104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
"""Monte Carlo Tree Search for self-play data generation."""
|
|
|
|
import math
|
|
import numpy as np
|
|
from .game import ConnectFour
|
|
from .config import MCTS_C_PUCT
|
|
|
|
|
|
class MCTSNode:
|
|
__slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game")
|
|
|
|
def __init__(self, game, parent=None, action=None, prior=0.0):
|
|
self.game = game
|
|
self.parent = parent
|
|
self.action = action
|
|
self.prior = prior
|
|
self.visit_count = 0
|
|
self.value_sum = 0.0
|
|
self.children = {}
|
|
|
|
@property
|
|
def q_value(self):
|
|
if self.visit_count == 0:
|
|
return 0.0
|
|
return self.value_sum / self.visit_count
|
|
|
|
def ucb_score(self):
|
|
parent_visits = self.parent.visit_count if self.parent else 1
|
|
exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count)
|
|
return self.q_value + exploration
|
|
|
|
def is_leaf(self):
|
|
return len(self.children) == 0
|
|
|
|
def expand(self, policy_probs):
|
|
"""Expand node using network policy output."""
|
|
legal = self.game.legal_moves()
|
|
for col in legal:
|
|
if col not in self.children:
|
|
self.children[col] = MCTSNode(
|
|
game=None, parent=self, action=col, prior=policy_probs[col]
|
|
)
|
|
|
|
def select_child(self):
|
|
return max(self.children.values(), key=lambda c: c.ucb_score())
|
|
|
|
|
|
def run_mcts(game, model, num_simulations):
|
|
"""Run MCTS from current game state, return visit-count policy vector."""
|
|
root = MCTSNode(game.clone())
|
|
|
|
# Evaluate root
|
|
state = root.game.get_state()
|
|
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
|
policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask())
|
|
root.expand(policy)
|
|
|
|
for _ in range(num_simulations):
|
|
node = root
|
|
sim_game = root.game.clone()
|
|
|
|
# SELECT — walk down tree picking best UCB child
|
|
while not node.is_leaf() and not sim_game.done:
|
|
node = node.select_child()
|
|
sim_game.step(node.action)
|
|
|
|
# EVALUATE leaf
|
|
if sim_game.done:
|
|
# Terminal: value from perspective of player who just moved
|
|
if sim_game.winner == 0:
|
|
leaf_value = 0.0
|
|
else:
|
|
# The winner is sim_game.winner; current_player already switched
|
|
leaf_value = -1.0 # current player lost (winner was previous player)
|
|
else:
|
|
node.game = sim_game.clone()
|
|
state = sim_game.get_state()
|
|
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
|
leaf_value = value[0, 0]
|
|
policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask())
|
|
node.expand(policy)
|
|
|
|
# BACKUP — propagate value up, flipping sign each level
|
|
while node is not None:
|
|
node.visit_count += 1
|
|
node.value_sum += leaf_value
|
|
leaf_value = -leaf_value
|
|
node = node.parent
|
|
|
|
# Build policy from visit counts
|
|
visits = np.zeros(7, dtype=np.float32)
|
|
for col, child in root.children.items():
|
|
visits[col] = child.visit_count
|
|
return visits
|
|
|
|
|
|
def _mask_and_normalize(logits, mask):
|
|
"""Apply legal-move mask and softmax."""
|
|
logits = np.array(logits, dtype=np.float64)
|
|
logits[mask == 0] = -1e9
|
|
exp = np.exp(logits - np.max(logits))
|
|
probs = exp / np.sum(exp)
|
|
return probs.astype(np.float32)
|