[fix] Non heuristic moves...
This commit is contained in:
+103
@@ -0,0 +1,103 @@
|
||||
"""Monte Carlo Tree Search for self-play data generation."""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from .game import ConnectFour
|
||||
from .config import MCTS_C_PUCT
|
||||
|
||||
|
||||
class MCTSNode:
|
||||
__slots__ = ("parent", "action", "prior", "visit_count", "value_sum", "children", "game")
|
||||
|
||||
def __init__(self, game, parent=None, action=None, prior=0.0):
|
||||
self.game = game
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.prior = prior
|
||||
self.visit_count = 0
|
||||
self.value_sum = 0.0
|
||||
self.children = {}
|
||||
|
||||
@property
|
||||
def q_value(self):
|
||||
if self.visit_count == 0:
|
||||
return 0.0
|
||||
return self.value_sum / self.visit_count
|
||||
|
||||
def ucb_score(self):
|
||||
parent_visits = self.parent.visit_count if self.parent else 1
|
||||
exploration = MCTS_C_PUCT * self.prior * math.sqrt(parent_visits) / (1 + self.visit_count)
|
||||
return self.q_value + exploration
|
||||
|
||||
def is_leaf(self):
|
||||
return len(self.children) == 0
|
||||
|
||||
def expand(self, policy_probs):
|
||||
"""Expand node using network policy output."""
|
||||
legal = self.game.legal_moves()
|
||||
for col in legal:
|
||||
if col not in self.children:
|
||||
self.children[col] = MCTSNode(
|
||||
game=None, parent=self, action=col, prior=policy_probs[col]
|
||||
)
|
||||
|
||||
def select_child(self):
|
||||
return max(self.children.values(), key=lambda c: c.ucb_score())
|
||||
|
||||
|
||||
def run_mcts(game, model, num_simulations):
|
||||
"""Run MCTS from current game state, return visit-count policy vector."""
|
||||
root = MCTSNode(game.clone())
|
||||
|
||||
# Evaluate root
|
||||
state = root.game.get_state()
|
||||
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
||||
policy = _mask_and_normalize(policy_logits[0], root.game.legal_moves_mask())
|
||||
root.expand(policy)
|
||||
|
||||
for _ in range(num_simulations):
|
||||
node = root
|
||||
sim_game = root.game.clone()
|
||||
|
||||
# SELECT — walk down tree picking best UCB child
|
||||
while not node.is_leaf() and not sim_game.done:
|
||||
node = node.select_child()
|
||||
sim_game.step(node.action)
|
||||
|
||||
# EVALUATE leaf
|
||||
if sim_game.done:
|
||||
# Terminal: value from perspective of player who just moved
|
||||
if sim_game.winner == 0:
|
||||
leaf_value = 0.0
|
||||
else:
|
||||
# The winner is sim_game.winner; current_player already switched
|
||||
leaf_value = -1.0 # current player lost (winner was previous player)
|
||||
else:
|
||||
node.game = sim_game.clone()
|
||||
state = sim_game.get_state()
|
||||
policy_logits, value = model.predict(state[np.newaxis], verbose=0)
|
||||
leaf_value = value[0, 0]
|
||||
policy = _mask_and_normalize(policy_logits[0], sim_game.legal_moves_mask())
|
||||
node.expand(policy)
|
||||
|
||||
# BACKUP — propagate value up, flipping sign each level
|
||||
while node is not None:
|
||||
node.visit_count += 1
|
||||
node.value_sum += leaf_value
|
||||
leaf_value = -leaf_value
|
||||
node = node.parent
|
||||
|
||||
# Build policy from visit counts
|
||||
visits = np.zeros(7, dtype=np.float32)
|
||||
for col, child in root.children.items():
|
||||
visits[col] = child.visit_count
|
||||
return visits
|
||||
|
||||
|
||||
def _mask_and_normalize(logits, mask):
|
||||
"""Apply legal-move mask and softmax."""
|
||||
logits = np.array(logits, dtype=np.float64)
|
||||
logits[mask == 0] = -1e9
|
||||
exp = np.exp(logits - np.max(logits))
|
||||
probs = exp / np.sum(exp)
|
||||
return probs.astype(np.float32)
|
||||
Reference in New Issue
Block a user