"""Export trained Keras model to TFLite (optionally int8-quantized) for ESP32.""" import os import numpy as np from .game import ConnectFour, ROWS, COLS from .config import EXPORT_DIR, QUANTIZE_INT8 def representative_dataset(): """Generate sample inputs for int8 calibration.""" game = ConnectFour() for _ in range(200): game.reset() # Play random moves to get diverse board states moves = np.random.randint(0, min(ROWS * COLS, 20)) for _ in range(moves): legal = game.legal_moves() if not legal or game.done: break game.step(np.random.choice(legal)) yield [game.get_state()[np.newaxis].astype(np.float32)] def export_tflite(model_path, quantize=None): """Convert a saved Keras model to TFLite. Args: model_path: Path to the .keras model file. quantize: Override quantization setting. If None, uses config.QUANTIZE_INT8. """ import tensorflow as tf if quantize is None: quantize = QUANTIZE_INT8 os.makedirs(EXPORT_DIR, exist_ok=True) model = tf.keras.models.load_model(model_path) converter = tf.lite.TFLiteConverter.from_keras_model(model) if quantize: converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 suffix = "_int8" else: suffix = "_f32" tflite_model = converter.convert() out_path = os.path.join(EXPORT_DIR, f"connect4{suffix}.tflite") with open(out_path, "wb") as f: f.write(tflite_model) size_kb = len(tflite_model) / 1024 print(f"Exported: {out_path} ({size_kb:.1f} KB)") # Also export as C header for direct embedding in firmware header_path = os.path.join(EXPORT_DIR, f"connect4_model{suffix}.h") _write_c_header(tflite_model, header_path) print(f"C header: {header_path}") return out_path def _write_c_header(model_bytes, path): """Write TFLite model as a C byte array for ESP32 firmware inclusion.""" with open(path, "w") as f: f.write("#pragma once\n\n") f.write(f"// Auto-generated — {len(model_bytes)} bytes\n") f.write(f"const unsigned int connect4_model_len = {len(model_bytes)};\n") f.write("alignas(16) const unsigned char connect4_model[] = {\n") for i in range(0, len(model_bytes), 12): chunk = model_bytes[i:i + 12] f.write(" " + ", ".join(f"0x{b:02x}" for b in chunk) + ",\n") f.write("};\n") if __name__ == "__main__": import sys model_path = sys.argv[1] if len(sys.argv) > 1 else "rl/checkpoints/model_final.keras" export_tflite(model_path)