39 lines
931 B
Python
39 lines
931 B
Python
"""Entry point: python -m rl [train|export|info]"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
|
|
|
def main():
|
|
cmd = sys.argv[1] if len(sys.argv) > 1 else "train"
|
|
|
|
if cmd == "train":
|
|
from .train import train
|
|
train()
|
|
|
|
elif cmd == "export":
|
|
from .export import export_tflite
|
|
model_path = sys.argv[2] if len(sys.argv) > 2 else "rl/checkpoints/model_final.keras"
|
|
export_tflite(model_path)
|
|
|
|
elif cmd == "visualize":
|
|
from .visualize import run_visualized
|
|
run_visualized()
|
|
|
|
elif cmd == "info":
|
|
from .model import build_model, print_model_info
|
|
model = build_model()
|
|
print_model_info(model)
|
|
|
|
else:
|
|
print(f"Unknown command: {cmd}")
|
|
print("Usage: python -m rl [train|visualize|export|info]")
|
|
sys.exit(1)
|
|
|
|
|
|
main()
|