[fix] Non heuristic moves...
This commit is contained in:
+54
@@ -0,0 +1,54 @@
|
||||
"""Compact dual-head neural network (policy + value) sized for ESP32."""
|
||||
|
||||
from .config import CONV_FILTERS, NUM_CONV_LAYERS, DENSE_UNITS, LEARNING_RATE
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Build a small AlphaZero-style network.
|
||||
|
||||
Input: (6, 7, 2) — current player pieces / opponent pieces
|
||||
Output: policy (7,) — log-probabilities over columns
|
||||
value (1,) — board evaluation in [-1, 1]
|
||||
"""
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
|
||||
inp = layers.Input(shape=(6, 7, 2), name="board")
|
||||
|
||||
x = inp
|
||||
for i in range(NUM_CONV_LAYERS):
|
||||
x = layers.Conv2D(
|
||||
CONV_FILTERS, 3, padding="same", activation="relu", name=f"conv{i}"
|
||||
)(x)
|
||||
x = layers.BatchNormalization(name=f"bn{i}")(x)
|
||||
|
||||
flat = layers.Flatten(name="flat")(x)
|
||||
shared = layers.Dense(DENSE_UNITS, activation="relu", name="shared_dense")(flat)
|
||||
|
||||
# Policy head
|
||||
policy = layers.Dense(7, name="policy_logits")(shared)
|
||||
|
||||
# Value head
|
||||
value = layers.Dense(1, activation="tanh", name="value")(shared)
|
||||
|
||||
model = keras.Model(inputs=inp, outputs=[policy, value], name="connect4_net")
|
||||
|
||||
model.compile(
|
||||
optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
||||
loss={
|
||||
"policy_logits": keras.losses.CategoricalCrossentropy(from_logits=True),
|
||||
"value": keras.losses.MeanSquaredError(),
|
||||
},
|
||||
loss_weights={"policy_logits": 1.0, "value": 1.0},
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def print_model_info(model):
|
||||
model.summary()
|
||||
total_params = model.count_params()
|
||||
approx_size_kb = total_params * 4 / 1024 # float32
|
||||
approx_int8_kb = total_params / 1024 # int8
|
||||
print(f"\nTotal parameters: {total_params:,}")
|
||||
print(f"Approx size (float32): {approx_size_kb:.1f} KB")
|
||||
print(f"Approx size (int8): {approx_int8_kb:.1f} KB")
|
||||
Reference in New Issue
Block a user