"""Connect Four game environment for self-play training.""" import numpy as np ROWS = 6 COLS = 7 WIN_LENGTH = 4 class ConnectFour: """Connect Four game with numpy board representation. Board encoding: 0 = empty, 1 = player 1, -1 = player 2. """ def __init__(self): self.reset() def reset(self): self.board = np.zeros((ROWS, COLS), dtype=np.int8) self.current_player = 1 self.done = False self.winner = 0 # 0 = no winner / draw, 1 or -1 self.move_count = 0 return self.get_state() def get_state(self): """Return board from current player's perspective as (6,7,2) tensor. Channel 0: current player's pieces (1s). Channel 1: opponent's pieces (1s). """ state = np.zeros((ROWS, COLS, 2), dtype=np.float32) state[:, :, 0] = (self.board == self.current_player).astype(np.float32) state[:, :, 1] = (self.board == -self.current_player).astype(np.float32) return state def legal_moves(self): """Return list of columns that are not full.""" return [c for c in range(COLS) if self.board[0, c] == 0] def legal_moves_mask(self): """Return binary mask of legal columns.""" return (self.board[0] == 0).astype(np.float32) def step(self, col): """Play a move in the given column. Returns (state, reward, done).""" if self.done: raise ValueError("Game is already over.") if col < 0 or col >= COLS or self.board[0, col] != 0: raise ValueError(f"Illegal move: column {col}") # Drop piece row = self._get_drop_row(col) self.board[row, col] = self.current_player self.move_count += 1 # Check win if self._check_win(row, col): self.done = True self.winner = self.current_player reward = 1.0 elif self.move_count == ROWS * COLS: self.done = True self.winner = 0 reward = 0.0 else: reward = 0.0 # Switch player self.current_player *= -1 return self.get_state(), reward, self.done def _get_drop_row(self, col): for r in range(ROWS - 1, -1, -1): if self.board[r, col] == 0: return r raise ValueError(f"Column {col} is full") def _check_win(self, row, col): player = self.board[row, col] directions = [(0, 1), (1, 0), (1, 1), (1, -1)] for dr, dc in directions: count = 1 for sign in (1, -1): r, c = row + sign * dr, col + sign * dc while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player: count += 1 r += sign * dr c += sign * dc if count >= WIN_LENGTH: return True return False def clone(self): g = ConnectFour() g.board = self.board.copy() g.current_player = self.current_player g.done = self.done g.winner = self.winner g.move_count = self.move_count return g