[fix] Non heuristic moves...
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from .game import ConnectFour, ROWS, COLS
|
||||
from .config import EXPORT_DIR, QUANTIZE_INT8
|
||||
|
||||
|
||||
def representative_dataset():
|
||||
"""Generate sample inputs for int8 calibration."""
|
||||
game = ConnectFour()
|
||||
for _ in range(200):
|
||||
game.reset()
|
||||
# Play random moves to get diverse board states
|
||||
moves = np.random.randint(0, min(ROWS * COLS, 20))
|
||||
for _ in range(moves):
|
||||
legal = game.legal_moves()
|
||||
if not legal or game.done:
|
||||
break
|
||||
game.step(np.random.choice(legal))
|
||||
yield [game.get_state()[np.newaxis].astype(np.float32)]
|
||||
|
||||
|
||||
def export_tflite(model_path, quantize=None):
|
||||
"""Convert a saved Keras model to TFLite.
|
||||
|
||||
Args:
|
||||
model_path: Path to the .keras model file.
|
||||
quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
if quantize is None:
|
||||
quantize = QUANTIZE_INT8
|
||||
|
||||
os.makedirs(EXPORT_DIR, exist_ok=True)
|
||||
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
|
||||
if quantize:
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = representative_dataset
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.inference_input_type = tf.int8
|
||||
converter.inference_output_type = tf.int8
|
||||
suffix = "_int8"
|
||||
else:
|
||||
suffix = "_f32"
|
||||
|
||||
tflite_model = converter.convert()
|
||||
|
||||
out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite")
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
size_kb = len(tflite_model) / 1024
|
||||
print(f"Exported: {out_path} ({size_kb:.1f} KB)")
|
||||
|
||||
# Also export as C header for direct embedding in firmware
|
||||
header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h")
|
||||
_write_c_header(tflite_model, header_path)
|
||||
print(f"C header: {header_path}")
|
||||
|
||||
return out_path
|
||||
|
||||
|
||||
def _write_c_header(model_bytes, path):
|
||||
"""Write TFLite model as a C byte array for ESP32 firmware inclusion."""
|
||||
with open(path, "w") as f:
|
||||
f.write("#pragma once\n\n")
|
||||
f.write(f"// Auto-generated — {len(model_bytes)} bytes\n")
|
||||
f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n")
|
||||
f.write("alignas(16) const unsigned char connect4_model[] = {\n")
|
||||
for i in range(0, len(model_bytes), 12):
|
||||
chunk = model_bytes[i:i + 12]
|
||||
f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n")
|
||||
f.write("};\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras"
|
||||
export_tflite(model_path)
|
||||
Reference in New Issue
Block a user