diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/connect_four.js b/connect_four.js index e2c4a32..4be09ac 100644 --- a/connect_four.js +++ b/connect_four.js @@ -166,12 +166,57 @@ function scanBoard(b) { return [0, []]; } +function evaluateBoard(b, aiP, huP) { + let score = 0; + + // Center column bonus + for (let r = 0; r < ROWS; r++) { + if (b[3][r] === aiP) score += 3; + else if (b[3][r] === huP) score -= 3; + } + + // Score a window of 4 cells by piece counts + function scoreWindow(c, r, dc, dr) { + let ai = 0, hu = 0; + for (let i = 0; i < 4; i++) { + const v = b[c + i * dc][r + i * dr]; + if (v === aiP) ai++; + else if (v === huP) hu++; + } + if (ai > 0 && hu > 0) return 0; + if (ai === 3) return 50; + if (ai === 2) return 5; + if (hu === 3) return -50; + if (hu === 2) return -5; + return 0; + } + + // Horizontal + for (let r = 0; r < ROWS; r++) + for (let c = 0; c <= COLS - 4; c++) + score += scoreWindow(c, r, 1, 0); + // Vertical + for (let r = 0; r <= ROWS - 4; r++) + for (let c = 0; c < COLS; c++) + score += scoreWindow(c, r, 0, 1); + // Diagonal up-right + for (let r = 0; r <= ROWS - 4; r++) + for (let c = 0; c <= COLS - 4; c++) + score += scoreWindow(c, r, 1, 1); + // Diagonal down-right + for (let r = 3; r < ROWS; r++) + for (let c = 0; c <= COLS - 4; c++) + score += scoreWindow(c, r, 1, -1); + + return score; +} + // --- AI ----------------------------------------------------- function minimax(b, depth, alpha, beta, isMax, aiP, huP) { const [winner] = scanBoard(b); if (winner === aiP) return 1000 + depth; if (winner === huP) return -1000 - depth; - if (depth === 0 || isBoardFull(b)) return 0; + if (depth === 0 || isBoardFull(b)) return evaluateBoard(b, aiP, huP); let best = isMax ? -10000 : 10000; for (const c of COL_ORDER) { @@ -196,12 +241,19 @@ function performAiMove(b, aiP, lookAhead, isDemo = false, dPly = 4) { const huP = aiP === 1 ? 2 : 1; const ply = isDemo ? dPly : lookAhead; - // Phase 1: instant win / block + // Phase 1a: check ALL columns for instant AI win for (let c = 0; c < COLS; c++) { const r = getFirstEmptyRow(b, c); if (r === -1) continue; b[c][r] = aiP; if (scanBoard(b)[0] === aiP) { b[c][r] = 0; return c; } + b[c][r] = 0; + } + + // Phase 1b: check ALL columns for opponent block + for (let c = 0; c < COLS; c++) { + const r = getFirstEmptyRow(b, c); + if (r === -1) continue; b[c][r] = huP; if (scanBoard(b)[0] === huP) { b[c][r] = 0; return c; } b[c][r] = 0; diff --git a/pyproject.toml b/pyproject.toml index 5f10da7..a8b1e27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,11 +2,14 @@ name = "connect-four-terminal" version = "1.0.0" description = "Connect Four terminal game with AI" -requires-python = ">=3.10" +requires-python = ">=3.10,<3.14" dependencies = [ "rich>=13.0", "python-dotenv>=1.0", "readchar>=4.0", + "tensorflow>=2.16", + "numpy>=2.0", + "pygame>=2.5", ] [project.scripts] diff --git a/rl/.python-version b/rl/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/rl/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/rl/README.md b/rl/README.md new file mode 100644 index 0000000..e69de29 diff --git a/rl/__init__.py b/rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl/__main__.py b/rl/__main__.py new file mode 100644 index 0000000..1c84cac --- /dev/null +++ b/rl/__main__.py @@ -0,0 +1,38 @@ +"""Entry point: python -m rl [train|export|info]""" + +import os +import sys + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def main(): + cmd = sys.argv[1] if len(sys.argv) > 1 else "train" + + if cmd == "train": + from .train import train + train() + + elif cmd == "export": + from .export import export_tflite + model_path = sys.argv[2] if len(sys.argv) > 2 else "rl/checkpoints/model_final.keras" + export_tflite(model_path) + + elif cmd == "visualize": + from .visualize import run_visualized + run_visualized() + + elif cmd == "info": + from .model import build_model, print_model_info + model = build_model() + print_model_info(model) + + else: + print(f"Unknown command: {cmd}") + print("Usage: python -m rl [train|visualize|export|info]") + sys.exit(1) + + +main() diff --git a/rl/config.py b/rl/config.py new file mode 100644 index 0000000..786f8d2 --- /dev/null +++ b/rl/config.py @@ -0,0 +1,36 @@ +"""Training hyperparameters — edit these to tune your model.""" + +# ── Model architecture ────────────────────────────────────────────── +CONV_FILTERS = 32 # filters per conv layer (keep small for ESP32) +NUM_CONV_LAYERS = 3 # number of convolutional blocks +DENSE_UNITS = 64 # units in the dense layer before heads + +# ── Training ──────────────────────────────────────────────────────── +LEARNING_RATE = 1e-3 # Adam learning rate +BATCH_SIZE = 256 # training batch size +EPOCHS_PER_ITERATION = 4 # epochs per training iteration +REPLAY_BUFFER_SIZE = 50000 # max samples kept in replay buffer + +# ── Self-play ─────────────────────────────────────────────────────── +NUM_ITERATIONS = 50 # total train iterations (self-play → train cycles) +GAMES_PER_ITERATION = 100 # self-play games generated per iteration +MCTS_SIMULATIONS = 50 # MCTS simulations per move +MCTS_C_PUCT = 1.4 # exploration constant +MCTS_TEMPERATURE = 1.0 # move selection temperature (1 = proportional, →0 = greedy) +TEMP_DROP_MOVE = 10 # switch to greedy after this many moves + +# ── Parallelism ──────────────────────────────────────────────────── +NUM_WORKERS = 0 # 0 = use all available CPU cores + +# ── Reward shaping ────────────────────────────────────────────────── +WIN_REWARD = 1.0 +DRAW_REWARD = 0.0 +LOSS_REWARD = -1.0 + +# ── Checkpointing ────────────────────────────────────────────────── +CHECKPOINT_DIR = "rl/checkpoints" +CHECKPOINT_INTERVAL = 5 # save model every N iterations +EXPORT_DIR = "rl/export" + +# ── ESP32 export ──────────────────────────────────────────────────── +QUANTIZE_INT8 = True # int8 quantization for TFLite (recommended for ESP32) diff --git a/rl/export.py b/rl/export.py new file mode 100644 index 0000000..be15365 --- /dev/null +++ b/rl/export.py @@ -0,0 +1,86 @@ +"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32.""" + +import os +import numpy as np + +from .game import ConnectFour, ROWS, COLS +from .config import EXPORT_DIR, QUANTIZE_INT8 + + +def representative_dataset(): + """Generate sample inputs for int8 calibration.""" + game = ConnectFour() + for _ in range(200): + game.reset() + # Play random moves to get diverse board states + moves = np.random.randint(0, min(ROWS * COLS, 20)) + for _ in range(moves): + legal = game.legal_moves() + if not legal or game.done: + break + game.step(np.random.choice(legal)) + yield [game.get_state()[np.newaxis].astype(np.float32)] + + +def export_tflite(model_path, quantize=None): + """Convert a saved Keras model to TFLite. + + Args: + model_path: Path to the .keras model file. + quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8. + """ + import tensorflow as tf + + if quantize is None: + quantize = QUANTIZE_INT8 + + os.makedirs(EXPORT_DIR, exist_ok=True) + + model = tf.keras.models.load_model(model_path) + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + + if quantize: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + suffix = "_int8" + else: + suffix = "_f32" + + tflite_model = converter.convert() + + out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite") + with open(out_path, "wb") as f: + f.write(tflite_model) + + size_kb = len(tflite_model) / 1024 + print(f"Exported: {out_path} ({size_kb:.1f} KB)") + + # Also export as C header for direct embedding in firmware + header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h") + _write_c_header(tflite_model, header_path) + print(f"C header: {header_path}") + + return out_path + + +def _write_c_header(model_bytes, path): + """Write TFLite model as a C byte array for ESP32 firmware inclusion.""" + with open(path, "w") as f: + f.write("#pragma once\n\n") + f.write(f"// Auto-generated — {len(model_bytes)} bytes\n") + f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n") + f.write("alignas(16) const unsigned char connect4_model[] = {\n") + for i in range(0, len(model_bytes), 12): + chunk = model_bytes[i:i + 12] + f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n") + f.write("};\n") + + +if __name__ == "__main__": + import sys + model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras" + export_tflite(model_path) diff --git a/rl/game.py b/rl/game.py new file mode 100644 index 0000000..c95153f --- /dev/null +++ b/rl/game.py @@ -0,0 +1,102 @@ +"""Connect Four game environment for self-play training.""" + +import numpy as np + +ROWS = 6 +COLS = 7 +WIN_LENGTH = 4 + + +class ConnectFour: + """Connect Four game with numpy board representation. + + Board encoding: 0 = empty, 1 = player 1, -1 = player 2. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.board = np.zeros((ROWS, COLS), dtype=np.int8) + self.current_player = 1 + self.done = False + self.winner = 0 # 0 = no winner / draw, 1 or -1 + self.move_count = 0 + return self.get_state() + + def get_state(self): + """Return board from current player's perspective as (6,7,2) tensor. + + Channel 0: current player's pieces (1s). + Channel 1: opponent's pieces (1s). + """ + state = np.zeros((ROWS, COLS, 2), dtype=np.float32) + state[:, :, 0] = (self.board == self.current_player).astype(np.float32) + state[:, :, 1] = (self.board == -self.current_player).astype(np.float32) + return state + + def legal_moves(self): + """Return list of columns that are not full.""" + return [c for c in range(COLS) if self.board[0, c] == 0] + + def legal_moves_mask(self): + """Return binary mask of legal columns.""" + return (self.board[0] == 0).astype(np.float32) + + def step(self, col): + """Play a move in the given column. Returns (state, reward, done).""" + if self.done: + raise ValueError("Game is already over.") + if col < 0 or col >= COLS or self.board[0, col] != 0: + raise ValueError(f"Illegal move: column {col}") + + # Drop piece + row = self._get_drop_row(col) + self.board[row, col] = self.current_player + self.move_count += 1 + + # Check win + if self._check_win(row, col): + self.done = True + self.winner = self.current_player + reward = 1.0 + elif self.move_count == ROWS * COLS: + self.done = True + self.winner = 0 + reward = 0.0 + else: + reward = 0.0 + + # Switch player + self.current_player *= -1 + return self.get_state(), reward, self.done + + def _get_drop_row(self, col): + for r in range(ROWS - 1, -1, -1): + if self.board[r, col] == 0: + return r + raise ValueError(f"Column {col} is full") + + def _check_win(self, row, col): + player = self.board[row, col] + directions = [(0, 1), (1, 0), (1, 1), (1, -1)] + for dr, dc in directions: + count = 1 + for sign in (1, -1): + r, c = row + sign * dr, col + sign * dc + while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player: + count += 1 + r += sign * dr + c += sign * dc + if count >= WIN_LENGTH: + return True + return False + + def clone(self): + g = ConnectFour() + g.board = self.board.copy() + g.current_player = self.current_player + g.done = self.done + g.winner = self.winner + g.move_count = self.move_count + return g diff --git a/rl/main.py b/rl/main.py new file mode 100644 index 0000000..f118498 --- /dev/null +++ b/rl/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from rl!") + + +if __name__ == "__main__": + main() diff --git a/rl/mcts.py b/rl/mcts.py new file mode 100644 index 0000000..2b56835 --- /dev/null +++ b/rl/mcts.py @@ -0,0 +1,103 @@ +"""Monte Carlo Tree Search for self-play data generation.""" + +import math +import numpy as np +from .game import ConnectFour +from .config import MCTS_C_PUCT + + +class MCTSNode: + __slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game") + + def __init__(self, game, parent=None, action=None, prior=0.0): + self.game = game + self.parent = parent + self.action = action + self.prior = prior + self.visit_count = 0 + self.value_sum = 0.0 + self.children = {} + + @property + def q_value(self): + if self.visit_count == 0: + return 0.0 + return self.value_sum / self.visit_count + + def ucb_score(self): + parent_visits = self.parent.visit_count if self.parent else 1 + exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count) + return self.q_value + exploration + + def is_leaf(self): + return len(self.children) == 0 + + def expand(self, policy_probs): + """Expand node using network policy output.""" + legal = self.game.legal_moves() + for col in legal: + if col not in self.children: + self.children[col] = MCTSNode( + game=None, parent=self, action=col, prior=policy_probs[col] + ) + + def select_child(self): + return max(self.children.values(), key=lambda c: c.ucb_score()) + + +def run_mcts(game, model, num_simulations): + """Run MCTS from current game state, return visit-count policy vector.""" + root = MCTSNode(game.clone()) + + # Evaluate root + state = root.game.get_state() + policy_logits, value = model.predict(state[np.newaxis], verbose=0) + policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask()) + root.expand(policy) + + for _ in range(num_simulations): + node = root + sim_game = root.game.clone() + + # SELECT — walk down tree picking best UCB child + while not node.is_leaf() and not sim_game.done: + node = node.select_child() + sim_game.step(node.action) + + # EVALUATE leaf + if sim_game.done: + # Terminal: value from perspective of player who just moved + if sim_game.winner == 0: + leaf_value = 0.0 + else: + # The winner is sim_game.winner; current_player already switched + leaf_value = -1.0 # current player lost (winner was previous player) + else: + node.game = sim_game.clone() + state = sim_game.get_state() + policy_logits, value = model.predict(state[np.newaxis], verbose=0) + leaf_value = value[0, 0] + policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask()) + node.expand(policy) + + # BACKUP — propagate value up, flipping sign each level + while node is not None: + node.visit_count += 1 + node.value_sum += leaf_value + leaf_value = -leaf_value + node = node.parent + + # Build policy from visit counts + visits = np.zeros(7, dtype=np.float32) + for col, child in root.children.items(): + visits[col] = child.visit_count + return visits + + +def _mask_and_normalize(logits, mask): + """Apply legal-move mask and softmax.""" + logits = np.array(logits, dtype=np.float64) + logits[mask == 0] = -1e9 + exp = np.exp(logits - np.max(logits)) + probs = exp / np.sum(exp) + return probs.astype(np.float32) diff --git a/rl/model.py b/rl/model.py new file mode 100644 index 0000000..ed8acc8 --- /dev/null +++ b/rl/model.py @@ -0,0 +1,54 @@ +"""Compact dual-head neural network (policy + value) sized for ESP32.""" + +from .config import CONV_FILTERS, NUM_CONV_LAYERS, DENSE_UNITS, LEARNING_RATE + + +def build_model(): + """Build a small AlphaZero-style network. + + Input: (6, 7, 2) — current player pieces / opponent pieces + Output: policy (7,) — log-probabilities over columns + value (1,) — board evaluation in [-1, 1] + """ + from tensorflow import keras + from tensorflow.keras import layers + + inp = layers.Input(shape=(6, 7, 2), name="board") + + x = inp + for i in range(NUM_CONV_LAYERS): + x = layers.Conv2D( + CONV_FILTERS, 3, padding="same", activation="relu", name=f"conv{i}" + )(x) + x = layers.BatchNormalization(name=f"bn{i}")(x) + + flat = layers.Flatten(name="flat")(x) + shared = layers.Dense(DENSE_UNITS, activation="relu", name="shared_dense")(flat) + + # Policy head + policy = layers.Dense(7, name="policy_logits")(shared) + + # Value head + value = layers.Dense(1, activation="tanh", name="value")(shared) + + model = keras.Model(inputs=inp, outputs=[policy, value], name="connect4_net") + + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE), + loss={ + "policy_logits": keras.losses.CategoricalCrossentropy(from_logits=True), + "value": keras.losses.MeanSquaredError(), + }, + loss_weights={"policy_logits": 1.0, "value": 1.0}, + ) + return model + + +def print_model_info(model): + model.summary() + total_params = model.count_params() + approx_size_kb = total_params * 4 / 1024 # float32 + approx_int8_kb = total_params / 1024 # int8 + print(f"\nTotal parameters: {total_params:,}") + print(f"Approx size (float32): {approx_size_kb:.1f} KB") + print(f"Approx size (int8): {approx_int8_kb:.1f} KB") diff --git a/rl/train.py b/rl/train.py new file mode 100644 index 0000000..8d10045 --- /dev/null +++ b/rl/train.py @@ -0,0 +1,143 @@ +"""Self-play training loop with parallel game generation.""" + +import os +import numpy as np +from collections import deque +from multiprocessing import Pool, cpu_count + +from .game import ConnectFour +from .model import build_model, print_model_info +from .mcts import run_mcts +from .config import ( + NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS, + MCTS_TEMPERATURE, TEMP_DROP_MOVE, + WIN_REWARD, DRAW_REWARD, LOSS_REWARD, + BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE, + CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS, +) + +# Per-worker global model (loaded once per process) +_worker_model = None + + +def _init_worker(weights_list): + """Initialize a worker process with its own model copy.""" + global _worker_model + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + os.environ["CUDA_VISIBLE_DEVICES"] = "" + _worker_model = build_model() + _worker_model.set_weights(weights_list) + + +def _play_one_game(_): + """Play a single self-play game in a worker process.""" + game = ConnectFour() + trajectory = [] + + while not game.done: + state = game.get_state() + visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS) + + if game.move_count < TEMP_DROP_MOVE: + temp = MCTS_TEMPERATURE + else: + temp = 0.1 + + if temp < 0.2: + action = int(np.argmax(visit_counts)) + policy = np.zeros(7, dtype=np.float32) + policy[action] = 1.0 + else: + counts = visit_counts ** (1.0 / temp) + policy = counts / counts.sum() + action = np.random.choice(7, p=policy) + + trajectory.append((state, policy, game.current_player)) + game.step(action) + + samples = [] + for state, policy, player in trajectory: + if game.winner == 0: + value = DRAW_REWARD + elif game.winner == player: + value = WIN_REWARD + else: + value = LOSS_REWARD + samples.append((state, policy, value)) + + return samples + + +def train(): + """Main training entry point.""" + model = build_model() + print_model_info(model) + + num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count() + print(f"Using {num_workers} worker processes for self-play") + + replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE) + os.makedirs(CHECKPOINT_DIR, exist_ok=True) + + for iteration in range(1, NUM_ITERATIONS + 1): + print(f"\n{'='*60}") + print(f"Iteration {iteration}/{NUM_ITERATIONS}") + print(f"{'='*60}") + + # ── Self-play (parallel) ─────────────────────────────── + weights = model.get_weights() + with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool: + results = pool.map(_play_one_game, range(GAMES_PER_ITERATION)) + + wins = {1: 0, -1: 0, 0: 0} + for samples in results: + replay_buffer.extend(samples) + if samples: + last_value = samples[-1][2] + if last_value == WIN_REWARD: + wins[1] += 1 + elif last_value == LOSS_REWARD: + wins[-1] += 1 + else: + wins[0] += 1 + + print(f" Self-play: {GAMES_PER_ITERATION} games " + f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})") + print(f" Buffer size: {len(replay_buffer)}") + + # ── Train ─────────────────────────────────────────────── + if len(replay_buffer) >= BATCH_SIZE: + sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION) + indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False) + batch = [replay_buffer[i] for i in indices] + + states = np.array([s[0] for s in batch]) + policies = np.array([s[1] for s in batch]) + values = np.array([s[2] for s in batch]).reshape(-1, 1) + + history = model.fit( + states, + {"policy_logits": policies, "value": values}, + batch_size=BATCH_SIZE, + epochs=EPOCHS_PER_ITERATION, + verbose=1, + ) + policy_loss = history.history["policy_logits_loss"][-1] + value_loss = history.history["value_loss"][-1] + print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}") + + # ── Checkpoint ────────────────────────────────────────── + if iteration % CHECKPOINT_INTERVAL == 0: + path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras") + model.save(path) + print(f" Saved checkpoint: {path}") + + final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras") + model.save(final_path) + print(f"\nTraining complete. Final model saved to {final_path}") + return model + + +if __name__ == "__main__": + train() diff --git a/rl/visualize.py b/rl/visualize.py new file mode 100644 index 0000000..7ab6261 --- /dev/null +++ b/rl/visualize.py @@ -0,0 +1,482 @@ +"""Pygame visualization of Connect Four RL training. + +Left panel: live self-play game board +Right panel: loss curves + win-rate chart + training stats +""" + +import os +import threading +import time +from collections import deque + +import numpy as np +import pygame + +from .game import ConnectFour, ROWS, COLS +from .model import build_model, print_model_info +from .mcts import run_mcts +from .config import ( + NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS, + MCTS_TEMPERATURE, TEMP_DROP_MOVE, + WIN_REWARD, DRAW_REWARD, LOSS_REWARD, + BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE, + CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS, +) +from multiprocessing import Pool, cpu_count + +# ── Layout constants ──────────────────────────────────────────────── +CELL = 80 +BOARD_W = COLS * CELL +BOARD_H = ROWS * CELL +PANEL_W = 420 +MARGIN = 20 +WIN_W = BOARD_W + PANEL_W + MARGIN * 3 +WIN_H = BOARD_H + MARGIN * 2 +FPS = 30 + +# ── Colors ────────────────────────────────────────────────────────── +BG = (30, 30, 40) +BOARD_BG = (0, 60, 180) +EMPTY = (20, 20, 30) +P1_COLOR = (255, 220, 50) # yellow +P2_COLOR = (220, 40, 40) # red +WIN_HIGHLIGHT = (100, 255, 100) +GRID_LINE = (0, 40, 140) +TEXT_COLOR = (220, 220, 220) +CHART_BG = (40, 40, 55) +POLICY_LINE = (80, 200, 255) +VALUE_LINE = (255, 160, 60) +P1_CHART = (255, 220, 50) +P2_CHART = (220, 40, 40) +DRAW_CHART = (140, 140, 140) + +# ── Shared state between training thread and pygame loop ──────────── +_state = { + "board": np.zeros((ROWS, COLS), dtype=np.int8), + "iteration": 0, + "game_num": 0, + "phase": "init", # init / self-play / training / done + "policy_losses": [], + "value_losses": [], + "win_history": [], # list of (p1_wins, p2_wins, draws) per iteration + "move_delay": 0.3, + "status": "Initializing...", + "winner": 0, + "running": True, +} +_lock = threading.Lock() + + +# ── Worker setup (same as train.py) ───────────────────────────────── +_worker_model = None + + +def _init_worker(weights_list): + global _worker_model + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + os.environ["CUDA_VISIBLE_DEVICES"] = "" + _worker_model = build_model() + _worker_model.set_weights(weights_list) + + +def _play_one_game(_): + game = ConnectFour() + trajectory = [] + while not game.done: + state = game.get_state() + visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS) + if game.move_count < TEMP_DROP_MOVE: + temp = MCTS_TEMPERATURE + else: + temp = 0.1 + if temp < 0.2: + action = int(np.argmax(visit_counts)) + policy = np.zeros(7, dtype=np.float32) + policy[action] = 1.0 + else: + counts = visit_counts ** (1.0 / temp) + policy = counts / counts.sum() + action = np.random.choice(7, p=policy) + trajectory.append((state, policy, game.current_player)) + game.step(action) + samples = [] + for state, policy, player in trajectory: + if game.winner == 0: + value = DRAW_REWARD + elif game.winner == player: + value = WIN_REWARD + else: + value = LOSS_REWARD + samples.append((state, policy, value)) + return samples + + +def _play_showcase_game(model): + """Play one game slowly on the main training thread, updating shared state.""" + game = ConnectFour() + trajectory = [] + + with _lock: + _state["board"] = game.board.copy() + _state["winner"] = 0 + + while not game.done and _state["running"]: + state = game.get_state() + visit_counts = run_mcts(game, model, MCTS_SIMULATIONS) + + if game.move_count < TEMP_DROP_MOVE: + temp = MCTS_TEMPERATURE + else: + temp = 0.1 + if temp < 0.2: + action = int(np.argmax(visit_counts)) + policy = np.zeros(7, dtype=np.float32) + policy[action] = 1.0 + else: + counts = visit_counts ** (1.0 / temp) + policy = counts / counts.sum() + action = np.random.choice(7, p=policy) + + trajectory.append((state, policy, game.current_player)) + game.step(action) + + with _lock: + _state["board"] = game.board.copy() + + time.sleep(_state["move_delay"]) + + with _lock: + _state["winner"] = game.winner + + samples = [] + for state, policy, player in trajectory: + if game.winner == 0: + value = DRAW_REWARD + elif game.winner == player: + value = WIN_REWARD + else: + value = LOSS_REWARD + samples.append((state, policy, value)) + return samples + + +def _training_thread(): + """Run the full training loop, pushing updates to shared state.""" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + model = build_model() + print_model_info(model) + + num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count() + replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE) + os.makedirs(CHECKPOINT_DIR, exist_ok=True) + + with _lock: + _state["status"] = f"Using {num_workers} workers" + + for iteration in range(1, NUM_ITERATIONS + 1): + if not _state["running"]: + break + + with _lock: + _state["iteration"] = iteration + _state["phase"] = "self-play" + _state["status"] = f"Iteration {iteration}/{NUM_ITERATIONS} - Self-play" + + # Play one showcase game visually + with _lock: + _state["game_num"] = 0 + showcase_samples = _play_showcase_game(model) + replay_buffer.extend(showcase_samples) + + # Play remaining games in parallel + remaining = GAMES_PER_ITERATION - 1 + if remaining > 0 and _state["running"]: + with _lock: + _state["status"] = f"Iter {iteration} - Playing {remaining} games (parallel)..." + + weights = model.get_weights() + with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool: + results = pool.map(_play_one_game, range(remaining)) + + for samples in results: + replay_buffer.extend(samples) + + # Count wins across all games this iteration + wins = {1: 0, -1: 0, 0: 0} + # Showcase game + if showcase_samples: + last_val = showcase_samples[-1][2] + if last_val == WIN_REWARD: + wins[1] += 1 + elif last_val == LOSS_REWARD: + wins[-1] += 1 + else: + wins[0] += 1 + # Parallel games + if remaining > 0 and _state["running"]: + for samples in results: + if samples: + last_val = samples[-1][2] + if last_val == WIN_REWARD: + wins[1] += 1 + elif last_val == LOSS_REWARD: + wins[-1] += 1 + else: + wins[0] += 1 + + with _lock: + _state["win_history"].append((wins[1], wins[-1], wins[0])) + + # Train + if len(replay_buffer) >= BATCH_SIZE and _state["running"]: + with _lock: + _state["phase"] = "training" + _state["status"] = f"Iter {iteration} - Training..." + + sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION) + indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False) + batch = [replay_buffer[i] for i in indices] + + states = np.array([s[0] for s in batch]) + policies = np.array([s[1] for s in batch]) + values = np.array([s[2] for s in batch]).reshape(-1, 1) + + history = model.fit( + states, + {"policy_logits": policies, "value": values}, + batch_size=BATCH_SIZE, + epochs=EPOCHS_PER_ITERATION, + verbose=0, + ) + + with _lock: + _state["policy_losses"].append(history.history["policy_logits_loss"][-1]) + _state["value_losses"].append(history.history["value_loss"][-1]) + + # Checkpoint + if iteration % CHECKPOINT_INTERVAL == 0: + path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras") + model.save(path) + + if _state["running"]: + final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras") + model.save(final_path) + + with _lock: + _state["phase"] = "done" + _state["status"] = "Training complete!" + + +# ── Drawing helpers ───────────────────────────────────────────────── + +def _draw_board(surface, board, x0, y0): + """Draw the Connect Four board.""" + # Board background + pygame.draw.rect(surface, BOARD_BG, (x0, y0, BOARD_W, BOARD_H), border_radius=8) + + for r in range(ROWS): + for c in range(COLS): + cx = x0 + c * CELL + CELL // 2 + cy = y0 + r * CELL + CELL // 2 + radius = CELL // 2 - 6 + + val = board[r, c] + if val == 1: + color = P1_COLOR + elif val == -1: + color = P2_COLOR + else: + color = EMPTY + + pygame.draw.circle(surface, color, (cx, cy), radius) + pygame.draw.circle(surface, GRID_LINE, (cx, cy), radius, 2) + + +def _draw_chart(surface, x, y, w, h, series_list, colors, title, font): + """Draw a simple line chart with multiple series.""" + pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6) + pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6) + + # Title + title_surf = font.render(title, True, TEXT_COLOR) + surface.blit(title_surf, (x + 8, y + 4)) + + chart_x = x + 8 + chart_y = y + 24 + chart_w = w - 16 + chart_h = h - 32 + + if not any(series_list): + return + + # Find global min/max + all_vals = [v for s in series_list if s for v in s] + if not all_vals: + return + min_val = min(all_vals) + max_val = max(all_vals) + val_range = max_val - min_val if max_val != min_val else 1.0 + + for series, color in zip(series_list, colors): + if len(series) < 2: + continue + points = [] + for i, v in enumerate(series): + px = chart_x + int(i / (len(series) - 1) * chart_w) + py = chart_y + chart_h - int((v - min_val) / val_range * chart_h) + points.append((px, py)) + pygame.draw.lines(surface, color, False, points, 2) + + +def _draw_stacked_bar(surface, x, y, w, h, win_history, font): + """Draw stacked bar chart of win rates.""" + pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6) + pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6) + + title_surf = font.render("Win rates per iteration", True, TEXT_COLOR) + surface.blit(title_surf, (x + 8, y + 4)) + + if not win_history: + return + + chart_x = x + 8 + chart_y = y + 24 + chart_w = w - 16 + chart_h = h - 48 + + n = len(win_history) + bar_w = max(2, chart_w // max(n, 1)) + + for i, (p1, p2, dr) in enumerate(win_history): + total = p1 + p2 + dr + if total == 0: + continue + bx = chart_x + int(i / max(n, 1) * chart_w) + + # Stack: P1 (bottom), draws (middle), P2 (top) + h1 = int(p1 / total * chart_h) + hd = int(dr / total * chart_h) + h2 = chart_h - h1 - hd + + by = chart_y + pygame.draw.rect(surface, P2_CHART, (bx, by, bar_w - 1, h2)) + by += h2 + pygame.draw.rect(surface, DRAW_CHART, (bx, by, bar_w - 1, hd)) + by += hd + pygame.draw.rect(surface, P1_CHART, (bx, by, bar_w - 1, h1)) + + # Legend + ly = y + h - 18 + for label, color, lx in [("P1", P1_CHART, x + 8), ("Draw", DRAW_CHART, x + 70), ("P2", P2_CHART, x + 150)]: + pygame.draw.rect(surface, color, (lx, ly, 12, 12)) + surface.blit(font.render(label, True, TEXT_COLOR), (lx + 16, ly - 2)) + + +def run_visualized(): + """Launch pygame window and run training with live visualization.""" + pygame.init() + screen = pygame.display.set_mode((WIN_W, WIN_H)) + pygame.display.set_caption("Connect Four RL Training") + clock = pygame.time.Clock() + font = pygame.font.SysFont("monospace", 14) + font_big = pygame.font.SysFont("monospace", 18, bold=True) + + # Start training in background thread + train_thread = threading.Thread(target=_training_thread, daemon=True) + train_thread.start() + + running = True + while running: + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + _state["running"] = False + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_ESCAPE: + running = False + _state["running"] = False + elif event.key == pygame.K_UP: + _state["move_delay"] = max(0.05, _state["move_delay"] - 0.05) + elif event.key == pygame.K_DOWN: + _state["move_delay"] = min(2.0, _state["move_delay"] + 0.05) + + screen.fill(BG) + + with _lock: + board = _state["board"].copy() + iteration = _state["iteration"] + phase = _state["phase"] + status = _state["status"] + policy_losses = list(_state["policy_losses"]) + value_losses = list(_state["value_losses"]) + win_history = list(_state["win_history"]) + winner = _state["winner"] + delay = _state["move_delay"] + + # ── Left: game board ──────────────────────────────────── + bx, by = MARGIN, MARGIN + _draw_board(screen, board, bx, by) + + # Winner overlay + if winner != 0 and phase == "self-play": + label = f"Player {1 if winner == 1 else 2} wins!" + color = P1_COLOR if winner == 1 else P2_COLOR + win_surf = font_big.render(label, True, color) + wrect = win_surf.get_rect(center=(bx + BOARD_W // 2, by + BOARD_H + 2)) + if wrect.bottom < WIN_H: + screen.blit(win_surf, wrect) + + # ── Right panel ──────────────────────────────────────── + px = BOARD_W + MARGIN * 2 + py = MARGIN + + # Status + status_surf = font_big.render(status, True, TEXT_COLOR) + screen.blit(status_surf, (px, py)) + py += 28 + + iter_surf = font.render(f"Iteration: {iteration}/{NUM_ITERATIONS} Phase: {phase}", True, TEXT_COLOR) + screen.blit(iter_surf, (px, py)) + py += 20 + + delay_surf = font.render(f"Move delay: {delay:.2f}s (Up/Down to adjust)", True, (150, 150, 170)) + screen.blit(delay_surf, (px, py)) + py += 28 + + # Loss chart + chart_h = 140 + _draw_chart( + screen, px, py, PANEL_W, chart_h, + [policy_losses, value_losses], + [POLICY_LINE, VALUE_LINE], + "Loss (blue=policy, orange=value)", + font, + ) + py += chart_h + 12 + + # Win rate chart + bar_h = 160 + _draw_stacked_bar(screen, px, py, PANEL_W, bar_h, win_history, font) + py += bar_h + 12 + + # Latest stats + if policy_losses: + pl = font.render(f"Policy loss: {policy_losses[-1]:.4f}", True, POLICY_LINE) + screen.blit(pl, (px, py)) + py += 18 + if value_losses: + vl = font.render(f"Value loss: {value_losses[-1]:.4f}", True, VALUE_LINE) + screen.blit(vl, (px, py)) + py += 18 + if win_history: + p1, p2, dr = win_history[-1] + ws = font.render(f"Last iter: P1={p1} P2={p2} Draw={dr}", True, TEXT_COLOR) + screen.blit(ws, (px, py)) + + pygame.display.flip() + clock.tick(FPS) + + pygame.quit() + _state["running"] = False + train_thread.join(timeout=5) diff --git a/src/main.cpp b/src/main.cpp index 54c7d7f..53befac 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -121,6 +121,7 @@ bool checkGameEnd(); void updateThinkingVisuals(int8_t pColor, int8_t column); void animateDrop(int col, int player); void moveDiscToCol(int startCol, int targetCol, int player, int speed); +int evaluateBoard(int8_t aiP, int8_t huP); int minimax(int depth, int alpha, int beta, bool isMax, int8_t aiP, int8_t huP, int8_t rootCol); void performAiMove(int8_t aiP); void randomizeDemoPlies(); @@ -265,6 +266,39 @@ int8_t scanBoard() { return 0; } +int evaluateBoard(int8_t aiP, int8_t huP) { + int score = 0; + + // Center column bonus + for (int r = 0; r < ROWS; r++) { + if (board[3][r] == aiP) score += 3; + else if (board[3][r] == huP) score -= 3; + } + + // Score a window of 4 cells by piece counts + auto scoreWindow = [&](int c, int r, int dc, int dr) -> int { + int ai = 0, hu = 0; + for (int i = 0; i < 4; i++) { + int8_t v = board[c + i * dc][r + i * dr]; + if (v == aiP) ai++; + else if (v == huP) hu++; + } + if (ai > 0 && hu > 0) return 0; + if (ai == 3) return 50; + if (ai == 2) return 5; + if (hu == 3) return -50; + if (hu == 2) return -5; + return 0; + }; + + for (int r = 0; r < 6; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, 0); + for (int r = 0; r < 3; r++) for (int c = 0; c < 7; c++) score += scoreWindow(c, r, 0, 1); + for (int r = 0; r < 3; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, 1); + for (int r = 3; r < 6; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, -1); + + return score; +} + bool checkGameEnd() { winnerPlayer = scanBoard(); bool won = winnerPlayer != 0; @@ -331,7 +365,7 @@ int minimax(int depth, int alpha, int beta, bool isMax, int8_t aiP, int8_t huP, int8_t win = scanBoard(); if (win == aiP) return 1000 + depth; if (win == huP) return -1000 - depth; - if (depth == 0 || isBoardFull()) return 0; + if (depth == 0 || isBoardFull()) return evaluateBoard(aiP, huP); int best = isMax ? -10000 : 10000; for (int c : colOrder) { @@ -356,13 +390,22 @@ void performAiMove(int8_t aiP) { int originalPly = currentLookAhead; if (gameState == DEMO) currentLookAhead = demoPly[aiP - 1]; - // Phase 1: always take an instant win or block an opponent's win + // Phase 1a: check ALL columns for instant AI win bool found = false; for (int c = 0; c < COLS && !found; c++) { int r = getFirstEmptyRow(c); if (r != -1) { - board[c][r] = aiP; if (scanBoard() == aiP) { board[c][r]=0; bestCol=c; found=true; break; } - board[c][r] = huP; if (scanBoard() == huP) { board[c][r]=0; bestCol=c; found=true; break; } + board[c][r] = aiP; + if (scanBoard() == aiP) { board[c][r] = 0; bestCol = c; found = true; break; } + board[c][r] = 0; + } + } + // Phase 1b: check ALL columns for opponent block + for (int c = 0; c < COLS && !found; c++) { + int r = getFirstEmptyRow(c); + if (r != -1) { + board[c][r] = huP; + if (scanBoard() == huP) { board[c][r] = 0; bestCol = c; found = true; break; } board[c][r] = 0; } }