diff --git a/crates/burn-import/model-checks/yolo11x/Cargo.toml b/crates/burn-import/model-checks/yolo/Cargo.toml similarity index 80% rename from crates/burn-import/model-checks/yolo11x/Cargo.toml rename to crates/burn-import/model-checks/yolo/Cargo.toml index 3a586ae632..a3e86827b4 100644 --- a/crates/burn-import/model-checks/yolo11x/Cargo.toml +++ b/crates/burn-import/model-checks/yolo/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "burn-import-model-checks-yolo11x" +name = "burn-import-model-checks-yolo" version = "0.1.0" edition = "2024" publish = false @@ -23,4 +23,4 @@ burn = { path = "../../../../crates/burn", features = [ burn-import = { path = "../../../burn-import", features = ["pytorch"] } [build-dependencies] -burn-import = { path = "../../../burn-import" } +burn-import = { path = "../../../burn-import" } \ No newline at end of file diff --git a/crates/burn-import/model-checks/yolo/README.md b/crates/burn-import/model-checks/yolo/README.md new file mode 100644 index 0000000000..270d06054e --- /dev/null +++ b/crates/burn-import/model-checks/yolo/README.md @@ -0,0 +1,55 @@ +# YOLO Model Checks + +This crate provides a unified interface for testing multiple YOLO model variants with Burn. + +## Supported Models + +- `yolov5s` - YOLOv5 small variant +- `yolov8n` - YOLOv8 nano variant +- `yolov8s` - YOLOv8 small variant +- `yolo11x` - YOLO11 extra-large variant +- `yolov10n` - YOLOv10 nano variant (Note: Currently fails due to TopK operator issue) + +## Usage + +### 1. Download and prepare a model + +```bash +# Using Python directly +python get_model.py --model yolov8n + +# Or using uv +uv run get_model.py --model yolov8n + +# List available models +uv run get_model.py --list +``` + +### 2. Run the model test + +After building, you can run the test. The model is already compiled in: + +```bash +YOLO_MODEL=yolov8s cargo run --release +``` + +## Directory Structure + +``` +yolo/ +├── artifacts/ # Downloaded ONNX models and test data +│ ├── yolov8n_opset16.onnx +│ ├── yolov8n_test_data.pt +│ └── ... +├── src/ +│ └── main.rs # Test runner +├── build.rs # Build script that generates model code +├── get_model.py # Model download and preparation script +└── Cargo.toml +``` + +## Notes + +- All YOLO models (except v10) output shape `[1, 84, 8400]` for standard object detection +- YOLOv10n has a different architecture with output shape `[1, 300, 6]` and uses TopK operator +- The crate requires explicit model selection at build time (no default model) diff --git a/crates/burn-import/model-checks/yolo/build.rs b/crates/burn-import/model-checks/yolo/build.rs new file mode 100644 index 0000000000..8d3d71445e --- /dev/null +++ b/crates/burn-import/model-checks/yolo/build.rs @@ -0,0 +1,85 @@ +use burn_import::onnx::ModelGen; +use std::env; +use std::fs; +use std::path::Path; + +fn main() { + // Supported models + let supported_models = vec!["yolov5s", "yolov8n", "yolov8s", "yolov10n", "yolo11x"]; + + // Get the model name from environment variable (required) + let model_name = env::var("YOLO_MODEL").unwrap_or_else(|_| { + eprintln!("Error: YOLO_MODEL environment variable is not set."); + eprintln!(); + eprintln!("Please specify which YOLO model to build:"); + eprintln!(" YOLO_MODEL=yolov8n cargo build"); + eprintln!(); + eprintln!("Available models: {}", supported_models.join(", ")); + std::process::exit(1); + }); + + if !supported_models.contains(&model_name.as_str()) { + eprintln!( + "Error: Unsupported model '{}'. Supported models: {:?}", + model_name, supported_models + ); + std::process::exit(1); + } + + let onnx_path = format!("artifacts/{}_opset16.onnx", model_name); + let test_data_path = format!("artifacts/{}_test_data.pt", model_name); + + // Tell Cargo to only rebuild if these files change + println!("cargo:rerun-if-changed={}", onnx_path); + println!("cargo:rerun-if-changed={}", test_data_path); + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-env-changed=YOLO_MODEL"); + + // Check if the ONNX model file exists + if !Path::new(&onnx_path).exists() { + eprintln!("Error: ONNX model file not found at '{}'", onnx_path); + eprintln!(); + eprintln!( + "Please run the following command to download and prepare the {} model:", + model_name + ); + eprintln!(" python get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Or if you prefer using uv:"); + eprintln!(" uv run get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Available models: {}", supported_models.join(", ")); + std::process::exit(1); + } + + // Generate the model code from the ONNX file + ModelGen::new() + .input(&onnx_path) + .out_dir("model/") + .run_from_script(); + + // Write the model name to a file so main.rs can access it + let out_dir = env::var("OUT_DIR").unwrap(); + let model_info_path = Path::new(&out_dir).join("model_info.rs"); + + // Generate the include path for the model + let model_include = format!( + "include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));", + model_name + ); + + fs::write( + model_info_path, + format!( + r#"pub const MODEL_NAME: &str = "{}"; +pub const TEST_DATA_FILE: &str = "{}_test_data.pt"; + +// Include the generated model +pub mod yolo_model {{ + {} +}}"#, + model_name, model_name, model_include + ), + ) + .expect("Failed to write model info"); +} diff --git a/crates/burn-import/model-checks/yolo11x/get_model.py b/crates/burn-import/model-checks/yolo/get_model.py similarity index 61% rename from crates/burn-import/model-checks/yolo11x/get_model.py rename to crates/burn-import/model-checks/yolo/get_model.py index 7e76229d8b..4506c861bf 100755 --- a/crates/burn-import/model-checks/yolo11x/get_model.py +++ b/crates/burn-import/model-checks/yolo/get_model.py @@ -2,8 +2,8 @@ # /// script # dependencies = [ -# "onnx-weekly==1.19.0.dev20250419", -# "onnxruntime>=1.22.0", +# "onnx>=1.17.0", +# "onnxruntime>=1.18.0", # "ultralytics>=8.3.0", # "numpy", # "pillow", @@ -17,6 +17,17 @@ from onnx import shape_inference, version_converter import numpy as np from pathlib import Path +import argparse + + +# Supported YOLO models configuration +SUPPORTED_MODELS = { + 'yolov5s': {'download_name': 'yolov5s.pt', 'display_name': 'YOLOv5s'}, + 'yolov8n': {'download_name': 'yolov8n.pt', 'display_name': 'YOLOv8n'}, + 'yolov8s': {'download_name': 'yolov8s.pt', 'display_name': 'YOLOv8s'}, + 'yolov10n': {'download_name': 'yolov10n.pt', 'display_name': 'YOLOv10n'}, + 'yolo11x': {'download_name': 'yolo11x.pt', 'display_name': 'YOLO11x'}, +} def get_input_shape(model): @@ -35,25 +46,30 @@ def get_input_shape(model): return shape -def download_and_convert_model(output_path): - """Download YOLO11x model and export to ONNX format.""" +def download_and_convert_model(model_name, output_path): + """Download YOLO model and export to ONNX format.""" from ultralytics import YOLO - print("Downloading YOLO11x model...") - pt_path = Path("artifacts/yolo11x.pt") - model = YOLO(str(pt_path)) + model_config = SUPPORTED_MODELS[model_name] + display_name = model_config['display_name'] + download_name = model_config['download_name'] + + print(f"Downloading {display_name} model...") + model = YOLO(download_name) print("Exporting to ONNX format...") model.export(format="onnx", simplify=True) # Move exported file to artifacts - exported_file = Path("yolo11x.onnx") + base_name = download_name.replace('.pt', '') + exported_file = Path(f"{base_name}.onnx") if exported_file.exists(): exported_file.rename(output_path) # Clean up PyTorch file - if pt_path.exists(): - pt_path.unlink() + pt_file = Path(download_name) + if pt_file.exists(): + pt_file.unlink() if not output_path.exists(): raise FileNotFoundError(f"Failed to create ONNX file at {output_path}") @@ -81,7 +97,7 @@ def process_model(input_path, output_path, target_opset=16): return model -def generate_test_data(model_path, output_dir): +def generate_test_data(model_path, output_path, model_name): """Generate test input/output data and save as PyTorch tensors.""" import torch import onnxruntime as ort @@ -108,29 +124,46 @@ def generate_test_data(model_path, output_dir): 'output': torch.from_numpy(outputs[0]) } - test_data_path = Path(output_dir) / "test_data.pt" - torch.save(test_data, test_data_path) + torch.save(test_data, output_path) - print(f" ✓ Test data saved to: {test_data_path}") + print(f" ✓ Test data saved to: {output_path}") print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}") def main(): + parser = argparse.ArgumentParser(description='YOLO Model Preparation Tool') + parser.add_argument('--model', type=str, default='yolov8n', + choices=list(SUPPORTED_MODELS.keys()), + help=f'YOLO model to download and prepare (default: yolov8n). Choices: {", ".join(SUPPORTED_MODELS.keys())}') + parser.add_argument('--list', action='store_true', + help='List all supported models') + + args = parser.parse_args() + + if args.list: + print("Supported YOLO models:") + for model_id, config in SUPPORTED_MODELS.items(): + print(f" - {model_id:10s} ({config['display_name']})") + return + + model_name = args.model + display_name = SUPPORTED_MODELS[model_name]['display_name'] + print("=" * 60) - print("YOLO11x Model Preparation Tool") + print(f"{display_name} Model Preparation Tool") print("=" * 60) # Setup paths artifacts_dir = Path("artifacts") artifacts_dir.mkdir(exist_ok=True) - original_path = artifacts_dir / "yolo11x.onnx" - processed_path = artifacts_dir / "yolo11x_opset16.onnx" - test_data_path = artifacts_dir / "test_data.pt" + original_path = artifacts_dir / f"{model_name}.onnx" + processed_path = artifacts_dir / f"{model_name}_opset16.onnx" + test_data_path = artifacts_dir / f"{model_name}_test_data.pt" # Check if we already have everything if processed_path.exists() and test_data_path.exists(): - print(f"\n✓ All files already exist:") + print(f"\n✓ All files already exist for {display_name}:") print(f" Model: {processed_path}") print(f" Test data: {test_data_path}") print("\nNothing to do!") @@ -138,8 +171,8 @@ def main(): # Download and convert if needed if not original_path.exists() and not processed_path.exists(): - print("\nStep 1: Downloading and converting YOLO11x model...") - download_and_convert_model(original_path) + print(f"\nStep 1: Downloading and converting {display_name} model...") + download_and_convert_model(model_name, original_path) # Process model if needed if not processed_path.exists(): @@ -153,10 +186,10 @@ def main(): # Generate test data if needed if not test_data_path.exists(): print("\nStep 3: Generating test data...") - generate_test_data(processed_path, artifacts_dir) + generate_test_data(processed_path, test_data_path, model_name) print("\n" + "=" * 60) - print("✓ YOLO11x model preparation completed!") + print(f"✓ {display_name} model preparation completed!") print(f" Model: {processed_path}") print(f" Test data: {test_data_path}") print("=" * 60) @@ -170,4 +203,4 @@ def main(): sys.exit(1) except Exception as e: print(f"\n✗ Error: {str(e)}") - sys.exit(1) + sys.exit(1) \ No newline at end of file diff --git a/crates/burn-import/model-checks/yolo11x/src/main.rs b/crates/burn-import/model-checks/yolo/src/main.rs similarity index 67% rename from crates/burn-import/model-checks/yolo11x/src/main.rs rename to crates/burn-import/model-checks/yolo/src/main.rs index 7de0d2a150..2c11abea52 100644 --- a/crates/burn-import/model-checks/yolo11x/src/main.rs +++ b/crates/burn-import/model-checks/yolo/src/main.rs @@ -5,6 +5,7 @@ use burn::prelude::*; use burn::record::*; use burn_import::pytorch::PyTorchFileRecorder; +use std::env; use std::path::Path; use std::time::Instant; @@ -20,10 +21,11 @@ pub type MyBackend = burn::backend::LibTorch; #[cfg(feature = "metal")] pub type MyBackend = burn::backend::Metal; -// Import the generated model code as a module -pub mod yolo11x { - include!(concat!(env!("OUT_DIR"), "/model/yolo11x_opset16.rs")); -} +// Import model info generated by build.rs (includes the yolo_model module) +include!(concat!(env!("OUT_DIR"), "/model_info.rs")); + +// Use the yolo_model module from model_info.rs +use yolo_model::Model; #[derive(Debug, Module)] struct TestData { @@ -31,9 +33,24 @@ struct TestData { output: Param>, } +fn get_model_display_name(model_name: &str) -> &str { + match model_name { + "yolov5s" => "YOLOv5s", + "yolov8n" => "YOLOv8n", + "yolov8s" => "YOLOv8s", + "yolov10n" => "YOLOv10n", + "yolo11x" => "YOLO11x", + _ => model_name, + } +} + fn main() { + // MODEL_NAME is set at build time from YOLO_MODEL env var + let model_name = MODEL_NAME; + let display_name = get_model_display_name(model_name); + println!("========================================"); - println!("YOLO11x Burn Model Test"); + println!("{} Burn Model Test", display_name); println!("========================================\n"); // Check if artifacts exist @@ -41,29 +58,50 @@ fn main() { if !artifacts_dir.exists() { eprintln!("Error: artifacts directory not found!"); eprintln!("Please run get_model.py first to download the model and test data."); + eprintln!("Example: uv run get_model.py --model {}", model_name); + std::process::exit(1); + } + + // Check if model files exist for this specific model + let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name)); + let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name)); + + if !model_file.exists() || !test_data_file.exists() { + eprintln!("Error: Model files not found for {}!", display_name); + eprintln!("Please run: uv run get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Available models:"); + eprintln!(" - yolov5s"); + eprintln!(" - yolov8n"); + eprintln!(" - yolov8s"); + eprintln!(" - yolov10n"); + eprintln!(" - yolo11x"); std::process::exit(1); } // Initialize the model (without weights for now) - println!("Initializing YOLO11x model..."); + println!("Initializing {} model...", display_name); let start = Instant::now(); let device = Default::default(); - let model: yolo11x::Model = yolo11x::Model::default(); + let model: Model = Model::default(); let init_time = start.elapsed(); println!(" Model initialized in {:.2?}", init_time); // Save model structure to file - println!("\nSaving model structure to artifacts/model.txt..."); + let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name)); + println!( + "\nSaving model structure to {}...", + model_txt_path.display() + ); let model_str = format!("{}", model); - std::fs::write("artifacts/model.txt", &model_str) - .expect("Failed to write model structure to file"); + std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file"); println!(" Model structure saved"); // Load test data from PyTorch file - println!("\nLoading test data from artifacts/test_data.pt..."); + println!("\nLoading test data from {}...", test_data_file.display()); let start = Instant::now(); let test_data: TestDataRecord = PyTorchFileRecorder::::new() - .load("artifacts/test_data.pt".into(), &device) + .load(test_data_file.into(), &device) .expect("Failed to load test data"); let load_time = start.elapsed(); println!(" Data loaded in {:.2?}", load_time); @@ -92,14 +130,14 @@ fn main() { let shape = output.shape(); println!("\n Model output shape: {:?}", shape.dims); - // Verify expected output shape + // Verify expected output shape (most YOLO models use [1, 84, 8400]) let expected_shape = [1, 84, 8400]; if shape.dims == expected_shape { println!(" ✓ Output shape matches expected: {:?}", expected_shape); } else { println!( - " ⚠ Warning: Expected shape {:?}, got {:?}", - expected_shape, shape.dims + " ⚠ Note: Shape is {:?} (expected {:?} for most YOLO models)", + shape.dims, expected_shape ); } diff --git a/crates/burn-import/model-checks/yolo11x/build.rs b/crates/burn-import/model-checks/yolo11x/build.rs deleted file mode 100644 index 625a22f2ba..0000000000 --- a/crates/burn-import/model-checks/yolo11x/build.rs +++ /dev/null @@ -1,32 +0,0 @@ -use burn_import::onnx::ModelGen; -use std::path::Path; - -fn main() { - let onnx_path = "artifacts/yolo11x_opset16.onnx"; - let test_data_path = "artifacts/test_data.pt"; - - // Tell Cargo to only rebuild if these files change - println!("cargo:rerun-if-changed={}", onnx_path); - println!("cargo:rerun-if-changed={}", test_data_path); - println!("cargo:rerun-if-changed=build.rs"); - - // Check if the ONNX model file exists - if !Path::new(onnx_path).exists() { - eprintln!("Error: ONNX model file not found at '{}'", onnx_path); - eprintln!(); - eprintln!("Please run the following command to download and prepare the model:"); - eprintln!(" python get_model.py"); - eprintln!(); - eprintln!("Or if you prefer using uv:"); - eprintln!(" uv run get_model.py"); - eprintln!(); - eprintln!("This will download the YOLO11x model and convert it to ONNX format."); - std::process::exit(1); - } - - // Generate the model code from the ONNX file - ModelGen::new() - .input(onnx_path) - .out_dir("model/") - .run_from_script(); -} \ No newline at end of file