"""Pygame visualization of Connect Four RL training. Left panel: live self-play game board Right panel: loss curves + win-rate chart + training stats """ import os import threading import time from collections import deque import numpy as np import pygame from .game import ConnectFour, ROWS, COLS from .model import build_model, print_model_info from .mcts import run_mcts from .config import ( NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS, MCTS_TEMPERATURE, TEMP_DROP_MOVE, WIN_REWARD, DRAW_REWARD, LOSS_REWARD, BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE, CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS, ) from multiprocessing import Pool, cpu_count # ── Layout constants ──────────────────────────────────────────────── CELL = 80 BOARD_W = COLS * CELL BOARD_H = ROWS * CELL PANEL_W = 420 MARGIN = 20 WIN_W = BOARD_W + PANEL_W + MARGIN * 3 WIN_H = BOARD_H + MARGIN * 2 FPS = 30 # ── Colors ────────────────────────────────────────────────────────── BG = (30, 30, 40) BOARD_BG = (0, 60, 180) EMPTY = (20, 20, 30) P1_COLOR = (255, 220, 50) # yellow P2_COLOR = (220, 40, 40) # red WIN_HIGHLIGHT = (100, 255, 100) GRID_LINE = (0, 40, 140) TEXT_COLOR = (220, 220, 220) CHART_BG = (40, 40, 55) POLICY_LINE = (80, 200, 255) VALUE_LINE = (255, 160, 60) P1_CHART = (255, 220, 50) P2_CHART = (220, 40, 40) DRAW_CHART = (140, 140, 140) # ── Shared state between training thread and pygame loop ──────────── _state = { "board": np.zeros((ROWS, COLS), dtype=np.int8), "iteration": 0, "game_num": 0, "phase": "init", # init / self-play / training / done "policy_losses": [], "value_losses": [], "win_history": [], # list of (p1_wins, p2_wins, draws) per iteration "move_delay": 0.3, "status": "Initializing...", "winner": 0, "running": True, } _lock = threading.Lock() # ── Worker setup (same as train.py) ───────────────────────────────── _worker_model = None def _init_worker(weights_list): global _worker_model os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "" _worker_model = build_model() _worker_model.set_weights(weights_list) def _play_one_game(_): game = ConnectFour() trajectory = [] while not game.done: state = game.get_state() visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS) if game.move_count < TEMP_DROP_MOVE: temp = MCTS_TEMPERATURE else: temp = 0.1 if temp < 0.2: action = int(np.argmax(visit_counts)) policy = np.zeros(7, dtype=np.float32) policy[action] = 1.0 else: counts = visit_counts ** (1.0 / temp) policy = counts / counts.sum() action = np.random.choice(7, p=policy) trajectory.append((state, policy, game.current_player)) game.step(action) samples = [] for state, policy, player in trajectory: if game.winner == 0: value = DRAW_REWARD elif game.winner == player: value = WIN_REWARD else: value = LOSS_REWARD samples.append((state, policy, value)) return samples def _play_showcase_game(model): """Play one game slowly on the main training thread, updating shared state.""" game = ConnectFour() trajectory = [] with _lock: _state["board"] = game.board.copy() _state["winner"] = 0 while not game.done and _state["running"]: state = game.get_state() visit_counts = run_mcts(game, model, MCTS_SIMULATIONS) if game.move_count < TEMP_DROP_MOVE: temp = MCTS_TEMPERATURE else: temp = 0.1 if temp < 0.2: action = int(np.argmax(visit_counts)) policy = np.zeros(7, dtype=np.float32) policy[action] = 1.0 else: counts = visit_counts ** (1.0 / temp) policy = counts / counts.sum() action = np.random.choice(7, p=policy) trajectory.append((state, policy, game.current_player)) game.step(action) with _lock: _state["board"] = game.board.copy() time.sleep(_state["move_delay"]) with _lock: _state["winner"] = game.winner samples = [] for state, policy, player in trajectory: if game.winner == 0: value = DRAW_REWARD elif game.winner == player: value = WIN_REWARD else: value = LOSS_REWARD samples.append((state, policy, value)) return samples def _training_thread(): """Run the full training loop, pushing updates to shared state.""" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" model = build_model() print_model_info(model) num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count() replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE) os.makedirs(CHECKPOINT_DIR, exist_ok=True) with _lock: _state["status"] = f"Using {num_workers} workers" for iteration in range(1, NUM_ITERATIONS + 1): if not _state["running"]: break with _lock: _state["iteration"] = iteration _state["phase"] = "self-play" _state["status"] = f"Iteration {iteration}/{NUM_ITERATIONS} - Self-play" # Play one showcase game visually with _lock: _state["game_num"] = 0 showcase_samples = _play_showcase_game(model) replay_buffer.extend(showcase_samples) # Play remaining games in parallel remaining = GAMES_PER_ITERATION - 1 if remaining > 0 and _state["running"]: with _lock: _state["status"] = f"Iter {iteration} - Playing {remaining} games (parallel)..." weights = model.get_weights() with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool: results = pool.map(_play_one_game, range(remaining)) for samples in results: replay_buffer.extend(samples) # Count wins across all games this iteration wins = {1: 0, -1: 0, 0: 0} # Showcase game if showcase_samples: last_val = showcase_samples[-1][2] if last_val == WIN_REWARD: wins[1] += 1 elif last_val == LOSS_REWARD: wins[-1] += 1 else: wins[0] += 1 # Parallel games if remaining > 0 and _state["running"]: for samples in results: if samples: last_val = samples[-1][2] if last_val == WIN_REWARD: wins[1] += 1 elif last_val == LOSS_REWARD: wins[-1] += 1 else: wins[0] += 1 with _lock: _state["win_history"].append((wins[1], wins[-1], wins[0])) # Train if len(replay_buffer) >= BATCH_SIZE and _state["running"]: with _lock: _state["phase"] = "training" _state["status"] = f"Iter {iteration} - Training..." sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION) indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False) batch = [replay_buffer[i] for i in indices] states = np.array([s[0] for s in batch]) policies = np.array([s[1] for s in batch]) values = np.array([s[2] for s in batch]).reshape(-1, 1) history = model.fit( states, {"policy_logits": policies, "value": values}, batch_size=BATCH_SIZE, epochs=EPOCHS_PER_ITERATION, verbose=0, ) with _lock: _state["policy_losses"].append(history.history["policy_logits_loss"][-1]) _state["value_losses"].append(history.history["value_loss"][-1]) # Checkpoint if iteration % CHECKPOINT_INTERVAL == 0: path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras") model.save(path) if _state["running"]: final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras") model.save(final_path) with _lock: _state["phase"] = "done" _state["status"] = "Training complete!" # ── Drawing helpers ───────────────────────────────────────────────── def _draw_board(surface, board, x0, y0): """Draw the Connect Four board.""" # Board background pygame.draw.rect(surface, BOARD_BG, (x0, y0, BOARD_W, BOARD_H), border_radius=8) for r in range(ROWS): for c in range(COLS): cx = x0 + c * CELL + CELL // 2 cy = y0 + r * CELL + CELL // 2 radius = CELL // 2 - 6 val = board[r, c] if val == 1: color = P1_COLOR elif val == -1: color = P2_COLOR else: color = EMPTY pygame.draw.circle(surface, color, (cx, cy), radius) pygame.draw.circle(surface, GRID_LINE, (cx, cy), radius, 2) def _draw_chart(surface, x, y, w, h, series_list, colors, title, font): """Draw a simple line chart with multiple series.""" pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6) pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6) # Title title_surf = font.render(title, True, TEXT_COLOR) surface.blit(title_surf, (x + 8, y + 4)) chart_x = x + 8 chart_y = y + 24 chart_w = w - 16 chart_h = h - 32 if not any(series_list): return # Find global min/max all_vals = [v for s in series_list if s for v in s] if not all_vals: return min_val = min(all_vals) max_val = max(all_vals) val_range = max_val - min_val if max_val != min_val else 1.0 for series, color in zip(series_list, colors): if len(series) < 2: continue points = [] for i, v in enumerate(series): px = chart_x + int(i / (len(series) - 1) * chart_w) py = chart_y + chart_h - int((v - min_val) / val_range * chart_h) points.append((px, py)) pygame.draw.lines(surface, color, False, points, 2) def _draw_stacked_bar(surface, x, y, w, h, win_history, font): """Draw stacked bar chart of win rates.""" pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6) pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6) title_surf = font.render("Win rates per iteration", True, TEXT_COLOR) surface.blit(title_surf, (x + 8, y + 4)) if not win_history: return chart_x = x + 8 chart_y = y + 24 chart_w = w - 16 chart_h = h - 48 n = len(win_history) bar_w = max(2, chart_w // max(n, 1)) for i, (p1, p2, dr) in enumerate(win_history): total = p1 + p2 + dr if total == 0: continue bx = chart_x + int(i / max(n, 1) * chart_w) # Stack: P1 (bottom), draws (middle), P2 (top) h1 = int(p1 / total * chart_h) hd = int(dr / total * chart_h) h2 = chart_h - h1 - hd by = chart_y pygame.draw.rect(surface, P2_CHART, (bx, by, bar_w - 1, h2)) by += h2 pygame.draw.rect(surface, DRAW_CHART, (bx, by, bar_w - 1, hd)) by += hd pygame.draw.rect(surface, P1_CHART, (bx, by, bar_w - 1, h1)) # Legend ly = y + h - 18 for label, color, lx in [("P1", P1_CHART, x + 8), ("Draw", DRAW_CHART, x + 70), ("P2", P2_CHART, x + 150)]: pygame.draw.rect(surface, color, (lx, ly, 12, 12)) surface.blit(font.render(label, True, TEXT_COLOR), (lx + 16, ly - 2)) def run_visualized(): """Launch pygame window and run training with live visualization.""" pygame.init() screen = pygame.display.set_mode((WIN_W, WIN_H)) pygame.display.set_caption("Connect Four RL Training") clock = pygame.time.Clock() font = pygame.font.SysFont("monospace", 14) font_big = pygame.font.SysFont("monospace", 18, bold=True) # Start training in background thread train_thread = threading.Thread(target=_training_thread, daemon=True) train_thread.start() running = True while running: for event in pygame.event.get(): if event.type == pygame.QUIT: running = False _state["running"] = False elif event.type == pygame.KEYDOWN: if event.key == pygame.K_ESCAPE: running = False _state["running"] = False elif event.key == pygame.K_UP: _state["move_delay"] = max(0.05, _state["move_delay"] - 0.05) elif event.key == pygame.K_DOWN: _state["move_delay"] = min(2.0, _state["move_delay"] + 0.05) screen.fill(BG) with _lock: board = _state["board"].copy() iteration = _state["iteration"] phase = _state["phase"] status = _state["status"] policy_losses = list(_state["policy_losses"]) value_losses = list(_state["value_losses"]) win_history = list(_state["win_history"]) winner = _state["winner"] delay = _state["move_delay"] # ── Left: game board ──────────────────────────────────── bx, by = MARGIN, MARGIN _draw_board(screen, board, bx, by) # Winner overlay if winner != 0 and phase == "self-play": label = f"Player {1 if winner == 1 else 2} wins!" color = P1_COLOR if winner == 1 else P2_COLOR win_surf = font_big.render(label, True, color) wrect = win_surf.get_rect(center=(bx + BOARD_W // 2, by + BOARD_H + 2)) if wrect.bottom < WIN_H: screen.blit(win_surf, wrect) # ── Right panel ──────────────────────────────────────── px = BOARD_W + MARGIN * 2 py = MARGIN # Status status_surf = font_big.render(status, True, TEXT_COLOR) screen.blit(status_surf, (px, py)) py += 28 iter_surf = font.render(f"Iteration: {iteration}/{NUM_ITERATIONS} Phase: {phase}", True, TEXT_COLOR) screen.blit(iter_surf, (px, py)) py += 20 delay_surf = font.render(f"Move delay: {delay:.2f}s (Up/Down to adjust)", True, (150, 150, 170)) screen.blit(delay_surf, (px, py)) py += 28 # Loss chart chart_h = 140 _draw_chart( screen, px, py, PANEL_W, chart_h, [policy_losses, value_losses], [POLICY_LINE, VALUE_LINE], "Loss (blue=policy, orange=value)", font, ) py += chart_h + 12 # Win rate chart bar_h = 160 _draw_stacked_bar(screen, px, py, PANEL_W, bar_h, win_history, font) py += bar_h + 12 # Latest stats if policy_losses: pl = font.render(f"Policy loss: {policy_losses[-1]:.4f}", True, POLICY_LINE) screen.blit(pl, (px, py)) py += 18 if value_losses: vl = font.render(f"Value loss: {value_losses[-1]:.4f}", True, VALUE_LINE) screen.blit(vl, (px, py)) py += 18 if win_history: p1, p2, dr = win_history[-1] ws = font.render(f"Last iter: P1={p1} P2={p2} Draw={dr}", True, TEXT_COLOR) screen.blit(ws, (px, py)) pygame.display.flip() clock.tick(FPS) pygame.quit() _state["running"] = False train_thread.join(timeout=5)