103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
"""Connect Four game environment for self-play training."""
|
|
|
|
import numpy as np
|
|
|
|
ROWS = 6
|
|
COLS = 7
|
|
WIN_LENGTH = 4
|
|
|
|
|
|
class ConnectFour:
|
|
"""Connect Four game with numpy board representation.
|
|
|
|
Board encoding: 0 = empty, 1 = player 1, -1 = player 2.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.board = np.zeros((ROWS, COLS), dtype=np.int8)
|
|
self.current_player = 1
|
|
self.done = False
|
|
self.winner = 0 # 0 = no winner / draw, 1 or -1
|
|
self.move_count = 0
|
|
return self.get_state()
|
|
|
|
def get_state(self):
|
|
"""Return board from current player's perspective as (6,7,2) tensor.
|
|
|
|
Channel 0: current player's pieces (1s).
|
|
Channel 1: opponent's pieces (1s).
|
|
"""
|
|
state = np.zeros((ROWS, COLS, 2), dtype=np.float32)
|
|
state[:, :, 0] = (self.board == self.current_player).astype(np.float32)
|
|
state[:, :, 1] = (self.board == -self.current_player).astype(np.float32)
|
|
return state
|
|
|
|
def legal_moves(self):
|
|
"""Return list of columns that are not full."""
|
|
return [c for c in range(COLS) if self.board[0, c] == 0]
|
|
|
|
def legal_moves_mask(self):
|
|
"""Return binary mask of legal columns."""
|
|
return (self.board[0] == 0).astype(np.float32)
|
|
|
|
def step(self, col):
|
|
"""Play a move in the given column. Returns (state, reward, done)."""
|
|
if self.done:
|
|
raise ValueError("Game is already over.")
|
|
if col < 0 or col >= COLS or self.board[0, col] != 0:
|
|
raise ValueError(f"Illegal move: column {col}")
|
|
|
|
# Drop piece
|
|
row = self._get_drop_row(col)
|
|
self.board[row, col] = self.current_player
|
|
self.move_count += 1
|
|
|
|
# Check win
|
|
if self._check_win(row, col):
|
|
self.done = True
|
|
self.winner = self.current_player
|
|
reward = 1.0
|
|
elif self.move_count == ROWS * COLS:
|
|
self.done = True
|
|
self.winner = 0
|
|
reward = 0.0
|
|
else:
|
|
reward = 0.0
|
|
|
|
# Switch player
|
|
self.current_player *= -1
|
|
return self.get_state(), reward, self.done
|
|
|
|
def _get_drop_row(self, col):
|
|
for r in range(ROWS - 1, -1, -1):
|
|
if self.board[r, col] == 0:
|
|
return r
|
|
raise ValueError(f"Column {col} is full")
|
|
|
|
def _check_win(self, row, col):
|
|
player = self.board[row, col]
|
|
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
|
|
for dr, dc in directions:
|
|
count = 1
|
|
for sign in (1, -1):
|
|
r, c = row + sign * dr, col + sign * dc
|
|
while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player:
|
|
count += 1
|
|
r += sign * dr
|
|
c += sign * dc
|
|
if count >= WIN_LENGTH:
|
|
return True
|
|
return False
|
|
|
|
def clone(self):
|
|
g = ConnectFour()
|
|
g.board = self.board.copy()
|
|
g.current_player = self.current_player
|
|
g.done = self.done
|
|
g.winner = self.winner
|
|
g.move_count = self.move_count
|
|
return g
|