[fix] Non heuristic moves...
This commit is contained in:
@@ -0,0 +1 @@
|
||||
3.13
|
||||
+54
-2
@@ -166,12 +166,57 @@ function scanBoard(b) {
|
||||
return [0, []];
|
||||
}
|
||||
|
||||
function evaluateBoard(b, aiP, huP) {
|
||||
let score = 0;
|
||||
|
||||
// Center column bonus
|
||||
for (let r = 0; r < ROWS; r++) {
|
||||
if (b[3][r] === aiP) score += 3;
|
||||
else if (b[3][r] === huP) score -= 3;
|
||||
}
|
||||
|
||||
// Score a window of 4 cells by piece counts
|
||||
function scoreWindow(c, r, dc, dr) {
|
||||
let ai = 0, hu = 0;
|
||||
for (let i = 0; i < 4; i++) {
|
||||
const v = b[c + i * dc][r + i * dr];
|
||||
if (v === aiP) ai++;
|
||||
else if (v === huP) hu++;
|
||||
}
|
||||
if (ai > 0 && hu > 0) return 0;
|
||||
if (ai === 3) return 50;
|
||||
if (ai === 2) return 5;
|
||||
if (hu === 3) return -50;
|
||||
if (hu === 2) return -5;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Horizontal
|
||||
for (let r = 0; r < ROWS; r++)
|
||||
for (let c = 0; c <= COLS - 4; c++)
|
||||
score += scoreWindow(c, r, 1, 0);
|
||||
// Vertical
|
||||
for (let r = 0; r <= ROWS - 4; r++)
|
||||
for (let c = 0; c < COLS; c++)
|
||||
score += scoreWindow(c, r, 0, 1);
|
||||
// Diagonal up-right
|
||||
for (let r = 0; r <= ROWS - 4; r++)
|
||||
for (let c = 0; c <= COLS - 4; c++)
|
||||
score += scoreWindow(c, r, 1, 1);
|
||||
// Diagonal down-right
|
||||
for (let r = 3; r < ROWS; r++)
|
||||
for (let c = 0; c <= COLS - 4; c++)
|
||||
score += scoreWindow(c, r, 1, -1);
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
// --- AI -----------------------------------------------------
|
||||
function minimax(b, depth, alpha, beta, isMax, aiP, huP) {
|
||||
const [winner] = scanBoard(b);
|
||||
if (winner === aiP) return 1000 + depth;
|
||||
if (winner === huP) return -1000 - depth;
|
||||
if (depth === 0 || isBoardFull(b)) return 0;
|
||||
if (depth === 0 || isBoardFull(b)) return evaluateBoard(b, aiP, huP);
|
||||
|
||||
let best = isMax ? -10000 : 10000;
|
||||
for (const c of COL_ORDER) {
|
||||
@@ -196,12 +241,19 @@ function performAiMove(b, aiP, lookAhead, isDemo = false, dPly = 4) {
|
||||
const huP = aiP === 1 ? 2 : 1;
|
||||
const ply = isDemo ? dPly : lookAhead;
|
||||
|
||||
// Phase 1: instant win / block
|
||||
// Phase 1a: check ALL columns for instant AI win
|
||||
for (let c = 0; c < COLS; c++) {
|
||||
const r = getFirstEmptyRow(b, c);
|
||||
if (r === -1) continue;
|
||||
b[c][r] = aiP;
|
||||
if (scanBoard(b)[0] === aiP) { b[c][r] = 0; return c; }
|
||||
b[c][r] = 0;
|
||||
}
|
||||
|
||||
// Phase 1b: check ALL columns for opponent block
|
||||
for (let c = 0; c < COLS; c++) {
|
||||
const r = getFirstEmptyRow(b, c);
|
||||
if (r === -1) continue;
|
||||
b[c][r] = huP;
|
||||
if (scanBoard(b)[0] === huP) { b[c][r] = 0; return c; }
|
||||
b[c][r] = 0;
|
||||
|
||||
+4
-1
@@ -2,11 +2,14 @@
|
||||
name = "connect-four-terminal"
|
||||
version = "1.0.0"
|
||||
description = "Connect Four terminal game with AI"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"rich>=13.0",
|
||||
"python-dotenv>=1.0",
|
||||
"readchar>=4.0",
|
||||
"tensorflow>=2.16",
|
||||
"numpy>=2.0",
|
||||
"pygame>=2.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
3.13
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Entry point: python -m rl [train|export|info]"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
|
||||
def main():
|
||||
cmd = sys.argv[1] if len(sys.argv) > 1 else "train"
|
||||
|
||||
if cmd == "train":
|
||||
from .train import train
|
||||
train()
|
||||
|
||||
elif cmd == "export":
|
||||
from .export import export_tflite
|
||||
model_path = sys.argv[2] if len(sys.argv) > 2 else "rl/checkpoints/model_final.keras"
|
||||
export_tflite(model_path)
|
||||
|
||||
elif cmd == "visualize":
|
||||
from .visualize import run_visualized
|
||||
run_visualized()
|
||||
|
||||
elif cmd == "info":
|
||||
from .model import build_model, print_model_info
|
||||
model = build_model()
|
||||
print_model_info(model)
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
print("Usage: python -m rl [train|visualize|export|info]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
main()
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Training hyperparameters — edit these to tune your model."""
|
||||
|
||||
# ── Model architecture ──────────────────────────────────────────────
|
||||
CONV_FILTERS = 32 # filters per conv layer (keep small for ESP32)
|
||||
NUM_CONV_LAYERS = 3 # number of convolutional blocks
|
||||
DENSE_UNITS = 64 # units in the dense layer before heads
|
||||
|
||||
# ── Training ────────────────────────────────────────────────────────
|
||||
LEARNING_RATE = 1e-3 # Adam learning rate
|
||||
BATCH_SIZE = 256 # training batch size
|
||||
EPOCHS_PER_ITERATION = 4 # epochs per training iteration
|
||||
REPLAY_BUFFER_SIZE = 50000 # max samples kept in replay buffer
|
||||
|
||||
# ── Self-play ───────────────────────────────────────────────────────
|
||||
NUM_ITERATIONS = 50 # total train iterations (self-play → train cycles)
|
||||
GAMES_PER_ITERATION = 100 # self-play games generated per iteration
|
||||
MCTS_SIMULATIONS = 50 # MCTS simulations per move
|
||||
MCTS_C_PUCT = 1.4 # exploration constant
|
||||
MCTS_TEMPERATURE = 1.0 # move selection temperature (1 = proportional, →0 = greedy)
|
||||
TEMP_DROP_MOVE = 10 # switch to greedy after this many moves
|
||||
|
||||
# ── Parallelism ────────────────────────────────────────────────────
|
||||
NUM_WORKERS = 0 # 0 = use all available CPU cores
|
||||
|
||||
# ── Reward shaping ──────────────────────────────────────────────────
|
||||
WIN_REWARD = 1.0
|
||||
DRAW_REWARD = 0.0
|
||||
LOSS_REWARD = -1.0
|
||||
|
||||
# ── Checkpointing ──────────────────────────────────────────────────
|
||||
CHECKPOINT_DIR = "rl/checkpoints"
|
||||
CHECKPOINT_INTERVAL = 5 # save model every N iterations
|
||||
EXPORT_DIR = "rl/export"
|
||||
|
||||
# ── ESP32 export ────────────────────────────────────────────────────
|
||||
QUANTIZE_INT8 = True # int8 quantization for TFLite (recommended for ESP32)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from .game import ConnectFour, ROWS, COLS
|
||||
from .config import EXPORT_DIR, QUANTIZE_INT8
|
||||
|
||||
|
||||
def representative_dataset():
|
||||
"""Generate sample inputs for int8 calibration."""
|
||||
game = ConnectFour()
|
||||
for _ in range(200):
|
||||
game.reset()
|
||||
# Play random moves to get diverse board states
|
||||
moves = np.random.randint(0, min(ROWS * COLS, 20))
|
||||
for _ in range(moves):
|
||||
legal = game.legal_moves()
|
||||
if not legal or game.done:
|
||||
break
|
||||
game.step(np.random.choice(legal))
|
||||
yield [game.get_state()[np.newaxis].astype(np.float32)]
|
||||
|
||||
|
||||
def export_tflite(model_path, quantize=None):
|
||||
"""Convert a saved Keras model to TFLite.
|
||||
|
||||
Args:
|
||||
model_path: Path to the .keras model file.
|
||||
quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
if quantize is None:
|
||||
quantize = QUANTIZE_INT8
|
||||
|
||||
os.makedirs(EXPORT_DIR, exist_ok=True)
|
||||
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
|
||||
if quantize:
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = representative_dataset
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.inference_input_type = tf.int8
|
||||
converter.inference_output_type = tf.int8
|
||||
suffix = "_int8"
|
||||
else:
|
||||
suffix = "_f32"
|
||||
|
||||
tflite_model = converter.convert()
|
||||
|
||||
out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite")
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
size_kb = len(tflite_model) / 1024
|
||||
print(f"Exported: {out_path} ({size_kb:.1f} KB)")
|
||||
|
||||
# Also export as C header for direct embedding in firmware
|
||||
header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h")
|
||||
_write_c_header(tflite_model, header_path)
|
||||
print(f"C header: {header_path}")
|
||||
|
||||
return out_path
|
||||
|
||||
|
||||
def _write_c_header(model_bytes, path):
|
||||
"""Write TFLite model as a C byte array for ESP32 firmware inclusion."""
|
||||
with open(path, "w") as f:
|
||||
f.write("#pragma once\n\n")
|
||||
f.write(f"// Auto-generated — {len(model_bytes)} bytes\n")
|
||||
f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n")
|
||||
f.write("alignas(16) const unsigned char connect4_model[] = {\n")
|
||||
for i in range(0, len(model_bytes), 12):
|
||||
chunk = model_bytes[i:i + 12]
|
||||
f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n")
|
||||
f.write("};\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras"
|
||||
export_tflite(model_path)
|
||||
+102
@@ -0,0 +1,102 @@
|
||||
"""Connect Four game environment for self-play training."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
ROWS = 6
|
||||
COLS = 7
|
||||
WIN_LENGTH = 4
|
||||
|
||||
|
||||
class ConnectFour:
|
||||
"""Connect Four game with numpy board representation.
|
||||
|
||||
Board encoding: 0 = empty, 1 = player 1, -1 = player 2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.board = np.zeros((ROWS, COLS), dtype=np.int8)
|
||||
self.current_player = 1
|
||||
self.done = False
|
||||
self.winner = 0 # 0 = no winner / draw, 1 or -1
|
||||
self.move_count = 0
|
||||
return self.get_state()
|
||||
|
||||
def get_state(self):
|
||||
"""Return board from current player's perspective as (6,7,2) tensor.
|
||||
|
||||
Channel 0: current player's pieces (1s).
|
||||
Channel 1: opponent's pieces (1s).
|
||||
"""
|
||||
state = np.zeros((ROWS, COLS, 2), dtype=np.float32)
|
||||
state[:, :, 0] = (self.board == self.current_player).astype(np.float32)
|
||||
state[:, :, 1] = (self.board == -self.current_player).astype(np.float32)
|
||||
return state
|
||||
|
||||
def legal_moves(self):
|
||||
"""Return list of columns that are not full."""
|
||||
return [c for c in range(COLS) if self.board[0, c] == 0]
|
||||
|
||||
def legal_moves_mask(self):
|
||||
"""Return binary mask of legal columns."""
|
||||
return (self.board[0] == 0).astype(np.float32)
|
||||
|
||||
def step(self, col):
|
||||
"""Play a move in the given column. Returns (state, reward, done)."""
|
||||
if self.done:
|
||||
raise ValueError("Game is already over.")
|
||||
if col < 0 or col >= COLS or self.board[0, col] != 0:
|
||||
raise ValueError(f"Illegal move: column {col}")
|
||||
|
||||
# Drop piece
|
||||
row = self._get_drop_row(col)
|
||||
self.board[row, col] = self.current_player
|
||||
self.move_count += 1
|
||||
|
||||
# Check win
|
||||
if self._check_win(row, col):
|
||||
self.done = True
|
||||
self.winner = self.current_player
|
||||
reward = 1.0
|
||||
elif self.move_count == ROWS * COLS:
|
||||
self.done = True
|
||||
self.winner = 0
|
||||
reward = 0.0
|
||||
else:
|
||||
reward = 0.0
|
||||
|
||||
# Switch player
|
||||
self.current_player *= -1
|
||||
return self.get_state(), reward, self.done
|
||||
|
||||
def _get_drop_row(self, col):
|
||||
for r in range(ROWS - 1, -1, -1):
|
||||
if self.board[r, col] == 0:
|
||||
return r
|
||||
raise ValueError(f"Column {col} is full")
|
||||
|
||||
def _check_win(self, row, col):
|
||||
player = self.board[row, col]
|
||||
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
|
||||
for dr, dc in directions:
|
||||
count = 1
|
||||
for sign in (1, -1):
|
||||
r, c = row + sign * dr, col + sign * dc
|
||||
while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player:
|
||||
count += 1
|
||||
r += sign * dr
|
||||
c += sign * dc
|
||||
if count >= WIN_LENGTH:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone(self):
|
||||
g = ConnectFour()
|
||||
g.board = self.board.copy()
|
||||
g.current_player = self.current_player
|
||||
g.done = self.done
|
||||
g.winner = self.winner
|
||||
g.move_count = self.move_count
|
||||
return g
|
||||
@@ -0,0 +1,6 @@
|
||||
def main():
|
||||
print("Hello from rl!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+103
@@ -0,0 +1,103 @@
|
||||
"""Monte Carlo Tree Search for self-play data generation."""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from .game import ConnectFour
|
||||
from .config import MCTS_C_PUCT
|
||||
|
||||
|
||||
class MCTSNode:
|
||||
__slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game")
|
||||
|
||||
def __init__(self, game, parent=None, action=None, prior=0.0):
|
||||
self.game = game
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.prior = prior
|
||||
self.visit_count = 0
|
||||
self.value_sum = 0.0
|
||||
self.children = {}
|
||||
|
||||
@property
|
||||
def q_value(self):
|
||||
if self.visit_count == 0:
|
||||
return 0.0
|
||||
return self.value_sum / self.visit_count
|
||||
|
||||
def ucb_score(self):
|
||||
parent_visits = self.parent.visit_count if self.parent else 1
|
||||
exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count)
|
||||
return self.q_value + exploration
|
||||
|
||||
def is_leaf(self):
|
||||
return len(self.children) == 0
|
||||
|
||||
def expand(self, policy_probs):
|
||||
"""Expand node using network policy output."""
|
||||
legal = self.game.legal_moves()
|
||||
for col in legal:
|
||||
if col not in self.children:
|
||||
self.children[col] = MCTSNode(
|
||||
game=None, parent=self, action=col, prior=policy_probs[col]
|
||||
)
|
||||
|
||||
def select_child(self):
|
||||
return max(self.children.values(), key=lambda c: c.ucb_score())
|
||||
|
||||
|
||||
def run_mcts(game, model, num_simulations):
|
||||
"""Run MCTS from current game state, return visit-count policy vector."""
|
||||
root = MCTSNode(game.clone())
|
||||
|
||||
# Evaluate root
|
||||
state = root.game.get_state()
|
||||
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
||||
policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask())
|
||||
root.expand(policy)
|
||||
|
||||
for _ in range(num_simulations):
|
||||
node = root
|
||||
sim_game = root.game.clone()
|
||||
|
||||
# SELECT — walk down tree picking best UCB child
|
||||
while not node.is_leaf() and not sim_game.done:
|
||||
node = node.select_child()
|
||||
sim_game.step(node.action)
|
||||
|
||||
# EVALUATE leaf
|
||||
if sim_game.done:
|
||||
# Terminal: value from perspective of player who just moved
|
||||
if sim_game.winner == 0:
|
||||
leaf_value = 0.0
|
||||
else:
|
||||
# The winner is sim_game.winner; current_player already switched
|
||||
leaf_value = -1.0 # current player lost (winner was previous player)
|
||||
else:
|
||||
node.game = sim_game.clone()
|
||||
state = sim_game.get_state()
|
||||
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
||||
leaf_value = value[0, 0]
|
||||
policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask())
|
||||
node.expand(policy)
|
||||
|
||||
# BACKUP — propagate value up, flipping sign each level
|
||||
while node is not None:
|
||||
node.visit_count += 1
|
||||
node.value_sum += leaf_value
|
||||
leaf_value = -leaf_value
|
||||
node = node.parent
|
||||
|
||||
# Build policy from visit counts
|
||||
visits = np.zeros(7, dtype=np.float32)
|
||||
for col, child in root.children.items():
|
||||
visits[col] = child.visit_count
|
||||
return visits
|
||||
|
||||
|
||||
def _mask_and_normalize(logits, mask):
|
||||
"""Apply legal-move mask and softmax."""
|
||||
logits = np.array(logits, dtype=np.float64)
|
||||
logits[mask == 0] = -1e9
|
||||
exp = np.exp(logits - np.max(logits))
|
||||
probs = exp / np.sum(exp)
|
||||
return probs.astype(np.float32)
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
"""Compact dual-head neural network (policy + value) sized for ESP32."""
|
||||
|
||||
from .config import CONV_FILTERS, NUM_CONV_LAYERS, DENSE_UNITS, LEARNING_RATE
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Build a small AlphaZero-style network.
|
||||
|
||||
Input: (6, 7, 2) — current player pieces / opponent pieces
|
||||
Output: policy (7,) — log-probabilities over columns
|
||||
value (1,) — board evaluation in [-1, 1]
|
||||
"""
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
|
||||
inp = layers.Input(shape=(6, 7, 2), name="board")
|
||||
|
||||
x = inp
|
||||
for i in range(NUM_CONV_LAYERS):
|
||||
x = layers.Conv2D(
|
||||
CONV_FILTERS, 3, padding="same", activation="relu", name=f"conv{i}"
|
||||
)(x)
|
||||
x = layers.BatchNormalization(name=f"bn{i}")(x)
|
||||
|
||||
flat = layers.Flatten(name="flat")(x)
|
||||
shared = layers.Dense(DENSE_UNITS, activation="relu", name="shared_dense")(flat)
|
||||
|
||||
# Policy head
|
||||
policy = layers.Dense(7, name="policy_logits")(shared)
|
||||
|
||||
# Value head
|
||||
value = layers.Dense(1, activation="tanh", name="value")(shared)
|
||||
|
||||
model = keras.Model(inputs=inp, outputs=[policy, value], name="connect4_net")
|
||||
|
||||
model.compile(
|
||||
optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
||||
loss={
|
||||
"policy_logits": keras.losses.CategoricalCrossentropy(from_logits=True),
|
||||
"value": keras.losses.MeanSquaredError(),
|
||||
},
|
||||
loss_weights={"policy_logits": 1.0, "value": 1.0},
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def print_model_info(model):
|
||||
model.summary()
|
||||
total_params = model.count_params()
|
||||
approx_size_kb = total_params * 4 / 1024 # float32
|
||||
approx_int8_kb = total_params / 1024 # int8
|
||||
print(f"\nTotal parameters: {total_params:,}")
|
||||
print(f"Approx size (float32): {approx_size_kb:.1f} KB")
|
||||
print(f"Approx size (int8): {approx_int8_kb:.1f} KB")
|
||||
+143
@@ -0,0 +1,143 @@
|
||||
"""Self-play training loop with parallel game generation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
from .game import ConnectFour
|
||||
from .model import build_model, print_model_info
|
||||
from .mcts import run_mcts
|
||||
from .config import (
|
||||
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
|
||||
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
|
||||
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
|
||||
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
|
||||
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
|
||||
)
|
||||
|
||||
# Per-worker global model (loaded once per process)
|
||||
_worker_model = None
|
||||
|
||||
|
||||
def _init_worker(weights_list):
|
||||
"""Initialize a worker process with its own model copy."""
|
||||
global _worker_model
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
_worker_model = build_model()
|
||||
_worker_model.set_weights(weights_list)
|
||||
|
||||
|
||||
def _play_one_game(_):
|
||||
"""Play a single self-play game in a worker process."""
|
||||
game = ConnectFour()
|
||||
trajectory = []
|
||||
|
||||
while not game.done:
|
||||
state = game.get_state()
|
||||
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
|
||||
|
||||
if game.move_count < TEMP_DROP_MOVE:
|
||||
temp = MCTS_TEMPERATURE
|
||||
else:
|
||||
temp = 0.1
|
||||
|
||||
if temp < 0.2:
|
||||
action = int(np.argmax(visit_counts))
|
||||
policy = np.zeros(7, dtype=np.float32)
|
||||
policy[action] = 1.0
|
||||
else:
|
||||
counts = visit_counts ** (1.0 / temp)
|
||||
policy = counts / counts.sum()
|
||||
action = np.random.choice(7, p=policy)
|
||||
|
||||
trajectory.append((state, policy, game.current_player))
|
||||
game.step(action)
|
||||
|
||||
samples = []
|
||||
for state, policy, player in trajectory:
|
||||
if game.winner == 0:
|
||||
value = DRAW_REWARD
|
||||
elif game.winner == player:
|
||||
value = WIN_REWARD
|
||||
else:
|
||||
value = LOSS_REWARD
|
||||
samples.append((state, policy, value))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def train():
|
||||
"""Main training entry point."""
|
||||
model = build_model()
|
||||
print_model_info(model)
|
||||
|
||||
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
|
||||
print(f"Using {num_workers} worker processes for self-play")
|
||||
|
||||
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
|
||||
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
||||
|
||||
for iteration in range(1, NUM_ITERATIONS + 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Iteration {iteration}/{NUM_ITERATIONS}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# ── Self-play (parallel) ───────────────────────────────
|
||||
weights = model.get_weights()
|
||||
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
|
||||
results = pool.map(_play_one_game, range(GAMES_PER_ITERATION))
|
||||
|
||||
wins = {1: 0, -1: 0, 0: 0}
|
||||
for samples in results:
|
||||
replay_buffer.extend(samples)
|
||||
if samples:
|
||||
last_value = samples[-1][2]
|
||||
if last_value == WIN_REWARD:
|
||||
wins[1] += 1
|
||||
elif last_value == LOSS_REWARD:
|
||||
wins[-1] += 1
|
||||
else:
|
||||
wins[0] += 1
|
||||
|
||||
print(f" Self-play: {GAMES_PER_ITERATION} games "
|
||||
f"(P1 wins: {wins[1]}, P2 wins: {wins[-1]}, draws: {wins[0]})")
|
||||
print(f" Buffer size: {len(replay_buffer)}")
|
||||
|
||||
# ── Train ───────────────────────────────────────────────
|
||||
if len(replay_buffer) >= BATCH_SIZE:
|
||||
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
|
||||
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
|
||||
batch = [replay_buffer[i] for i in indices]
|
||||
|
||||
states = np.array([s[0] for s in batch])
|
||||
policies = np.array([s[1] for s in batch])
|
||||
values = np.array([s[2] for s in batch]).reshape(-1, 1)
|
||||
|
||||
history = model.fit(
|
||||
states,
|
||||
{"policy_logits": policies, "value": values},
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS_PER_ITERATION,
|
||||
verbose=1,
|
||||
)
|
||||
policy_loss = history.history["policy_logits_loss"][-1]
|
||||
value_loss = history.history["value_loss"][-1]
|
||||
print(f" Policy loss: {policy_loss:.4f} Value loss: {value_loss:.4f}")
|
||||
|
||||
# ── Checkpoint ──────────────────────────────────────────
|
||||
if iteration % CHECKPOINT_INTERVAL == 0:
|
||||
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
|
||||
model.save(path)
|
||||
print(f" Saved checkpoint: {path}")
|
||||
|
||||
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
|
||||
model.save(final_path)
|
||||
print(f"\nTraining complete. Final model saved to {final_path}")
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
+482
@@ -0,0 +1,482 @@
|
||||
"""Pygame visualization of Connect Four RL training.
|
||||
|
||||
Left panel: live self-play game board
|
||||
Right panel: loss curves + win-rate chart + training stats
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
|
||||
from .game import ConnectFour, ROWS, COLS
|
||||
from .model import build_model, print_model_info
|
||||
from .mcts import run_mcts
|
||||
from .config import (
|
||||
NUM_ITERATIONS, GAMES_PER_ITERATION, MCTS_SIMULATIONS,
|
||||
MCTS_TEMPERATURE, TEMP_DROP_MOVE,
|
||||
WIN_REWARD, DRAW_REWARD, LOSS_REWARD,
|
||||
BATCH_SIZE, EPOCHS_PER_ITERATION, REPLAY_BUFFER_SIZE,
|
||||
CHECKPOINT_DIR, CHECKPOINT_INTERVAL, NUM_WORKERS,
|
||||
)
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
# ── Layout constants ────────────────────────────────────────────────
|
||||
CELL = 80
|
||||
BOARD_W = COLS * CELL
|
||||
BOARD_H = ROWS * CELL
|
||||
PANEL_W = 420
|
||||
MARGIN = 20
|
||||
WIN_W = BOARD_W + PANEL_W + MARGIN * 3
|
||||
WIN_H = BOARD_H + MARGIN * 2
|
||||
FPS = 30
|
||||
|
||||
# ── Colors ──────────────────────────────────────────────────────────
|
||||
BG = (30, 30, 40)
|
||||
BOARD_BG = (0, 60, 180)
|
||||
EMPTY = (20, 20, 30)
|
||||
P1_COLOR = (255, 220, 50) # yellow
|
||||
P2_COLOR = (220, 40, 40) # red
|
||||
WIN_HIGHLIGHT = (100, 255, 100)
|
||||
GRID_LINE = (0, 40, 140)
|
||||
TEXT_COLOR = (220, 220, 220)
|
||||
CHART_BG = (40, 40, 55)
|
||||
POLICY_LINE = (80, 200, 255)
|
||||
VALUE_LINE = (255, 160, 60)
|
||||
P1_CHART = (255, 220, 50)
|
||||
P2_CHART = (220, 40, 40)
|
||||
DRAW_CHART = (140, 140, 140)
|
||||
|
||||
# ── Shared state between training thread and pygame loop ────────────
|
||||
_state = {
|
||||
"board": np.zeros((ROWS, COLS), dtype=np.int8),
|
||||
"iteration": 0,
|
||||
"game_num": 0,
|
||||
"phase": "init", # init / self-play / training / done
|
||||
"policy_losses": [],
|
||||
"value_losses": [],
|
||||
"win_history": [], # list of (p1_wins, p2_wins, draws) per iteration
|
||||
"move_delay": 0.3,
|
||||
"status": "Initializing...",
|
||||
"winner": 0,
|
||||
"running": True,
|
||||
}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
# ── Worker setup (same as train.py) ─────────────────────────────────
|
||||
_worker_model = None
|
||||
|
||||
|
||||
def _init_worker(weights_list):
|
||||
global _worker_model
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
_worker_model = build_model()
|
||||
_worker_model.set_weights(weights_list)
|
||||
|
||||
|
||||
def _play_one_game(_):
|
||||
game = ConnectFour()
|
||||
trajectory = []
|
||||
while not game.done:
|
||||
state = game.get_state()
|
||||
visit_counts = run_mcts(game, _worker_model, MCTS_SIMULATIONS)
|
||||
if game.move_count < TEMP_DROP_MOVE:
|
||||
temp = MCTS_TEMPERATURE
|
||||
else:
|
||||
temp = 0.1
|
||||
if temp < 0.2:
|
||||
action = int(np.argmax(visit_counts))
|
||||
policy = np.zeros(7, dtype=np.float32)
|
||||
policy[action] = 1.0
|
||||
else:
|
||||
counts = visit_counts ** (1.0 / temp)
|
||||
policy = counts / counts.sum()
|
||||
action = np.random.choice(7, p=policy)
|
||||
trajectory.append((state, policy, game.current_player))
|
||||
game.step(action)
|
||||
samples = []
|
||||
for state, policy, player in trajectory:
|
||||
if game.winner == 0:
|
||||
value = DRAW_REWARD
|
||||
elif game.winner == player:
|
||||
value = WIN_REWARD
|
||||
else:
|
||||
value = LOSS_REWARD
|
||||
samples.append((state, policy, value))
|
||||
return samples
|
||||
|
||||
|
||||
def _play_showcase_game(model):
|
||||
"""Play one game slowly on the main training thread, updating shared state."""
|
||||
game = ConnectFour()
|
||||
trajectory = []
|
||||
|
||||
with _lock:
|
||||
_state["board"] = game.board.copy()
|
||||
_state["winner"] = 0
|
||||
|
||||
while not game.done and _state["running"]:
|
||||
state = game.get_state()
|
||||
visit_counts = run_mcts(game, model, MCTS_SIMULATIONS)
|
||||
|
||||
if game.move_count < TEMP_DROP_MOVE:
|
||||
temp = MCTS_TEMPERATURE
|
||||
else:
|
||||
temp = 0.1
|
||||
if temp < 0.2:
|
||||
action = int(np.argmax(visit_counts))
|
||||
policy = np.zeros(7, dtype=np.float32)
|
||||
policy[action] = 1.0
|
||||
else:
|
||||
counts = visit_counts ** (1.0 / temp)
|
||||
policy = counts / counts.sum()
|
||||
action = np.random.choice(7, p=policy)
|
||||
|
||||
trajectory.append((state, policy, game.current_player))
|
||||
game.step(action)
|
||||
|
||||
with _lock:
|
||||
_state["board"] = game.board.copy()
|
||||
|
||||
time.sleep(_state["move_delay"])
|
||||
|
||||
with _lock:
|
||||
_state["winner"] = game.winner
|
||||
|
||||
samples = []
|
||||
for state, policy, player in trajectory:
|
||||
if game.winner == 0:
|
||||
value = DRAW_REWARD
|
||||
elif game.winner == player:
|
||||
value = WIN_REWARD
|
||||
else:
|
||||
value = LOSS_REWARD
|
||||
samples.append((state, policy, value))
|
||||
return samples
|
||||
|
||||
|
||||
def _training_thread():
|
||||
"""Run the full training loop, pushing updates to shared state."""
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
model = build_model()
|
||||
print_model_info(model)
|
||||
|
||||
num_workers = NUM_WORKERS if NUM_WORKERS > 0 else cpu_count()
|
||||
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
|
||||
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
||||
|
||||
with _lock:
|
||||
_state["status"] = f"Using {num_workers} workers"
|
||||
|
||||
for iteration in range(1, NUM_ITERATIONS + 1):
|
||||
if not _state["running"]:
|
||||
break
|
||||
|
||||
with _lock:
|
||||
_state["iteration"] = iteration
|
||||
_state["phase"] = "self-play"
|
||||
_state["status"] = f"Iteration {iteration}/{NUM_ITERATIONS} - Self-play"
|
||||
|
||||
# Play one showcase game visually
|
||||
with _lock:
|
||||
_state["game_num"] = 0
|
||||
showcase_samples = _play_showcase_game(model)
|
||||
replay_buffer.extend(showcase_samples)
|
||||
|
||||
# Play remaining games in parallel
|
||||
remaining = GAMES_PER_ITERATION - 1
|
||||
if remaining > 0 and _state["running"]:
|
||||
with _lock:
|
||||
_state["status"] = f"Iter {iteration} - Playing {remaining} games (parallel)..."
|
||||
|
||||
weights = model.get_weights()
|
||||
with Pool(processes=num_workers, initializer=_init_worker, initargs=(weights,)) as pool:
|
||||
results = pool.map(_play_one_game, range(remaining))
|
||||
|
||||
for samples in results:
|
||||
replay_buffer.extend(samples)
|
||||
|
||||
# Count wins across all games this iteration
|
||||
wins = {1: 0, -1: 0, 0: 0}
|
||||
# Showcase game
|
||||
if showcase_samples:
|
||||
last_val = showcase_samples[-1][2]
|
||||
if last_val == WIN_REWARD:
|
||||
wins[1] += 1
|
||||
elif last_val == LOSS_REWARD:
|
||||
wins[-1] += 1
|
||||
else:
|
||||
wins[0] += 1
|
||||
# Parallel games
|
||||
if remaining > 0 and _state["running"]:
|
||||
for samples in results:
|
||||
if samples:
|
||||
last_val = samples[-1][2]
|
||||
if last_val == WIN_REWARD:
|
||||
wins[1] += 1
|
||||
elif last_val == LOSS_REWARD:
|
||||
wins[-1] += 1
|
||||
else:
|
||||
wins[0] += 1
|
||||
|
||||
with _lock:
|
||||
_state["win_history"].append((wins[1], wins[-1], wins[0]))
|
||||
|
||||
# Train
|
||||
if len(replay_buffer) >= BATCH_SIZE and _state["running"]:
|
||||
with _lock:
|
||||
_state["phase"] = "training"
|
||||
_state["status"] = f"Iter {iteration} - Training..."
|
||||
|
||||
sample_size = min(len(replay_buffer), BATCH_SIZE * EPOCHS_PER_ITERATION)
|
||||
indices = np.random.choice(len(replay_buffer), size=sample_size, replace=False)
|
||||
batch = [replay_buffer[i] for i in indices]
|
||||
|
||||
states = np.array([s[0] for s in batch])
|
||||
policies = np.array([s[1] for s in batch])
|
||||
values = np.array([s[2] for s in batch]).reshape(-1, 1)
|
||||
|
||||
history = model.fit(
|
||||
states,
|
||||
{"policy_logits": policies, "value": values},
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS_PER_ITERATION,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
with _lock:
|
||||
_state["policy_losses"].append(history.history["policy_logits_loss"][-1])
|
||||
_state["value_losses"].append(history.history["value_loss"][-1])
|
||||
|
||||
# Checkpoint
|
||||
if iteration % CHECKPOINT_INTERVAL == 0:
|
||||
path = os.path.join(CHECKPOINT_DIR, f"model_iter{iteration}.keras")
|
||||
model.save(path)
|
||||
|
||||
if _state["running"]:
|
||||
final_path = os.path.join(CHECKPOINT_DIR, "model_final.keras")
|
||||
model.save(final_path)
|
||||
|
||||
with _lock:
|
||||
_state["phase"] = "done"
|
||||
_state["status"] = "Training complete!"
|
||||
|
||||
|
||||
# ── Drawing helpers ─────────────────────────────────────────────────
|
||||
|
||||
def _draw_board(surface, board, x0, y0):
|
||||
"""Draw the Connect Four board."""
|
||||
# Board background
|
||||
pygame.draw.rect(surface, BOARD_BG, (x0, y0, BOARD_W, BOARD_H), border_radius=8)
|
||||
|
||||
for r in range(ROWS):
|
||||
for c in range(COLS):
|
||||
cx = x0 + c * CELL + CELL // 2
|
||||
cy = y0 + r * CELL + CELL // 2
|
||||
radius = CELL // 2 - 6
|
||||
|
||||
val = board[r, c]
|
||||
if val == 1:
|
||||
color = P1_COLOR
|
||||
elif val == -1:
|
||||
color = P2_COLOR
|
||||
else:
|
||||
color = EMPTY
|
||||
|
||||
pygame.draw.circle(surface, color, (cx, cy), radius)
|
||||
pygame.draw.circle(surface, GRID_LINE, (cx, cy), radius, 2)
|
||||
|
||||
|
||||
def _draw_chart(surface, x, y, w, h, series_list, colors, title, font):
|
||||
"""Draw a simple line chart with multiple series."""
|
||||
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
|
||||
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
|
||||
|
||||
# Title
|
||||
title_surf = font.render(title, True, TEXT_COLOR)
|
||||
surface.blit(title_surf, (x + 8, y + 4))
|
||||
|
||||
chart_x = x + 8
|
||||
chart_y = y + 24
|
||||
chart_w = w - 16
|
||||
chart_h = h - 32
|
||||
|
||||
if not any(series_list):
|
||||
return
|
||||
|
||||
# Find global min/max
|
||||
all_vals = [v for s in series_list if s for v in s]
|
||||
if not all_vals:
|
||||
return
|
||||
min_val = min(all_vals)
|
||||
max_val = max(all_vals)
|
||||
val_range = max_val - min_val if max_val != min_val else 1.0
|
||||
|
||||
for series, color in zip(series_list, colors):
|
||||
if len(series) < 2:
|
||||
continue
|
||||
points = []
|
||||
for i, v in enumerate(series):
|
||||
px = chart_x + int(i / (len(series) - 1) * chart_w)
|
||||
py = chart_y + chart_h - int((v - min_val) / val_range * chart_h)
|
||||
points.append((px, py))
|
||||
pygame.draw.lines(surface, color, False, points, 2)
|
||||
|
||||
|
||||
def _draw_stacked_bar(surface, x, y, w, h, win_history, font):
|
||||
"""Draw stacked bar chart of win rates."""
|
||||
pygame.draw.rect(surface, CHART_BG, (x, y, w, h), border_radius=6)
|
||||
pygame.draw.rect(surface, (60, 60, 75), (x, y, w, h), 1, border_radius=6)
|
||||
|
||||
title_surf = font.render("Win rates per iteration", True, TEXT_COLOR)
|
||||
surface.blit(title_surf, (x + 8, y + 4))
|
||||
|
||||
if not win_history:
|
||||
return
|
||||
|
||||
chart_x = x + 8
|
||||
chart_y = y + 24
|
||||
chart_w = w - 16
|
||||
chart_h = h - 48
|
||||
|
||||
n = len(win_history)
|
||||
bar_w = max(2, chart_w // max(n, 1))
|
||||
|
||||
for i, (p1, p2, dr) in enumerate(win_history):
|
||||
total = p1 + p2 + dr
|
||||
if total == 0:
|
||||
continue
|
||||
bx = chart_x + int(i / max(n, 1) * chart_w)
|
||||
|
||||
# Stack: P1 (bottom), draws (middle), P2 (top)
|
||||
h1 = int(p1 / total * chart_h)
|
||||
hd = int(dr / total * chart_h)
|
||||
h2 = chart_h - h1 - hd
|
||||
|
||||
by = chart_y
|
||||
pygame.draw.rect(surface, P2_CHART, (bx, by, bar_w - 1, h2))
|
||||
by += h2
|
||||
pygame.draw.rect(surface, DRAW_CHART, (bx, by, bar_w - 1, hd))
|
||||
by += hd
|
||||
pygame.draw.rect(surface, P1_CHART, (bx, by, bar_w - 1, h1))
|
||||
|
||||
# Legend
|
||||
ly = y + h - 18
|
||||
for label, color, lx in [("P1", P1_CHART, x + 8), ("Draw", DRAW_CHART, x + 70), ("P2", P2_CHART, x + 150)]:
|
||||
pygame.draw.rect(surface, color, (lx, ly, 12, 12))
|
||||
surface.blit(font.render(label, True, TEXT_COLOR), (lx + 16, ly - 2))
|
||||
|
||||
|
||||
def run_visualized():
|
||||
"""Launch pygame window and run training with live visualization."""
|
||||
pygame.init()
|
||||
screen = pygame.display.set_mode((WIN_W, WIN_H))
|
||||
pygame.display.set_caption("Connect Four RL Training")
|
||||
clock = pygame.time.Clock()
|
||||
font = pygame.font.SysFont("monospace", 14)
|
||||
font_big = pygame.font.SysFont("monospace", 18, bold=True)
|
||||
|
||||
# Start training in background thread
|
||||
train_thread = threading.Thread(target=_training_thread, daemon=True)
|
||||
train_thread.start()
|
||||
|
||||
running = True
|
||||
while running:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
running = False
|
||||
_state["running"] = False
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
running = False
|
||||
_state["running"] = False
|
||||
elif event.key == pygame.K_UP:
|
||||
_state["move_delay"] = max(0.05, _state["move_delay"] - 0.05)
|
||||
elif event.key == pygame.K_DOWN:
|
||||
_state["move_delay"] = min(2.0, _state["move_delay"] + 0.05)
|
||||
|
||||
screen.fill(BG)
|
||||
|
||||
with _lock:
|
||||
board = _state["board"].copy()
|
||||
iteration = _state["iteration"]
|
||||
phase = _state["phase"]
|
||||
status = _state["status"]
|
||||
policy_losses = list(_state["policy_losses"])
|
||||
value_losses = list(_state["value_losses"])
|
||||
win_history = list(_state["win_history"])
|
||||
winner = _state["winner"]
|
||||
delay = _state["move_delay"]
|
||||
|
||||
# ── Left: game board ────────────────────────────────────
|
||||
bx, by = MARGIN, MARGIN
|
||||
_draw_board(screen, board, bx, by)
|
||||
|
||||
# Winner overlay
|
||||
if winner != 0 and phase == "self-play":
|
||||
label = f"Player {1 if winner == 1 else 2} wins!"
|
||||
color = P1_COLOR if winner == 1 else P2_COLOR
|
||||
win_surf = font_big.render(label, True, color)
|
||||
wrect = win_surf.get_rect(center=(bx + BOARD_W // 2, by + BOARD_H + 2))
|
||||
if wrect.bottom < WIN_H:
|
||||
screen.blit(win_surf, wrect)
|
||||
|
||||
# ── Right panel ────────────────────────────────────────
|
||||
px = BOARD_W + MARGIN * 2
|
||||
py = MARGIN
|
||||
|
||||
# Status
|
||||
status_surf = font_big.render(status, True, TEXT_COLOR)
|
||||
screen.blit(status_surf, (px, py))
|
||||
py += 28
|
||||
|
||||
iter_surf = font.render(f"Iteration: {iteration}/{NUM_ITERATIONS} Phase: {phase}", True, TEXT_COLOR)
|
||||
screen.blit(iter_surf, (px, py))
|
||||
py += 20
|
||||
|
||||
delay_surf = font.render(f"Move delay: {delay:.2f}s (Up/Down to adjust)", True, (150, 150, 170))
|
||||
screen.blit(delay_surf, (px, py))
|
||||
py += 28
|
||||
|
||||
# Loss chart
|
||||
chart_h = 140
|
||||
_draw_chart(
|
||||
screen, px, py, PANEL_W, chart_h,
|
||||
[policy_losses, value_losses],
|
||||
[POLICY_LINE, VALUE_LINE],
|
||||
"Loss (blue=policy, orange=value)",
|
||||
font,
|
||||
)
|
||||
py += chart_h + 12
|
||||
|
||||
# Win rate chart
|
||||
bar_h = 160
|
||||
_draw_stacked_bar(screen, px, py, PANEL_W, bar_h, win_history, font)
|
||||
py += bar_h + 12
|
||||
|
||||
# Latest stats
|
||||
if policy_losses:
|
||||
pl = font.render(f"Policy loss: {policy_losses[-1]:.4f}", True, POLICY_LINE)
|
||||
screen.blit(pl, (px, py))
|
||||
py += 18
|
||||
if value_losses:
|
||||
vl = font.render(f"Value loss: {value_losses[-1]:.4f}", True, VALUE_LINE)
|
||||
screen.blit(vl, (px, py))
|
||||
py += 18
|
||||
if win_history:
|
||||
p1, p2, dr = win_history[-1]
|
||||
ws = font.render(f"Last iter: P1={p1} P2={p2} Draw={dr}", True, TEXT_COLOR)
|
||||
screen.blit(ws, (px, py))
|
||||
|
||||
pygame.display.flip()
|
||||
clock.tick(FPS)
|
||||
|
||||
pygame.quit()
|
||||
_state["running"] = False
|
||||
train_thread.join(timeout=5)
|
||||
+47
-4
@@ -121,6 +121,7 @@ bool checkGameEnd();
|
||||
void updateThinkingVisuals(int8_t pColor, int8_t column);
|
||||
void animateDrop(int col, int player);
|
||||
void moveDiscToCol(int startCol, int targetCol, int player, int speed);
|
||||
int evaluateBoard(int8_t aiP, int8_t huP);
|
||||
int minimax(int depth, int alpha, int beta, bool isMax, int8_t aiP, int8_t huP, int8_t rootCol);
|
||||
void performAiMove(int8_t aiP);
|
||||
void randomizeDemoPlies();
|
||||
@@ -265,6 +266,39 @@ int8_t scanBoard() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int evaluateBoard(int8_t aiP, int8_t huP) {
|
||||
int score = 0;
|
||||
|
||||
// Center column bonus
|
||||
for (int r = 0; r < ROWS; r++) {
|
||||
if (board[3][r] == aiP) score += 3;
|
||||
else if (board[3][r] == huP) score -= 3;
|
||||
}
|
||||
|
||||
// Score a window of 4 cells by piece counts
|
||||
auto scoreWindow = [&](int c, int r, int dc, int dr) -> int {
|
||||
int ai = 0, hu = 0;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int8_t v = board[c + i * dc][r + i * dr];
|
||||
if (v == aiP) ai++;
|
||||
else if (v == huP) hu++;
|
||||
}
|
||||
if (ai > 0 && hu > 0) return 0;
|
||||
if (ai == 3) return 50;
|
||||
if (ai == 2) return 5;
|
||||
if (hu == 3) return -50;
|
||||
if (hu == 2) return -5;
|
||||
return 0;
|
||||
};
|
||||
|
||||
for (int r = 0; r < 6; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, 0);
|
||||
for (int r = 0; r < 3; r++) for (int c = 0; c < 7; c++) score += scoreWindow(c, r, 0, 1);
|
||||
for (int r = 0; r < 3; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, 1);
|
||||
for (int r = 3; r < 6; r++) for (int c = 0; c < 4; c++) score += scoreWindow(c, r, 1, -1);
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
bool checkGameEnd() {
|
||||
winnerPlayer = scanBoard();
|
||||
bool won = winnerPlayer != 0;
|
||||
@@ -331,7 +365,7 @@ int minimax(int depth, int alpha, int beta, bool isMax, int8_t aiP, int8_t huP,
|
||||
int8_t win = scanBoard();
|
||||
if (win == aiP) return 1000 + depth;
|
||||
if (win == huP) return -1000 - depth;
|
||||
if (depth == 0 || isBoardFull()) return 0;
|
||||
if (depth == 0 || isBoardFull()) return evaluateBoard(aiP, huP);
|
||||
|
||||
int best = isMax ? -10000 : 10000;
|
||||
for (int c : colOrder) {
|
||||
@@ -356,13 +390,22 @@ void performAiMove(int8_t aiP) {
|
||||
int originalPly = currentLookAhead;
|
||||
if (gameState == DEMO) currentLookAhead = demoPly[aiP - 1];
|
||||
|
||||
// Phase 1: always take an instant win or block an opponent's win
|
||||
// Phase 1a: check ALL columns for instant AI win
|
||||
bool found = false;
|
||||
for (int c = 0; c < COLS && !found; c++) {
|
||||
int r = getFirstEmptyRow(c);
|
||||
if (r != -1) {
|
||||
board[c][r] = aiP; if (scanBoard() == aiP) { board[c][r]=0; bestCol=c; found=true; break; }
|
||||
board[c][r] = huP; if (scanBoard() == huP) { board[c][r]=0; bestCol=c; found=true; break; }
|
||||
board[c][r] = aiP;
|
||||
if (scanBoard() == aiP) { board[c][r] = 0; bestCol = c; found = true; break; }
|
||||
board[c][r] = 0;
|
||||
}
|
||||
}
|
||||
// Phase 1b: check ALL columns for opponent block
|
||||
for (int c = 0; c < COLS && !found; c++) {
|
||||
int r = getFirstEmptyRow(c);
|
||||
if (r != -1) {
|
||||
board[c][r] = huP;
|
||||
if (scanBoard() == huP) { board[c][r] = 0; bestCol = c; found = true; break; }
|
||||
board[c][r] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user