Files
2026-03-27 12:17:25 +01:00

87 lines
2.8 KiB
Python

"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32."""
import os
import numpy as np
from .game import ConnectFour, ROWS, COLS
from .config import EXPORT_DIR, QUANTIZE_INT8
def representative_dataset():
"""Generate sample inputs for int8 calibration."""
game = ConnectFour()
for _ in range(200):
game.reset()
# Play random moves to get diverse board states
moves = np.random.randint(0, min(ROWS * COLS, 20))
for _ in range(moves):
legal = game.legal_moves()
if not legal or game.done:
break
game.step(np.random.choice(legal))
yield [game.get_state()[np.newaxis].astype(np.float32)]
def export_tflite(model_path, quantize=None):
"""Convert a saved Keras model to TFLite.
Args:
model_path: Path to the .keras model file.
quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8.
"""
import tensorflow as tf
if quantize is None:
quantize = QUANTIZE_INT8
os.makedirs(EXPORT_DIR, exist_ok=True)
model = tf.keras.models.load_model(model_path)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
if quantize:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
suffix = "_int8"
else:
suffix = "_f32"
tflite_model = converter.convert()
out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite")
with open(out_path, "wb") as f:
f.write(tflite_model)
size_kb = len(tflite_model) / 1024
print(f"Exported: {out_path} ({size_kb:.1f} KB)")
# Also export as C header for direct embedding in firmware
header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h")
_write_c_header(tflite_model, header_path)
print(f"C header: {header_path}")
return out_path
def _write_c_header(model_bytes, path):
"""Write TFLite model as a C byte array for ESP32 firmware inclusion."""
with open(path, "w") as f:
f.write("#pragma once\n\n")
f.write(f"// Auto-generated — {len(model_bytes)} bytes\n")
f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n")
f.write("alignas(16) const unsigned char connect4_model[] = {\n")
for i in range(0, len(model_bytes), 12):
chunk = model_bytes[i:i + 12]
f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n")
f.write("};\n")
if __name__ == "__main__":
import sys
model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras"
export_tflite(model_path)