Files
Connect-four-Esp32/rl/visualize.py
T
2026-03-27 12:17:25 +01:00

483 lines
16 KiB
Python

"""Pygame visualization of Connect Four RL training.
Left panel: live self-play game board
Right panel: loss curves + win-rate chart + training stats
"""
import os
import threading
import time
from collections import deque
import numpy as np
import pygame
from .game import ConnectFour, ROWS, COLS
from .model import build_model, print_model_info
from .mcts import run_mcts
from .config import (
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
)
from multiprocessing import Pool, cpu_count
# ── Layout constants ────────────────────────────────────────────────
CELL = 80
BOARD_W = COLS * CELL
BOARD_H = ROWS * CELL
PANEL_W = 420
MARGIN = 20
WIN_W = BOARD_W + PANEL_W + MARGIN * 3
WIN_H = BOARD_H + MARGIN * 2
FPS = 30
# ── Colors ──────────────────────────────────────────────────────────
BG = (30, 30, 40)
BOARD_BG = (0, 60, 180)
EMPTY = (20, 20, 30)
P1_COLOR = (255, 220, 50) # yellow
P2_COLOR = (220, 40, 40) # red
WIN_HIGHLIGHT = (100, 255, 100)
GRID_LINE = (0, 40, 140)
TEXT_COLOR = (220, 220, 220)
CHART_BG = (40, 40, 55)
POLICY_LINE = (80, 200, 255)
VALUE_LINE = (255, 160, 60)
P1_CHART = (255, 220, 50)
P2_CHART = (220, 40, 40)
DRAW_CHART = (140, 140, 140)
# ── Shared state between training thread and pygame loop ────────────
_state = {
"board": np.zeros((ROWS, COLS), dtype=np.int8),
"iteration": 0,
"game_num": 0,
"phase": "init", # init / self-play / training / done
"policy_losses": [],
"value_losses": [],
"win_history": [], # list of (p1_wins, p2_wins, draws) per iteration
"move_delay": 0.3,
"status": "Initializing...",
"winner": 0,
"running": True,
}
_lock = threading.Lock()
# ── Worker setup (same as train.py) ─────────────────────────────────
_worker_model = None
def _init_worker(weights_list):
global _worker_model
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
_worker_model = build_model()
_worker_model.set_weights(weights_list)
def _play_one_game(_):
game = ConnectFour()
trajectory = []
while not game.done:
state = game.get_state()
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
if game.move_count < TEMP_DROP_MOVE:
temp = MCTS_TEMPERATURE
else:
temp = 0.1
if temp < 0.2:
action = int(np.argmax(visit_counts))
policy = np.zeros(7, dtype=np.float32)
policy[action] = 1.0
else:
counts = visit_counts ** (1.0 / temp)
policy = counts / counts.sum()
action = np.random.choice(7, p=policy)
trajectory.append((state, policy, game.current_player))
game.step(action)
samples = []
for state, policy, player in trajectory:
if game.winner == 0:
value = DRAW_REWARD
elif game.winner == player:
value = WIN_REWARD
else:
value = LOSS_REWARD
samples.append((state, policy, value))
return samples
def _play_showcase_game(model):
"""Play one game slowly on the main training thread, updating shared state."""
game = ConnectFour()
trajectory = []
with _lock:
_state["board"] = game.board.copy()
_state["winner"] = 0
while not game.done and _state["running"]:
state = game.get_state()
visit_counts = run_mcts(game, model, MCTS_SIMULATIONS)
if game.move_count < TEMP_DROP_MOVE:
temp = MCTS_TEMPERATURE
else:
temp = 0.1
if temp < 0.2:
action = int(np.argmax(visit_counts))
policy = np.zeros(7, dtype=np.float32)
policy[action] = 1.0
else:
counts = visit_counts ** (1.0 / temp)
policy = counts / counts.sum()
action = np.random.choice(7, p=policy)
trajectory.append((state, policy, game.current_player))
game.step(action)
with _lock:
_state["board"] = game.board.copy()
time.sleep(_state["move_delay"])
with _lock:
_state["winner"] = game.winner
samples = []
for state, policy, player in trajectory:
if game.winner == 0:
value = DRAW_REWARD
elif game.winner == player:
value = WIN_REWARD
else:
value = LOSS_REWARD
samples.append((state, policy, value))
return samples
def _training_thread():
"""Run the full training loop, pushing updates to shared state."""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
model = build_model()
print_model_info(model)
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
with _lock:
_state["status"] = f"Using {num_workers} workers"
for iteration in range(1, NUM_ITERATIONS + 1):
if not _state["running"]:
break
with _lock:
_state["iteration"] = iteration
_state["phase"] = "self-play"
_state["status"] = f"Iteration {iteration}/{NUM_ITERATIONS} - Self-play"
# Play one showcase game visually
with _lock:
_state["game_num"] = 0
showcase_samples = _play_showcase_game(model)
replay_buffer.extend(showcase_samples)
# Play remaining games in parallel
remaining = GAMES_PER_ITERATION - 1
if remaining > 0 and _state["running"]:
with _lock:
_state["status"] = f"Iter {iteration} - Playing {remaining} games (parallel)..."
weights = model.get_weights()
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
results = pool.map(_play_one_game, range(remaining))
for samples in results:
replay_buffer.extend(samples)
# Count wins across all games this iteration
wins = {1: 0, -1: 0, 0: 0}
# Showcase game
if showcase_samples:
last_val = showcase_samples[-1][2]
if last_val == WIN_REWARD:
wins[1] += 1
elif last_val == LOSS_REWARD:
wins[-1] += 1
else:
wins[0] += 1
# Parallel games
if remaining > 0 and _state["running"]:
for samples in results:
if samples:
last_val = samples[-1][2]
if last_val == WIN_REWARD:
wins[1] += 1
elif last_val == LOSS_REWARD:
wins[-1] += 1
else:
wins[0] += 1
with _lock:
_state["win_history"].append((wins[1], wins[-1], wins[0]))
# Train
if len(replay_buffer) >= BATCH_SIZE and _state["running"]:
with _lock:
_state["phase"] = "training"
_state["status"] = f"Iter {iteration} - Training..."
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
batch = [replay_buffer[i] for i in indices]
states = np.array([s[0] for s in batch])
policies = np.array([s[1] for s in batch])
values = np.array([s[2] for s in batch]).reshape(-1, 1)
history = model.fit(
states,
{"policy_logits": policies, "value": values},
batch_size=BATCH_SIZE,
epochs=EPOCHS_PER_ITERATION,
verbose=0,
)
with _lock:
_state["policy_losses"].append(history.history["policy_logits_loss"][-1])
_state["value_losses"].append(history.history["value_loss"][-1])
# Checkpoint
if iteration % CHECKPOINT_INTERVAL == 0:
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
model.save(path)
if _state["running"]:
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
model.save(final_path)
with _lock:
_state["phase"] = "done"
_state["status"] = "Training complete!"
# ── Drawing helpers ─────────────────────────────────────────────────
def _draw_board(surface, board, x0, y0):
"""Draw the Connect Four board."""
# Board background
pygame.draw.rect(surface, BOARD_BG, (x0, y0, BOARD_W, BOARD_H), border_radius=8)
for r in range(ROWS):
for c in range(COLS):
cx = x0 + c * CELL + CELL // 2
cy = y0 + r * CELL + CELL // 2
radius = CELL // 2 - 6
val = board[r, c]
if val == 1:
color = P1_COLOR
elif val == -1:
color = P2_COLOR
else:
color = EMPTY
pygame.draw.circle(surface, color, (cx, cy), radius)
pygame.draw.circle(surface, GRID_LINE, (cx, cy), radius, 2)
def _draw_chart(surface, x, y, w, h, series_list, colors, title, font):
"""Draw a simple line chart with multiple series."""
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
# Title
title_surf = font.render(title, True, TEXT_COLOR)
surface.blit(title_surf, (x + 8, y + 4))
chart_x = x + 8
chart_y = y + 24
chart_w = w - 16
chart_h = h - 32
if not any(series_list):
return
# Find global min/max
all_vals = [v for s in series_list if s for v in s]
if not all_vals:
return
min_val = min(all_vals)
max_val = max(all_vals)
val_range = max_val - min_val if max_val != min_val else 1.0
for series, color in zip(series_list, colors):
if len(series) < 2:
continue
points = []
for i, v in enumerate(series):
px = chart_x + int(i / (len(series) - 1) * chart_w)
py = chart_y + chart_h - int((v - min_val) / val_range * chart_h)
points.append((px, py))
pygame.draw.lines(surface, color, False, points, 2)
def _draw_stacked_bar(surface, x, y, w, h, win_history, font):
"""Draw stacked bar chart of win rates."""
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
title_surf = font.render("Win rates per iteration", True, TEXT_COLOR)
surface.blit(title_surf, (x + 8, y + 4))
if not win_history:
return
chart_x = x + 8
chart_y = y + 24
chart_w = w - 16
chart_h = h - 48
n = len(win_history)
bar_w = max(2, chart_w // max(n, 1))
for i, (p1, p2, dr) in enumerate(win_history):
total = p1 + p2 + dr
if total == 0:
continue
bx = chart_x + int(i / max(n, 1) * chart_w)
# Stack: P1 (bottom), draws (middle), P2 (top)
h1 = int(p1 / total * chart_h)
hd = int(dr / total * chart_h)
h2 = chart_h - h1 - hd
by = chart_y
pygame.draw.rect(surface, P2_CHART, (bx, by, bar_w - 1, h2))
by += h2
pygame.draw.rect(surface, DRAW_CHART, (bx, by, bar_w - 1, hd))
by += hd
pygame.draw.rect(surface, P1_CHART, (bx, by, bar_w - 1, h1))
# Legend
ly = y + h - 18
for label, color, lx in [("P1", P1_CHART, x + 8), ("Draw", DRAW_CHART, x + 70), ("P2", P2_CHART, x + 150)]:
pygame.draw.rect(surface, color, (lx, ly, 12, 12))
surface.blit(font.render(label, True, TEXT_COLOR), (lx + 16, ly - 2))
def run_visualized():
"""Launch pygame window and run training with live visualization."""
pygame.init()
screen = pygame.display.set_mode((WIN_W, WIN_H))
pygame.display.set_caption("Connect Four RL Training")
clock = pygame.time.Clock()
font = pygame.font.SysFont("monospace", 14)
font_big = pygame.font.SysFont("monospace", 18, bold=True)
# Start training in background thread
train_thread = threading.Thread(target=_training_thread, daemon=True)
train_thread.start()
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
_state["running"] = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
running = False
_state["running"] = False
elif event.key == pygame.K_UP:
_state["move_delay"] = max(0.05, _state["move_delay"] - 0.05)
elif event.key == pygame.K_DOWN:
_state["move_delay"] = min(2.0, _state["move_delay"] + 0.05)
screen.fill(BG)
with _lock:
board = _state["board"].copy()
iteration = _state["iteration"]
phase = _state["phase"]
status = _state["status"]
policy_losses = list(_state["policy_losses"])
value_losses = list(_state["value_losses"])
win_history = list(_state["win_history"])
winner = _state["winner"]
delay = _state["move_delay"]
# ── Left: game board ────────────────────────────────────
bx, by = MARGIN, MARGIN
_draw_board(screen, board, bx, by)
# Winner overlay
if winner != 0 and phase == "self-play":
label = f"Player {1 if winner == 1 else 2} wins!"
color = P1_COLOR if winner == 1 else P2_COLOR
win_surf = font_big.render(label, True, color)
wrect = win_surf.get_rect(center=(bx + BOARD_W // 2, by + BOARD_H + 2))
if wrect.bottom < WIN_H:
screen.blit(win_surf, wrect)
# ── Right panel ────────────────────────────────────────
px = BOARD_W + MARGIN * 2
py = MARGIN
# Status
status_surf = font_big.render(status, True, TEXT_COLOR)
screen.blit(status_surf, (px, py))
py += 28
iter_surf = font.render(f"Iteration: {iteration}/{NUM_ITERATIONS} Phase: {phase}", True, TEXT_COLOR)
screen.blit(iter_surf, (px, py))
py += 20
delay_surf = font.render(f"Move delay: {delay:.2f}s (Up/Down to adjust)", True, (150, 150, 170))
screen.blit(delay_surf, (px, py))
py += 28
# Loss chart
chart_h = 140
_draw_chart(
screen, px, py, PANEL_W, chart_h,
[policy_losses, value_losses],
[POLICY_LINE, VALUE_LINE],
"Loss (blue=policy, orange=value)",
font,
)
py += chart_h + 12
# Win rate chart
bar_h = 160
_draw_stacked_bar(screen, px, py, PANEL_W, bar_h, win_history, font)
py += bar_h + 12
# Latest stats
if policy_losses:
pl = font.render(f"Policy loss: {policy_losses[-1]:.4f}", True, POLICY_LINE)
screen.blit(pl, (px, py))
py += 18
if value_losses:
vl = font.render(f"Value loss: {value_losses[-1]:.4f}", True, VALUE_LINE)
screen.blit(vl, (px, py))
py += 18
if win_history:
p1, p2, dr = win_history[-1]
ws = font.render(f"Last iter: P1={p1} P2={p2} Draw={dr}", True, TEXT_COLOR)
screen.blit(ws, (px, py))
pygame.display.flip()
clock.tick(FPS)
pygame.quit()
_state["running"] = False
train_thread.join(timeout=5)