"""Monte Carlo Tree Search for self-play data generation.""" import math import numpy as np from .game import ConnectFour from .config import MCTS_C_PUCT class MCTSNode: __slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game") def __init__(self, game, parent=None, action=None, prior=0.0): self.game = game self.parent = parent self.action = action self.prior = prior self.visit_count = 0 self.value_sum = 0.0 self.children = {} @property def q_value(self): if self.visit_count == 0: return 0.0 return self.value_sum / self.visit_count def ucb_score(self): parent_visits = self.parent.visit_count if self.parent else 1 exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count) return self.q_value + exploration def is_leaf(self): return len(self.children) == 0 def expand(self, policy_probs): """Expand node using network policy output.""" legal = self.game.legal_moves() for col in legal: if col not in self.children: self.children[col] = MCTSNode( game=None, parent=self, action=col, prior=policy_probs[col] ) def select_child(self): return max(self.children.values(), key=lambda c: c.ucb_score()) def run_mcts(game, model, num_simulations): """Run MCTS from current game state, return visit-count policy vector.""" root = MCTSNode(game.clone()) # Evaluate root state = root.game.get_state() policy_logits, value = model.predict(state[np.newaxis], verbose=0) policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask()) root.expand(policy) for _ in range(num_simulations): node = root sim_game = root.game.clone() # SELECT — walk down tree picking best UCB child while not node.is_leaf() and not sim_game.done: node = node.select_child() sim_game.step(node.action) # EVALUATE leaf if sim_game.done: # Terminal: value from perspective of player who just moved if sim_game.winner == 0: leaf_value = 0.0 else: # The winner is sim_game.winner; current_player already switched leaf_value = -1.0 # current player lost (winner was previous player) else: node.game = sim_game.clone() state = sim_game.get_state() policy_logits, value = model.predict(state[np.newaxis], verbose=0) leaf_value = value[0, 0] policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask()) node.expand(policy) # BACKUP — propagate value up, flipping sign each level while node is not None: node.visit_count += 1 node.value_sum += leaf_value leaf_value = -leaf_value node = node.parent # Build policy from visit counts visits = np.zeros(7, dtype=np.float32) for col, child in root.children.items(): visits[col] = child.visit_count return visits def _mask_and_normalize(logits, mask): """Apply legal-move mask and softmax.""" logits = np.array(logits, dtype=np.float64) logits[mask == 0] = -1e9 exp = np.exp(logits - np.max(logits)) probs = exp / np.sum(exp) return probs.astype(np.float32)