Files
Connect-four-Esp32/rl/game.py
T
2026-03-27 12:17:25 +01:00

103 lines
3.1 KiB
Python

"""Connect Four game environment for self-play training."""
import numpy as np
ROWS = 6
COLS = 7
WIN_LENGTH = 4
class ConnectFour:
"""Connect Four game with numpy board representation.
Board encoding: 0 = empty, 1 = player 1, -1 = player 2.
"""
def __init__(self):
self.reset()
def reset(self):
self.board = np.zeros((ROWS, COLS), dtype=np.int8)
self.current_player = 1
self.done = False
self.winner = 0 # 0 = no winner / draw, 1 or -1
self.move_count = 0
return self.get_state()
def get_state(self):
"""Return board from current player's perspective as (6,7,2) tensor.
Channel 0: current player's pieces (1s).
Channel 1: opponent's pieces (1s).
"""
state = np.zeros((ROWS, COLS, 2), dtype=np.float32)
state[:, :, 0] = (self.board == self.current_player).astype(np.float32)
state[:, :, 1] = (self.board == -self.current_player).astype(np.float32)
return state
def legal_moves(self):
"""Return list of columns that are not full."""
return [c for c in range(COLS) if self.board[0, c] == 0]
def legal_moves_mask(self):
"""Return binary mask of legal columns."""
return (self.board[0] == 0).astype(np.float32)
def step(self, col):
"""Play a move in the given column. Returns (state, reward, done)."""
if self.done:
raise ValueError("Game is already over.")
if col < 0 or col >= COLS or self.board[0, col] != 0:
raise ValueError(f"Illegal move: column {col}")
# Drop piece
row = self._get_drop_row(col)
self.board[row, col] = self.current_player
self.move_count += 1
# Check win
if self._check_win(row, col):
self.done = True
self.winner = self.current_player
reward = 1.0
elif self.move_count == ROWS * COLS:
self.done = True
self.winner = 0
reward = 0.0
else:
reward = 0.0
# Switch player
self.current_player *= -1
return self.get_state(), reward, self.done
def _get_drop_row(self, col):
for r in range(ROWS - 1, -1, -1):
if self.board[r, col] == 0:
return r
raise ValueError(f"Column {col} is full")
def _check_win(self, row, col):
player = self.board[row, col]
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
for dr, dc in directions:
count = 1
for sign in (1, -1):
r, c = row + sign * dr, col + sign * dc
while 0 <= r < ROWS and 0 <= c < COLS and self.board[r, c] == player:
count += 1
r += sign * dr
c += sign * dc
if count >= WIN_LENGTH:
return True
return False
def clone(self):
g = ConnectFour()
g.board = self.board.copy()
g.current_player = self.current_player
g.done = self.done
g.winner = self.winner
g.move_count = self.move_count
return g