[fix] Non heuristic moves...
This commit is contained in:
+102
@@ -0,0 +1,102 @@
|
||||
"""Connect Four game environment for self-play training."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
ROWS = 6
|
||||
COLS = 7
|
||||
WIN_LENGTH = 4
|
||||
|
||||
|
||||
class ConnectFour:
|
||||
"""Connect Four game with numpy board representation.
|
||||
|
||||
Board encoding: 0 = empty, 1 = player 1, -1 = player 2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.board = np.zeros((ROWS, COLS), dtype=np.int8)
|
||||
self.current_player = 1
|
||||
self.done = False
|
||||
self.winner = 0 # 0 = no winner / draw, 1 or -1
|
||||
self.move_count = 0
|
||||
return self.get_state()
|
||||
|
||||
def get_state(self):
|
||||
"""Return board from current player's perspective as (6,7,2) tensor.
|
||||
|
||||
Channel 0: current player's pieces (1s).
|
||||
Channel 1: opponent's pieces (1s).
|
||||
"""
|
||||
state = np.zeros((ROWS, COLS, 2), dtype=np.float32)
|
||||
state[:, :, 0] = (self.board == self.current_player).astype(np.float32)
|
||||
state[:, :, 1] = (self.board == -self.current_player).astype(np.float32)
|
||||
return state
|
||||
|
||||
def legal_moves(self):
|
||||
"""Return list of columns that are not full."""
|
||||
return [c for c in range(COLS) if self.board[0, c] == 0]
|
||||
|
||||
def legal_moves_mask(self):
|
||||
"""Return binary mask of legal columns."""
|
||||
return (self.board[0] == 0).astype(np.float32)
|
||||
|
||||
def step(self, col):
|
||||
"""Play a move in the given column. Returns (state, reward, done)."""
|
||||
if self.done:
|
||||
raise ValueError("Game is already over.")
|
||||
if col < 0 or col >= COLS or self.board[0, col] != 0:
|
||||
raise ValueError(f"Illegal move: column {col}")
|
||||
|
||||
# Drop piece
|
||||
row = self._get_drop_row(col)
|
||||
self.board[row, col] = self.current_player
|
||||
self.move_count += 1
|
||||
|
||||
# Check win
|
||||
if self._check_win(row, col):
|
||||
self.done = True
|
||||
self.winner = self.current_player
|
||||
reward = 1.0
|
||||
elif self.move_count == ROWS * COLS:
|
||||
self.done = True
|
||||
self.winner = 0
|
||||
reward = 0.0
|
||||
else:
|
||||
reward = 0.0
|
||||
|
||||
# Switch player
|
||||
self.current_player *= -1
|
||||
return self.get_state(), reward, self.done
|
||||
|
||||
def _get_drop_row(self, col):
|
||||
for r in range(ROWS - 1, -1, -1):
|
||||
if self.board[r, col] == 0:
|
||||
return r
|
||||
raise ValueError(f"Column {col} is full")
|
||||
|
||||
def _check_win(self, row, col):
|
||||
player = self.board[row, col]
|
||||
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
|
||||
for dr, dc in directions:
|
||||
count = 1
|
||||
for sign in (1, -1):
|
||||
r, c = row + sign * dr, col + sign * dc
|
||||
while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player:
|
||||
count += 1
|
||||
r += sign * dr
|
||||
c += sign * dc
|
||||
if count >= WIN_LENGTH:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone(self):
|
||||
g = ConnectFour()
|
||||
g.board = self.board.copy()
|
||||
g.current_player = self.current_player
|
||||
g.done = self.done
|
||||
g.winner = self.winner
|
||||
g.move_count = self.move_count
|
||||
return g
|
||||
Reference in New Issue
Block a user