483 lines
16 KiB
Python
483 lines
16 KiB
Python
"""Pygame visualization of Connect Four RL training.
|
|
|
|
Left panel: live self-play game board
|
|
Right panel: loss curves + win-rate chart + training stats
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
|
|
import numpy as np
|
|
import pygame
|
|
|
|
from .game import ConnectFour, ROWS, COLS
|
|
from .model import build_model, print_model_info
|
|
from .mcts import run_mcts
|
|
from .config import (
|
|
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
|
|
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
|
|
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
|
|
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
|
|
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
|
|
)
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
# ── Layout constants ────────────────────────────────────────────────
|
|
CELL = 80
|
|
BOARD_W = COLS * CELL
|
|
BOARD_H = ROWS * CELL
|
|
PANEL_W = 420
|
|
MARGIN = 20
|
|
WIN_W = BOARD_W + PANEL_W + MARGIN * 3
|
|
WIN_H = BOARD_H + MARGIN * 2
|
|
FPS = 30
|
|
|
|
# ── Colors ──────────────────────────────────────────────────────────
|
|
BG = (30, 30, 40)
|
|
BOARD_BG = (0, 60, 180)
|
|
EMPTY = (20, 20, 30)
|
|
P1_COLOR = (255, 220, 50) # yellow
|
|
P2_COLOR = (220, 40, 40) # red
|
|
WIN_HIGHLIGHT = (100, 255, 100)
|
|
GRID_LINE = (0, 40, 140)
|
|
TEXT_COLOR = (220, 220, 220)
|
|
CHART_BG = (40, 40, 55)
|
|
POLICY_LINE = (80, 200, 255)
|
|
VALUE_LINE = (255, 160, 60)
|
|
P1_CHART = (255, 220, 50)
|
|
P2_CHART = (220, 40, 40)
|
|
DRAW_CHART = (140, 140, 140)
|
|
|
|
# ── Shared state between training thread and pygame loop ────────────
|
|
_state = {
|
|
"board": np.zeros((ROWS, COLS), dtype=np.int8),
|
|
"iteration": 0,
|
|
"game_num": 0,
|
|
"phase": "init", # init / self-play / training / done
|
|
"policy_losses": [],
|
|
"value_losses": [],
|
|
"win_history": [], # list of (p1_wins, p2_wins, draws) per iteration
|
|
"move_delay": 0.3,
|
|
"status": "Initializing...",
|
|
"winner": 0,
|
|
"running": True,
|
|
}
|
|
_lock = threading.Lock()
|
|
|
|
|
|
# ── Worker setup (same as train.py) ─────────────────────────────────
|
|
_worker_model = None
|
|
|
|
|
|
def _init_worker(weights_list):
|
|
global _worker_model
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
_worker_model = build_model()
|
|
_worker_model.set_weights(weights_list)
|
|
|
|
|
|
def _play_one_game(_):
|
|
game = ConnectFour()
|
|
trajectory = []
|
|
while not game.done:
|
|
state = game.get_state()
|
|
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
|
|
if game.move_count < TEMP_DROP_MOVE:
|
|
temp = MCTS_TEMPERATURE
|
|
else:
|
|
temp = 0.1
|
|
if temp < 0.2:
|
|
action = int(np.argmax(visit_counts))
|
|
policy = np.zeros(7, dtype=np.float32)
|
|
policy[action] = 1.0
|
|
else:
|
|
counts = visit_counts ** (1.0 / temp)
|
|
policy = counts / counts.sum()
|
|
action = np.random.choice(7, p=policy)
|
|
trajectory.append((state, policy, game.current_player))
|
|
game.step(action)
|
|
samples = []
|
|
for state, policy, player in trajectory:
|
|
if game.winner == 0:
|
|
value = DRAW_REWARD
|
|
elif game.winner == player:
|
|
value = WIN_REWARD
|
|
else:
|
|
value = LOSS_REWARD
|
|
samples.append((state, policy, value))
|
|
return samples
|
|
|
|
|
|
def _play_showcase_game(model):
|
|
"""Play one game slowly on the main training thread, updating shared state."""
|
|
game = ConnectFour()
|
|
trajectory = []
|
|
|
|
with _lock:
|
|
_state["board"] = game.board.copy()
|
|
_state["winner"] = 0
|
|
|
|
while not game.done and _state["running"]:
|
|
state = game.get_state()
|
|
visit_counts = run_mcts(game, model, MCTS_SIMULATIONS)
|
|
|
|
if game.move_count < TEMP_DROP_MOVE:
|
|
temp = MCTS_TEMPERATURE
|
|
else:
|
|
temp = 0.1
|
|
if temp < 0.2:
|
|
action = int(np.argmax(visit_counts))
|
|
policy = np.zeros(7, dtype=np.float32)
|
|
policy[action] = 1.0
|
|
else:
|
|
counts = visit_counts ** (1.0 / temp)
|
|
policy = counts / counts.sum()
|
|
action = np.random.choice(7, p=policy)
|
|
|
|
trajectory.append((state, policy, game.current_player))
|
|
game.step(action)
|
|
|
|
with _lock:
|
|
_state["board"] = game.board.copy()
|
|
|
|
time.sleep(_state["move_delay"])
|
|
|
|
with _lock:
|
|
_state["winner"] = game.winner
|
|
|
|
samples = []
|
|
for state, policy, player in trajectory:
|
|
if game.winner == 0:
|
|
value = DRAW_REWARD
|
|
elif game.winner == player:
|
|
value = WIN_REWARD
|
|
else:
|
|
value = LOSS_REWARD
|
|
samples.append((state, policy, value))
|
|
return samples
|
|
|
|
|
|
def _training_thread():
|
|
"""Run the full training loop, pushing updates to shared state."""
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
model = build_model()
|
|
print_model_info(model)
|
|
|
|
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
|
|
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
|
|
|
with _lock:
|
|
_state["status"] = f"Using {num_workers} workers"
|
|
|
|
for iteration in range(1, NUM_ITERATIONS + 1):
|
|
if not _state["running"]:
|
|
break
|
|
|
|
with _lock:
|
|
_state["iteration"] = iteration
|
|
_state["phase"] = "self-play"
|
|
_state["status"] = f"Iteration {iteration}/{NUM_ITERATIONS} - Self-play"
|
|
|
|
# Play one showcase game visually
|
|
with _lock:
|
|
_state["game_num"] = 0
|
|
showcase_samples = _play_showcase_game(model)
|
|
replay_buffer.extend(showcase_samples)
|
|
|
|
# Play remaining games in parallel
|
|
remaining = GAMES_PER_ITERATION - 1
|
|
if remaining > 0 and _state["running"]:
|
|
with _lock:
|
|
_state["status"] = f"Iter {iteration} - Playing {remaining} games (parallel)..."
|
|
|
|
weights = model.get_weights()
|
|
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
|
|
results = pool.map(_play_one_game, range(remaining))
|
|
|
|
for samples in results:
|
|
replay_buffer.extend(samples)
|
|
|
|
# Count wins across all games this iteration
|
|
wins = {1: 0, -1: 0, 0: 0}
|
|
# Showcase game
|
|
if showcase_samples:
|
|
last_val = showcase_samples[-1][2]
|
|
if last_val == WIN_REWARD:
|
|
wins[1] += 1
|
|
elif last_val == LOSS_REWARD:
|
|
wins[-1] += 1
|
|
else:
|
|
wins[0] += 1
|
|
# Parallel games
|
|
if remaining > 0 and _state["running"]:
|
|
for samples in results:
|
|
if samples:
|
|
last_val = samples[-1][2]
|
|
if last_val == WIN_REWARD:
|
|
wins[1] += 1
|
|
elif last_val == LOSS_REWARD:
|
|
wins[-1] += 1
|
|
else:
|
|
wins[0] += 1
|
|
|
|
with _lock:
|
|
_state["win_history"].append((wins[1], wins[-1], wins[0]))
|
|
|
|
# Train
|
|
if len(replay_buffer) >= BATCH_SIZE and _state["running"]:
|
|
with _lock:
|
|
_state["phase"] = "training"
|
|
_state["status"] = f"Iter {iteration} - Training..."
|
|
|
|
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
|
|
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
|
|
batch = [replay_buffer[i] for i in indices]
|
|
|
|
states = np.array([s[0] for s in batch])
|
|
policies = np.array([s[1] for s in batch])
|
|
values = np.array([s[2] for s in batch]).reshape(-1, 1)
|
|
|
|
history = model.fit(
|
|
states,
|
|
{"policy_logits": policies, "value": values},
|
|
batch_size=BATCH_SIZE,
|
|
epochs=EPOCHS_PER_ITERATION,
|
|
verbose=0,
|
|
)
|
|
|
|
with _lock:
|
|
_state["policy_losses"].append(history.history["policy_logits_loss"][-1])
|
|
_state["value_losses"].append(history.history["value_loss"][-1])
|
|
|
|
# Checkpoint
|
|
if iteration % CHECKPOINT_INTERVAL == 0:
|
|
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
|
|
model.save(path)
|
|
|
|
if _state["running"]:
|
|
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
|
|
model.save(final_path)
|
|
|
|
with _lock:
|
|
_state["phase"] = "done"
|
|
_state["status"] = "Training complete!"
|
|
|
|
|
|
# ── Drawing helpers ─────────────────────────────────────────────────
|
|
|
|
def _draw_board(surface, board, x0, y0):
|
|
"""Draw the Connect Four board."""
|
|
# Board background
|
|
pygame.draw.rect(surface, BOARD_BG, (x0, y0, BOARD_W, BOARD_H), border_radius=8)
|
|
|
|
for r in range(ROWS):
|
|
for c in range(COLS):
|
|
cx = x0 + c * CELL + CELL // 2
|
|
cy = y0 + r * CELL + CELL // 2
|
|
radius = CELL // 2 - 6
|
|
|
|
val = board[r, c]
|
|
if val == 1:
|
|
color = P1_COLOR
|
|
elif val == -1:
|
|
color = P2_COLOR
|
|
else:
|
|
color = EMPTY
|
|
|
|
pygame.draw.circle(surface, color, (cx, cy), radius)
|
|
pygame.draw.circle(surface, GRID_LINE, (cx, cy), radius, 2)
|
|
|
|
|
|
def _draw_chart(surface, x, y, w, h, series_list, colors, title, font):
|
|
"""Draw a simple line chart with multiple series."""
|
|
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
|
|
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
|
|
|
|
# Title
|
|
title_surf = font.render(title, True, TEXT_COLOR)
|
|
surface.blit(title_surf, (x + 8, y + 4))
|
|
|
|
chart_x = x + 8
|
|
chart_y = y + 24
|
|
chart_w = w - 16
|
|
chart_h = h - 32
|
|
|
|
if not any(series_list):
|
|
return
|
|
|
|
# Find global min/max
|
|
all_vals = [v for s in series_list if s for v in s]
|
|
if not all_vals:
|
|
return
|
|
min_val = min(all_vals)
|
|
max_val = max(all_vals)
|
|
val_range = max_val - min_val if max_val != min_val else 1.0
|
|
|
|
for series, color in zip(series_list, colors):
|
|
if len(series) < 2:
|
|
continue
|
|
points = []
|
|
for i, v in enumerate(series):
|
|
px = chart_x + int(i / (len(series) - 1) * chart_w)
|
|
py = chart_y + chart_h - int((v - min_val) / val_range * chart_h)
|
|
points.append((px, py))
|
|
pygame.draw.lines(surface, color, False, points, 2)
|
|
|
|
|
|
def _draw_stacked_bar(surface, x, y, w, h, win_history, font):
|
|
"""Draw stacked bar chart of win rates."""
|
|
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
|
|
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
|
|
|
|
title_surf = font.render("Win rates per iteration", True, TEXT_COLOR)
|
|
surface.blit(title_surf, (x + 8, y + 4))
|
|
|
|
if not win_history:
|
|
return
|
|
|
|
chart_x = x + 8
|
|
chart_y = y + 24
|
|
chart_w = w - 16
|
|
chart_h = h - 48
|
|
|
|
n = len(win_history)
|
|
bar_w = max(2, chart_w // max(n, 1))
|
|
|
|
for i, (p1, p2, dr) in enumerate(win_history):
|
|
total = p1 + p2 + dr
|
|
if total == 0:
|
|
continue
|
|
bx = chart_x + int(i / max(n, 1) * chart_w)
|
|
|
|
# Stack: P1 (bottom), draws (middle), P2 (top)
|
|
h1 = int(p1 / total * chart_h)
|
|
hd = int(dr / total * chart_h)
|
|
h2 = chart_h - h1 - hd
|
|
|
|
by = chart_y
|
|
pygame.draw.rect(surface, P2_CHART, (bx, by, bar_w - 1, h2))
|
|
by += h2
|
|
pygame.draw.rect(surface, DRAW_CHART, (bx, by, bar_w - 1, hd))
|
|
by += hd
|
|
pygame.draw.rect(surface, P1_CHART, (bx, by, bar_w - 1, h1))
|
|
|
|
# Legend
|
|
ly = y + h - 18
|
|
for label, color, lx in [("P1", P1_CHART, x + 8), ("Draw", DRAW_CHART, x + 70), ("P2", P2_CHART, x + 150)]:
|
|
pygame.draw.rect(surface, color, (lx, ly, 12, 12))
|
|
surface.blit(font.render(label, True, TEXT_COLOR), (lx + 16, ly - 2))
|
|
|
|
|
|
def run_visualized():
|
|
"""Launch pygame window and run training with live visualization."""
|
|
pygame.init()
|
|
screen = pygame.display.set_mode((WIN_W, WIN_H))
|
|
pygame.display.set_caption("Connect Four RL Training")
|
|
clock = pygame.time.Clock()
|
|
font = pygame.font.SysFont("monospace", 14)
|
|
font_big = pygame.font.SysFont("monospace", 18, bold=True)
|
|
|
|
# Start training in background thread
|
|
train_thread = threading.Thread(target=_training_thread, daemon=True)
|
|
train_thread.start()
|
|
|
|
running = True
|
|
while running:
|
|
for event in pygame.event.get():
|
|
if event.type == pygame.QUIT:
|
|
running = False
|
|
_state["running"] = False
|
|
elif event.type == pygame.KEYDOWN:
|
|
if event.key == pygame.K_ESCAPE:
|
|
running = False
|
|
_state["running"] = False
|
|
elif event.key == pygame.K_UP:
|
|
_state["move_delay"] = max(0.05, _state["move_delay"] - 0.05)
|
|
elif event.key == pygame.K_DOWN:
|
|
_state["move_delay"] = min(2.0, _state["move_delay"] + 0.05)
|
|
|
|
screen.fill(BG)
|
|
|
|
with _lock:
|
|
board = _state["board"].copy()
|
|
iteration = _state["iteration"]
|
|
phase = _state["phase"]
|
|
status = _state["status"]
|
|
policy_losses = list(_state["policy_losses"])
|
|
value_losses = list(_state["value_losses"])
|
|
win_history = list(_state["win_history"])
|
|
winner = _state["winner"]
|
|
delay = _state["move_delay"]
|
|
|
|
# ── Left: game board ────────────────────────────────────
|
|
bx, by = MARGIN, MARGIN
|
|
_draw_board(screen, board, bx, by)
|
|
|
|
# Winner overlay
|
|
if winner != 0 and phase == "self-play":
|
|
label = f"Player {1 if winner == 1 else 2} wins!"
|
|
color = P1_COLOR if winner == 1 else P2_COLOR
|
|
win_surf = font_big.render(label, True, color)
|
|
wrect = win_surf.get_rect(center=(bx + BOARD_W // 2, by + BOARD_H + 2))
|
|
if wrect.bottom < WIN_H:
|
|
screen.blit(win_surf, wrect)
|
|
|
|
# ── Right panel ────────────────────────────────────────
|
|
px = BOARD_W + MARGIN * 2
|
|
py = MARGIN
|
|
|
|
# Status
|
|
status_surf = font_big.render(status, True, TEXT_COLOR)
|
|
screen.blit(status_surf, (px, py))
|
|
py += 28
|
|
|
|
iter_surf = font.render(f"Iteration: {iteration}/{NUM_ITERATIONS} Phase: {phase}", True, TEXT_COLOR)
|
|
screen.blit(iter_surf, (px, py))
|
|
py += 20
|
|
|
|
delay_surf = font.render(f"Move delay: {delay:.2f}s (Up/Down to adjust)", True, (150, 150, 170))
|
|
screen.blit(delay_surf, (px, py))
|
|
py += 28
|
|
|
|
# Loss chart
|
|
chart_h = 140
|
|
_draw_chart(
|
|
screen, px, py, PANEL_W, chart_h,
|
|
[policy_losses, value_losses],
|
|
[POLICY_LINE, VALUE_LINE],
|
|
"Loss (blue=policy, orange=value)",
|
|
font,
|
|
)
|
|
py += chart_h + 12
|
|
|
|
# Win rate chart
|
|
bar_h = 160
|
|
_draw_stacked_bar(screen, px, py, PANEL_W, bar_h, win_history, font)
|
|
py += bar_h + 12
|
|
|
|
# Latest stats
|
|
if policy_losses:
|
|
pl = font.render(f"Policy loss: {policy_losses[-1]:.4f}", True, POLICY_LINE)
|
|
screen.blit(pl, (px, py))
|
|
py += 18
|
|
if value_losses:
|
|
vl = font.render(f"Value loss: {value_losses[-1]:.4f}", True, VALUE_LINE)
|
|
screen.blit(vl, (px, py))
|
|
py += 18
|
|
if win_history:
|
|
p1, p2, dr = win_history[-1]
|
|
ws = font.render(f"Last iter: P1={p1} P2={p2} Draw={dr}", True, TEXT_COLOR)
|
|
screen.blit(ws, (px, py))
|
|
|
|
pygame.display.flip()
|
|
clock.tick(FPS)
|
|
|
|
pygame.quit()
|
|
_state["running"] = False
|
|
train_thread.join(timeout=5)
|