144 lines
5.0 KiB
Python
144 lines
5.0 KiB
Python
"""Self-play training loop with parallel game generation."""
|
|
|
|
import os
|
|
import numpy as np
|
|
from collections import deque
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
from .game import ConnectFour
|
|
from .model import build_model, print_model_info
|
|
from .mcts import run_mcts
|
|
from .config import (
|
|
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
|
|
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
|
|
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
|
|
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
|
|
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
|
|
)
|
|
|
|
# Per-worker global model (loaded once per process)
|
|
_worker_model = None
|
|
|
|
|
|
def _init_worker(weights_list):
|
|
"""Initialize a worker process with its own model copy."""
|
|
global _worker_model
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
_worker_model = build_model()
|
|
_worker_model.set_weights(weights_list)
|
|
|
|
|
|
def _play_one_game(_):
|
|
"""Play a single self-play game in a worker process."""
|
|
game = ConnectFour()
|
|
trajectory = []
|
|
|
|
while not game.done:
|
|
state = game.get_state()
|
|
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
|
|
|
|
if game.move_count < TEMP_DROP_MOVE:
|
|
temp = MCTS_TEMPERATURE
|
|
else:
|
|
temp = 0.1
|
|
|
|
if temp < 0.2:
|
|
action = int(np.argmax(visit_counts))
|
|
policy = np.zeros(7, dtype=np.float32)
|
|
policy[action] = 1.0
|
|
else:
|
|
counts = visit_counts ** (1.0 / temp)
|
|
policy = counts / counts.sum()
|
|
action = np.random.choice(7, p=policy)
|
|
|
|
trajectory.append((state, policy, game.current_player))
|
|
game.step(action)
|
|
|
|
samples = []
|
|
for state, policy, player in trajectory:
|
|
if game.winner == 0:
|
|
value = DRAW_REWARD
|
|
elif game.winner == player:
|
|
value = WIN_REWARD
|
|
else:
|
|
value = LOSS_REWARD
|
|
samples.append((state, policy, value))
|
|
|
|
return samples
|
|
|
|
|
|
def train():
|
|
"""Main training entry point."""
|
|
model = build_model()
|
|
print_model_info(model)
|
|
|
|
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
|
|
print(f"Using {num_workers} worker processes for self-play")
|
|
|
|
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
|
|
|
for iteration in range(1, NUM_ITERATIONS + 1):
|
|
print(f"\n{'='*60}")
|
|
print(f"Iteration {iteration}/{NUM_ITERATIONS}")
|
|
print(f"{'='*60}")
|
|
|
|
# ── Self-play (parallel) ───────────────────────────────
|
|
weights = model.get_weights()
|
|
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
|
|
results = pool.map(_play_one_game, range(GAMES_PER_ITERATION))
|
|
|
|
wins = {1: 0, -1: 0, 0: 0}
|
|
for samples in results:
|
|
replay_buffer.extend(samples)
|
|
if samples:
|
|
last_value = samples[-1][2]
|
|
if last_value == WIN_REWARD:
|
|
wins[1] += 1
|
|
elif last_value == LOSS_REWARD:
|
|
wins[-1] += 1
|
|
else:
|
|
wins[0] += 1
|
|
|
|
print(f" Self-play: {GAMES_PER_ITERATION} games "
|
|
f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})")
|
|
print(f" Buffer size: {len(replay_buffer)}")
|
|
|
|
# ── Train ───────────────────────────────────────────────
|
|
if len(replay_buffer) >= BATCH_SIZE:
|
|
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
|
|
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
|
|
batch = [replay_buffer[i] for i in indices]
|
|
|
|
states = np.array([s[0] for s in batch])
|
|
policies = np.array([s[1] for s in batch])
|
|
values = np.array([s[2] for s in batch]).reshape(-1, 1)
|
|
|
|
history = model.fit(
|
|
states,
|
|
{"policy_logits": policies, "value": values},
|
|
batch_size=BATCH_SIZE,
|
|
epochs=EPOCHS_PER_ITERATION,
|
|
verbose=1,
|
|
)
|
|
policy_loss = history.history["policy_logits_loss"][-1]
|
|
value_loss = history.history["value_loss"][-1]
|
|
print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}")
|
|
|
|
# ── Checkpoint ──────────────────────────────────────────
|
|
if iteration % CHECKPOINT_INTERVAL == 0:
|
|
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
|
|
model.save(path)
|
|
print(f" Saved checkpoint: {path}")
|
|
|
|
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
|
|
model.save(final_path)
|
|
print(f"\nTraining complete. Final model saved to {final_path}")
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|