Files
2026-03-27 12:17:25 +01:00

104 lines
3.4 KiB
Python

"""Monte Carlo Tree Search for self-play data generation."""
import math
import numpy as np
from .game import ConnectFour
from .config import MCTS_C_PUCT
class MCTSNode:
__slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game")
def __init__(self, game, parent=None, action=None, prior=0.0):
self.game = game
self.parent = parent
self.action = action
self.prior = prior
self.visit_count = 0
self.value_sum = 0.0
self.children = {}
@property
def q_value(self):
if self.visit_count == 0:
return 0.0
return self.value_sum / self.visit_count
def ucb_score(self):
parent_visits = self.parent.visit_count if self.parent else 1
exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count)
return self.q_value + exploration
def is_leaf(self):
return len(self.children) == 0
def expand(self, policy_probs):
"""Expand node using network policy output."""
legal = self.game.legal_moves()
for col in legal:
if col not in self.children:
self.children[col] = MCTSNode(
game=None, parent=self, action=col, prior=policy_probs[col]
)
def select_child(self):
return max(self.children.values(), key=lambda c: c.ucb_score())
def run_mcts(game, model, num_simulations):
"""Run MCTS from current game state, return visit-count policy vector."""
root = MCTSNode(game.clone())
# Evaluate root
state = root.game.get_state()
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask())
root.expand(policy)
for _ in range(num_simulations):
node = root
sim_game = root.game.clone()
# SELECT — walk down tree picking best UCB child
while not node.is_leaf() and not sim_game.done:
node = node.select_child()
sim_game.step(node.action)
# EVALUATE leaf
if sim_game.done:
# Terminal: value from perspective of player who just moved
if sim_game.winner == 0:
leaf_value = 0.0
else:
# The winner is sim_game.winner; current_player already switched
leaf_value = -1.0 # current player lost (winner was previous player)
else:
node.game = sim_game.clone()
state = sim_game.get_state()
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
leaf_value = value[0, 0]
policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask())
node.expand(policy)
# BACKUP — propagate value up, flipping sign each level
while node is not None:
node.visit_count += 1
node.value_sum += leaf_value
leaf_value = -leaf_value
node = node.parent
# Build policy from visit counts
visits = np.zeros(7, dtype=np.float32)
for col, child in root.children.items():
visits[col] = child.visit_count
return visits
def _mask_and_normalize(logits, mask):
"""Apply legal-move mask and softmax."""
logits = np.array(logits, dtype=np.float64)
logits[mask == 0] = -1e9
exp = np.exp(logits - np.max(logits))
probs = exp / np.sum(exp)
return probs.astype(np.float32)