[fix] Non heuristic moves...
This commit is contained in:
+143
@@ -0,0 +1,143 @@
|
||||
"""Self-play training loop with parallel game generation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
from .game import ConnectFour
|
||||
from .model import build_model, print_model_info
|
||||
from .mcts import run_mcts
|
||||
from .config import (
|
||||
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
|
||||
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
|
||||
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
|
||||
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
|
||||
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
|
||||
)
|
||||
|
||||
# Per-worker global model (loaded once per process)
|
||||
_worker_model = None
|
||||
|
||||
|
||||
def _init_worker(weights_list):
|
||||
"""Initialize a worker process with its own model copy."""
|
||||
global _worker_model
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
_worker_model = build_model()
|
||||
_worker_model.set_weights(weights_list)
|
||||
|
||||
|
||||
def _play_one_game(_):
|
||||
"""Play a single self-play game in a worker process."""
|
||||
game = ConnectFour()
|
||||
trajectory = []
|
||||
|
||||
while not game.done:
|
||||
state = game.get_state()
|
||||
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
|
||||
|
||||
if game.move_count < TEMP_DROP_MOVE:
|
||||
temp = MCTS_TEMPERATURE
|
||||
else:
|
||||
temp = 0.1
|
||||
|
||||
if temp < 0.2:
|
||||
action = int(np.argmax(visit_counts))
|
||||
policy = np.zeros(7, dtype=np.float32)
|
||||
policy[action] = 1.0
|
||||
else:
|
||||
counts = visit_counts ** (1.0 / temp)
|
||||
policy = counts / counts.sum()
|
||||
action = np.random.choice(7, p=policy)
|
||||
|
||||
trajectory.append((state, policy, game.current_player))
|
||||
game.step(action)
|
||||
|
||||
samples = []
|
||||
for state, policy, player in trajectory:
|
||||
if game.winner == 0:
|
||||
value = DRAW_REWARD
|
||||
elif game.winner == player:
|
||||
value = WIN_REWARD
|
||||
else:
|
||||
value = LOSS_REWARD
|
||||
samples.append((state, policy, value))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def train():
|
||||
"""Main training entry point."""
|
||||
model = build_model()
|
||||
print_model_info(model)
|
||||
|
||||
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
|
||||
print(f"Using {num_workers} worker processes for self-play")
|
||||
|
||||
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
|
||||
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
||||
|
||||
for iteration in range(1, NUM_ITERATIONS + 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Iteration {iteration}/{NUM_ITERATIONS}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# ── Self-play (parallel) ───────────────────────────────
|
||||
weights = model.get_weights()
|
||||
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
|
||||
results = pool.map(_play_one_game, range(GAMES_PER_ITERATION))
|
||||
|
||||
wins = {1: 0, -1: 0, 0: 0}
|
||||
for samples in results:
|
||||
replay_buffer.extend(samples)
|
||||
if samples:
|
||||
last_value = samples[-1][2]
|
||||
if last_value == WIN_REWARD:
|
||||
wins[1] += 1
|
||||
elif last_value == LOSS_REWARD:
|
||||
wins[-1] += 1
|
||||
else:
|
||||
wins[0] += 1
|
||||
|
||||
print(f" Self-play: {GAMES_PER_ITERATION} games "
|
||||
f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})")
|
||||
print(f" Buffer size: {len(replay_buffer)}")
|
||||
|
||||
# ── Train ───────────────────────────────────────────────
|
||||
if len(replay_buffer) >= BATCH_SIZE:
|
||||
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
|
||||
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
|
||||
batch = [replay_buffer[i] for i in indices]
|
||||
|
||||
states = np.array([s[0] for s in batch])
|
||||
policies = np.array([s[1] for s in batch])
|
||||
values = np.array([s[2] for s in batch]).reshape(-1, 1)
|
||||
|
||||
history = model.fit(
|
||||
states,
|
||||
{"policy_logits": policies, "value": values},
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS_PER_ITERATION,
|
||||
verbose=1,
|
||||
)
|
||||
policy_loss = history.history["policy_logits_loss"][-1]
|
||||
value_loss = history.history["value_loss"][-1]
|
||||
print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}")
|
||||
|
||||
# ── Checkpoint ──────────────────────────────────────────
|
||||
if iteration % CHECKPOINT_INTERVAL == 0:
|
||||
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
|
||||
model.save(path)
|
||||
print(f" Saved checkpoint: {path}")
|
||||
|
||||
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
|
||||
model.save(final_path)
|
||||
print(f"\nTraining complete. Final model saved to {final_path}")
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user