"""Self-play training loop with parallel game generation.""" import os import numpy as np from collections import deque from multiprocessing import Pool, cpu_count from .game import ConnectFour from .model import build_model, print_model_info from .mcts import run_mcts from .config import ( NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS, MCTS_TEMPERATURE, TEMP_DROP_MOVE, WIN_REWARD, DRAW_REWARD, LOSS_REWARD, BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE, CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS, ) # Per-worker global model (loaded once per process) _worker_model = None def _init_worker(weights_list): """Initialize a worker process with its own model copy.""" global _worker_model os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "" _worker_model = build_model() _worker_model.set_weights(weights_list) def _play_one_game(_): """Play a single self-play game in a worker process.""" game = ConnectFour() trajectory = [] while not game.done: state = game.get_state() visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS) if game.move_count < TEMP_DROP_MOVE: temp = MCTS_TEMPERATURE else: temp = 0.1 if temp < 0.2: action = int(np.argmax(visit_counts)) policy = np.zeros(7, dtype=np.float32) policy[action] = 1.0 else: counts = visit_counts ** (1.0 / temp) policy = counts / counts.sum() action = np.random.choice(7, p=policy) trajectory.append((state, policy, game.current_player)) game.step(action) samples = [] for state, policy, player in trajectory: if game.winner == 0: value = DRAW_REWARD elif game.winner == player: value = WIN_REWARD else: value = LOSS_REWARD samples.append((state, policy, value)) return samples def train(): """Main training entry point.""" model = build_model() print_model_info(model) num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count() print(f"Using {num_workers} worker processes for self-play") replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE) os.makedirs(CHECKPOINT_DIR, exist_ok=True) for iteration in range(1, NUM_ITERATIONS + 1): print(f"\n{'='*60}") print(f"Iteration {iteration}/{NUM_ITERATIONS}") print(f"{'='*60}") # ── Self-play (parallel) ─────────────────────────────── weights = model.get_weights() with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool: results = pool.map(_play_one_game, range(GAMES_PER_ITERATION)) wins = {1: 0, -1: 0, 0: 0} for samples in results: replay_buffer.extend(samples) if samples: last_value = samples[-1][2] if last_value == WIN_REWARD: wins[1] += 1 elif last_value == LOSS_REWARD: wins[-1] += 1 else: wins[0] += 1 print(f" Self-play: {GAMES_PER_ITERATION} games " f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})") print(f" Buffer size: {len(replay_buffer)}") # ── Train ─────────────────────────────────────────────── if len(replay_buffer) >= BATCH_SIZE: sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION) indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False) batch = [replay_buffer[i] for i in indices] states = np.array([s[0] for s in batch]) policies = np.array([s[1] for s in batch]) values = np.array([s[2] for s in batch]).reshape(-1, 1) history = model.fit( states, {"policy_logits": policies, "value": values}, batch_size=BATCH_SIZE, epochs=EPOCHS_PER_ITERATION, verbose=1, ) policy_loss = history.history["policy_logits_loss"][-1] value_loss = history.history["value_loss"][-1] print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}") # ── Checkpoint ────────────────────────────────────────── if iteration % CHECKPOINT_INTERVAL == 0: path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras") model.save(path) print(f" Saved checkpoint: {path}") final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras") model.save(final_path) print(f"\nTraining complete. Final model saved to {final_path}") return model if __name__ == "__main__": train()