87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32."""
|
|
|
|
import os
|
|
import numpy as np
|
|
|
|
from .game import ConnectFour, ROWS, COLS
|
|
from .config import EXPORT_DIR, QUANTIZE_INT8
|
|
|
|
|
|
def representative_dataset():
|
|
"""Generate sample inputs for int8 calibration."""
|
|
game = ConnectFour()
|
|
for _ in range(200):
|
|
game.reset()
|
|
# Play random moves to get diverse board states
|
|
moves = np.random.randint(0, min(ROWS * COLS, 20))
|
|
for _ in range(moves):
|
|
legal = game.legal_moves()
|
|
if not legal or game.done:
|
|
break
|
|
game.step(np.random.choice(legal))
|
|
yield [game.get_state()[np.newaxis].astype(np.float32)]
|
|
|
|
|
|
def export_tflite(model_path, quantize=None):
|
|
"""Convert a saved Keras model to TFLite.
|
|
|
|
Args:
|
|
model_path: Path to the .keras model file.
|
|
quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8.
|
|
"""
|
|
import tensorflow as tf
|
|
|
|
if quantize is None:
|
|
quantize = QUANTIZE_INT8
|
|
|
|
os.makedirs(EXPORT_DIR, exist_ok=True)
|
|
|
|
model = tf.keras.models.load_model(model_path)
|
|
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
|
|
if quantize:
|
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
converter.representative_dataset = representative_dataset
|
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
|
converter.inference_input_type = tf.int8
|
|
converter.inference_output_type = tf.int8
|
|
suffix = "_int8"
|
|
else:
|
|
suffix = "_f32"
|
|
|
|
tflite_model = converter.convert()
|
|
|
|
out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite")
|
|
with open(out_path, "wb") as f:
|
|
f.write(tflite_model)
|
|
|
|
size_kb = len(tflite_model) / 1024
|
|
print(f"Exported: {out_path} ({size_kb:.1f} KB)")
|
|
|
|
# Also export as C header for direct embedding in firmware
|
|
header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h")
|
|
_write_c_header(tflite_model, header_path)
|
|
print(f"C header: {header_path}")
|
|
|
|
return out_path
|
|
|
|
|
|
def _write_c_header(model_bytes, path):
|
|
"""Write TFLite model as a C byte array for ESP32 firmware inclusion."""
|
|
with open(path, "w") as f:
|
|
f.write("#pragma once\n\n")
|
|
f.write(f"// Auto-generated — {len(model_bytes)} bytes\n")
|
|
f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n")
|
|
f.write("alignas(16) const unsigned char connect4_model[] = {\n")
|
|
for i in range(0, len(model_bytes), 12):
|
|
chunk = model_bytes[i:i + 12]
|
|
f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n")
|
|
f.write("};\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras"
|
|
export_tflite(model_path)
|