Files
2026-03-27 12:17:25 +01:00

144 lines
5.0 KiB
Python

"""Self-play training loop with parallel game generation."""
import os
import numpy as np
from collections import deque
from multiprocessing import Pool, cpu_count
from .game import ConnectFour
from .model import build_model, print_model_info
from .mcts import run_mcts
from .config import (
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
)
# Per-worker global model (loaded once per process)
_worker_model = None
def _init_worker(weights_list):
"""Initialize a worker process with its own model copy."""
global _worker_model
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
_worker_model = build_model()
_worker_model.set_weights(weights_list)
def _play_one_game(_):
"""Play a single self-play game in a worker process."""
game = ConnectFour()
trajectory = []
while not game.done:
state = game.get_state()
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
if game.move_count < TEMP_DROP_MOVE:
temp = MCTS_TEMPERATURE
else:
temp = 0.1
if temp < 0.2:
action = int(np.argmax(visit_counts))
policy = np.zeros(7, dtype=np.float32)
policy[action] = 1.0
else:
counts = visit_counts ** (1.0 / temp)
policy = counts / counts.sum()
action = np.random.choice(7, p=policy)
trajectory.append((state, policy, game.current_player))
game.step(action)
samples = []
for state, policy, player in trajectory:
if game.winner == 0:
value = DRAW_REWARD
elif game.winner == player:
value = WIN_REWARD
else:
value = LOSS_REWARD
samples.append((state, policy, value))
return samples
def train():
"""Main training entry point."""
model = build_model()
print_model_info(model)
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
print(f"Using {num_workers} worker processes for self-play")
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for iteration in range(1, NUM_ITERATIONS + 1):
print(f"\n{'='*60}")
print(f"Iteration {iteration}/{NUM_ITERATIONS}")
print(f"{'='*60}")
# ── Self-play (parallel) ───────────────────────────────
weights = model.get_weights()
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
results = pool.map(_play_one_game, range(GAMES_PER_ITERATION))
wins = {1: 0, -1: 0, 0: 0}
for samples in results:
replay_buffer.extend(samples)
if samples:
last_value = samples[-1][2]
if last_value == WIN_REWARD:
wins[1] += 1
elif last_value == LOSS_REWARD:
wins[-1] += 1
else:
wins[0] += 1
print(f" Self-play: {GAMES_PER_ITERATION} games "
f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})")
print(f" Buffer size: {len(replay_buffer)}")
# ── Train ───────────────────────────────────────────────
if len(replay_buffer) >= BATCH_SIZE:
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
batch = [replay_buffer[i] for i in indices]
states = np.array([s[0] for s in batch])
policies = np.array([s[1] for s in batch])
values = np.array([s[2] for s in batch]).reshape(-1, 1)
history = model.fit(
states,
{"policy_logits": policies, "value": values},
batch_size=BATCH_SIZE,
epochs=EPOCHS_PER_ITERATION,
verbose=1,
)
policy_loss = history.history["policy_logits_loss"][-1]
value_loss = history.history["value_loss"][-1]
print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}")
# ── Checkpoint ──────────────────────────────────────────
if iteration % CHECKPOINT_INTERVAL == 0:
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
model.save(path)
print(f" Saved checkpoint: {path}")
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
model.save(final_path)
print(f"\nTraining complete. Final model saved to {final_path}")
return model
if __name__ == "__main__":
train()