From c5798979e9ee4fe7aa6bd3dac33d1840585ae03a Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Sat, 20 Sep 2025 21:42:41 -0500 Subject: [PATCH 1/3] Add YOLOv8n model check with ONNX import and test Introduces a new model-checks/yolov8n crate for Burn, including Cargo.toml, build script for ONNX codegen, Python script to download and process the YOLOv8n model and generate test data, and a Rust main.rs to validate model output against reference data. This enables automated verification of YOLOv8n ONNX import and inference correctness across supported backends. --- .../model-checks/yolov8n/Cargo.toml | 26 +++ .../burn-import/model-checks/yolov8n/build.rs | 32 ++++ .../model-checks/yolov8n/get_model.py | 173 ++++++++++++++++++ .../model-checks/yolov8n/src/main.rs | 148 +++++++++++++++ 4 files changed, 379 insertions(+) create mode 100644 crates/burn-import/model-checks/yolov8n/Cargo.toml create mode 100644 crates/burn-import/model-checks/yolov8n/build.rs create mode 100755 crates/burn-import/model-checks/yolov8n/get_model.py create mode 100644 crates/burn-import/model-checks/yolov8n/src/main.rs diff --git a/crates/burn-import/model-checks/yolov8n/Cargo.toml b/crates/burn-import/model-checks/yolov8n/Cargo.toml new file mode 100644 index 0000000000..54d56e1483 --- /dev/null +++ b/crates/burn-import/model-checks/yolov8n/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "burn-import-model-checks-yolov8n" +version = "0.1.0" +edition = "2024" +publish = false + +[workspace] + +[features] +default = ["tch"] +ndarray = [] +tch = [] +wgpu = [] +metal = [] + +[dependencies] +burn = { path = "../../../../crates/burn", features = [ + "ndarray", + "tch", + "wgpu", + "metal", +] } +burn-import = { path = "../../../burn-import", features = ["pytorch"] } + +[build-dependencies] +burn-import = { path = "../../../burn-import" } \ No newline at end of file diff --git a/crates/burn-import/model-checks/yolov8n/build.rs b/crates/burn-import/model-checks/yolov8n/build.rs new file mode 100644 index 0000000000..8c0600cdce --- /dev/null +++ b/crates/burn-import/model-checks/yolov8n/build.rs @@ -0,0 +1,32 @@ +use burn_import::onnx::ModelGen; +use std::path::Path; + +fn main() { + let onnx_path = "artifacts/yolov8n_opset16.onnx"; + let test_data_path = "artifacts/test_data.pt"; + + // Tell Cargo to only rebuild if these files change + println!("cargo:rerun-if-changed={}", onnx_path); + println!("cargo:rerun-if-changed={}", test_data_path); + println!("cargo:rerun-if-changed=build.rs"); + + // Check if the ONNX model file exists + if !Path::new(onnx_path).exists() { + eprintln!("Error: ONNX model file not found at '{}'", onnx_path); + eprintln!(); + eprintln!("Please run the following command to download and prepare the model:"); + eprintln!(" python get_model.py"); + eprintln!(); + eprintln!("Or if you prefer using uv:"); + eprintln!(" uv run get_model.py"); + eprintln!(); + eprintln!("This will download the YOLOv8n model and convert it to ONNX format."); + std::process::exit(1); + } + + // Generate the model code from the ONNX file + ModelGen::new() + .input(onnx_path) + .out_dir("model/") + .run_from_script(); +} diff --git a/crates/burn-import/model-checks/yolov8n/get_model.py b/crates/burn-import/model-checks/yolov8n/get_model.py new file mode 100755 index 0000000000..5c36386d12 --- /dev/null +++ b/crates/burn-import/model-checks/yolov8n/get_model.py @@ -0,0 +1,173 @@ +#!/usr/bin/env -S uv run --script + +# /// script +# dependencies = [ +# "onnx>=1.17.0", +# "onnxruntime>=1.18.0", +# "ultralytics>=8.3.0", +# "numpy", +# "pillow", +# "torch", +# ] +# /// + +import os +import sys +import onnx +from onnx import shape_inference, version_converter +import numpy as np +from pathlib import Path + + +def get_input_shape(model): + """Extract input shape from ONNX model.""" + input_info = model.graph.input[0] + shape = [] + for dim in input_info.type.tensor_type.shape.dim: + if dim.HasField('dim_value'): + shape.append(dim.dim_value) + else: + shape.append(1) # Default to 1 for dynamic dimensions + + # Ensure valid YOLO input shape + if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000: + return [1, 3, 640, 640] + return shape + + +def download_and_convert_model(output_path): + """Download YOLOv8n model and export to ONNX format.""" + from ultralytics import YOLO + + print("Downloading YOLOv8n model...") + model = YOLO('yolov8n.pt') + + print("Exporting to ONNX format...") + model.export(format="onnx", simplify=True) + + # Move exported file to artifacts + exported_file = Path("yolov8n.onnx") + if exported_file.exists(): + exported_file.rename(output_path) + + # Clean up PyTorch file + pt_file = Path("yolov8n.pt") + if pt_file.exists(): + pt_file.unlink() + + if not output_path.exists(): + raise FileNotFoundError(f"Failed to create ONNX file at {output_path}") + + +def process_model(input_path, output_path, target_opset=16): + """Load, upgrade opset, and apply shape inference to model.""" + print(f"Loading model from {input_path}...") + model = onnx.load(input_path) + + # Check and upgrade opset if needed + current_opset = model.opset_import[0].version + if current_opset < target_opset: + print(f"Upgrading opset from {current_opset} to {target_opset}...") + model = version_converter.convert_version(model, target_opset) + + # Apply shape inference + print("Applying shape inference...") + model = shape_inference.infer_shapes(model) + + # Save processed model + onnx.save(model, output_path) + print(f"✓ Processed model saved to: {output_path}") + + return model + + +def generate_test_data(model_path, output_dir): + """Generate test input/output data and save as PyTorch tensors.""" + import torch + import onnxruntime as ort + + print("\nGenerating test data...") + + # Load model to get input shape + model = onnx.load(model_path) + input_shape = get_input_shape(model) + print(f" Input shape: {input_shape}") + + # Create reproducible test input + np.random.seed(42) + test_input = np.random.rand(*input_shape).astype(np.float32) + + # Run inference to get output + session = ort.InferenceSession(model_path) + input_name = session.get_inputs()[0].name + outputs = session.run(None, {input_name: test_input}) + + # Save as PyTorch tensors + test_data = { + 'input': torch.from_numpy(test_input), + 'output': torch.from_numpy(outputs[0]) + } + + test_data_path = Path(output_dir) / "test_data.pt" + torch.save(test_data, test_data_path) + + print(f" ✓ Test data saved to: {test_data_path}") + print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}") + + +def main(): + print("=" * 60) + print("YOLOv8n Model Preparation Tool") + print("=" * 60) + + # Setup paths + artifacts_dir = Path("artifacts") + artifacts_dir.mkdir(exist_ok=True) + + original_path = artifacts_dir / "yolov8n.onnx" + processed_path = artifacts_dir / "yolov8n_opset16.onnx" + test_data_path = artifacts_dir / "test_data.pt" + + # Check if we already have everything + if processed_path.exists() and test_data_path.exists(): + print(f"\n✓ All files already exist:") + print(f" Model: {processed_path}") + print(f" Test data: {test_data_path}") + print("\nNothing to do!") + return + + # Download and convert if needed + if not original_path.exists() and not processed_path.exists(): + print("\nStep 1: Downloading and converting YOLOv8n model...") + download_and_convert_model(original_path) + + # Process model if needed + if not processed_path.exists(): + print("\nStep 2: Processing model...") + process_model(original_path, processed_path, target_opset=16) + + # Clean up original if we have the processed version + if original_path.exists(): + original_path.unlink() + + # Generate test data if needed + if not test_data_path.exists(): + print("\nStep 3: Generating test data...") + generate_test_data(processed_path, artifacts_dir) + + print("\n" + "=" * 60) + print("✓ YOLOv8n model preparation completed!") + print(f" Model: {processed_path}") + print(f" Test data: {test_data_path}") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n⚠ Operation cancelled by user.") + sys.exit(1) + except Exception as e: + print(f"\n✗ Error: {str(e)}") + sys.exit(1) \ No newline at end of file diff --git a/crates/burn-import/model-checks/yolov8n/src/main.rs b/crates/burn-import/model-checks/yolov8n/src/main.rs new file mode 100644 index 0000000000..5c421cb831 --- /dev/null +++ b/crates/burn-import/model-checks/yolov8n/src/main.rs @@ -0,0 +1,148 @@ +extern crate alloc; + +use burn::module::Param; +use burn::prelude::*; +use burn::record::*; + +use burn_import::pytorch::PyTorchFileRecorder; +use std::path::Path; +use std::time::Instant; + +#[cfg(feature = "wgpu")] +pub type MyBackend = burn::backend::Wgpu; + +#[cfg(feature = "ndarray")] +pub type MyBackend = burn::backend::NdArray; + +#[cfg(feature = "tch")] +pub type MyBackend = burn::backend::LibTorch; + +#[cfg(feature = "metal")] +pub type MyBackend = burn::backend::Metal; + +// Import the generated model code as a module +pub mod yolov8n { + include!(concat!(env!("OUT_DIR"), "/model/yolov8n_opset16.rs")); +} + +#[derive(Debug, Module)] +struct TestData { + input: Param>, + output: Param>, +} + +fn main() { + println!("========================================"); + println!("YOLOv8n Burn Model Test"); + println!("========================================\n"); + + // Check if artifacts exist + let artifacts_dir = Path::new("artifacts"); + if !artifacts_dir.exists() { + eprintln!("Error: artifacts directory not found!"); + eprintln!("Please run get_model.py first to download the model and test data."); + std::process::exit(1); + } + + // Initialize the model (without weights for now) + println!("Initializing YOLOv8n model..."); + let start = Instant::now(); + let device = Default::default(); + let model: yolov8n::Model = yolov8n::Model::default(); + let init_time = start.elapsed(); + println!(" Model initialized in {:.2?}", init_time); + + // Save model structure to file + println!("\nSaving model structure to artifacts/model.txt..."); + let model_str = format!("{}", model); + std::fs::write("artifacts/model.txt", &model_str) + .expect("Failed to write model structure to file"); + println!(" Model structure saved"); + + // Load test data from PyTorch file + println!("\nLoading test data from artifacts/test_data.pt..."); + let start = Instant::now(); + let test_data: TestDataRecord = PyTorchFileRecorder::::new() + .load("artifacts/test_data.pt".into(), &device) + .expect("Failed to load test data"); + let load_time = start.elapsed(); + println!(" Data loaded in {:.2?}", load_time); + + // Get the input tensor from test data + let input = test_data.input.val(); + let input_shape = input.shape(); + println!(" Loaded input tensor with shape: {:?}", input_shape.dims); + + // Get the reference output from test data + let reference_output = test_data.output.val(); + let reference_shape = reference_output.shape(); + println!( + " Loaded reference output with shape: {:?}", + reference_shape.dims + ); + + // Run inference with the loaded input + println!("\nRunning model inference with test input..."); + let start = Instant::now(); + let output = model.forward(input); + let inference_time = start.elapsed(); + println!(" Inference completed in {:.2?}", inference_time); + + // Display output shape + let shape = output.shape(); + println!("\n Model output shape: {:?}", shape.dims); + + // Verify expected output shape for YOLOv8n + let expected_shape = [1, 84, 8400]; + if shape.dims == expected_shape { + println!(" ✓ Output shape matches expected: {:?}", expected_shape); + } else { + println!( + " ⚠ Warning: Expected shape {:?}, got {:?}", + expected_shape, shape.dims + ); + } + + // Compare outputs + println!("\nComparing model output with reference data..."); + + // Check if outputs are close + if output + .clone() + .all_close(reference_output.clone(), Some(1e-4), Some(1e-4)) + { + println!(" ✓ Model output matches reference data within tolerance (1e-4)!"); + } else { + println!(" ⚠ Model output differs from reference data!"); + + // Calculate and display the difference statistics + let diff = output.clone() - reference_output.clone(); + let abs_diff = diff.abs(); + let max_diff = abs_diff.clone().max().into_scalar(); + let mean_diff = abs_diff.mean().into_scalar(); + + println!(" Maximum absolute difference: {:.6}", max_diff); + println!(" Mean absolute difference: {:.6}", mean_diff); + + // Show some sample values for debugging + println!("\n Sample values comparison (first 5 elements):"); + let output_flat = output.clone().flatten::<1>(0, 2); + let reference_flat = reference_output.clone().flatten::<1>(0, 2); + + for i in 0..5.min(output_flat.dims()[0]) { + let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar(); + let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar(); + println!( + " [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}", + i, + model_val, + ref_val, + (model_val - ref_val).abs() + ); + } + } + + println!("\n========================================"); + println!("Model test completed!"); + println!("========================================"); +} From e5c3773258a16cab7f0d2c85ed9d80980f91880b Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Sun, 21 Sep 2025 14:18:16 -0500 Subject: [PATCH 2/3] Generalize YOLO model checks to support multiple variants Renamed yolov8n model-checks crate to yolo and refactored code to support multiple YOLO variants (yolov5s, yolov8n, yolov8s, yolov10n, yolo11x). Added model selection via YOLO_MODEL environment variable, updated build script, Python model preparation script, and main.rs to handle dynamic model selection and output. Added README with usage instructions and supported models. Removed yolov8n-specific files. --- .../model-checks/{yolov8n => yolo}/Cargo.toml | 2 +- .../burn-import/model-checks/yolo/README.md | 55 ++++++++++++ crates/burn-import/model-checks/yolo/build.rs | 85 +++++++++++++++++++ .../{yolov8n => yolo}/get_model.py | 71 +++++++++++----- .../{yolov8n => yolo}/src/main.rs | 68 +++++++++++---- .../burn-import/model-checks/yolov8n/build.rs | 32 ------- 6 files changed, 246 insertions(+), 67 deletions(-) rename crates/burn-import/model-checks/{yolov8n => yolo}/Cargo.toml (90%) create mode 100644 crates/burn-import/model-checks/yolo/README.md create mode 100644 crates/burn-import/model-checks/yolo/build.rs rename crates/burn-import/model-checks/{yolov8n => yolo}/get_model.py (62%) rename crates/burn-import/model-checks/{yolov8n => yolo}/src/main.rs (67%) delete mode 100644 crates/burn-import/model-checks/yolov8n/build.rs diff --git a/crates/burn-import/model-checks/yolov8n/Cargo.toml b/crates/burn-import/model-checks/yolo/Cargo.toml similarity index 90% rename from crates/burn-import/model-checks/yolov8n/Cargo.toml rename to crates/burn-import/model-checks/yolo/Cargo.toml index 54d56e1483..a3e86827b4 100644 --- a/crates/burn-import/model-checks/yolov8n/Cargo.toml +++ b/crates/burn-import/model-checks/yolo/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "burn-import-model-checks-yolov8n" +name = "burn-import-model-checks-yolo" version = "0.1.0" edition = "2024" publish = false diff --git a/crates/burn-import/model-checks/yolo/README.md b/crates/burn-import/model-checks/yolo/README.md new file mode 100644 index 0000000000..270d06054e --- /dev/null +++ b/crates/burn-import/model-checks/yolo/README.md @@ -0,0 +1,55 @@ +# YOLO Model Checks + +This crate provides a unified interface for testing multiple YOLO model variants with Burn. + +## Supported Models + +- `yolov5s` - YOLOv5 small variant +- `yolov8n` - YOLOv8 nano variant +- `yolov8s` - YOLOv8 small variant +- `yolo11x` - YOLO11 extra-large variant +- `yolov10n` - YOLOv10 nano variant (Note: Currently fails due to TopK operator issue) + +## Usage + +### 1. Download and prepare a model + +```bash +# Using Python directly +python get_model.py --model yolov8n + +# Or using uv +uv run get_model.py --model yolov8n + +# List available models +uv run get_model.py --list +``` + +### 2. Run the model test + +After building, you can run the test. The model is already compiled in: + +```bash +YOLO_MODEL=yolov8s cargo run --release +``` + +## Directory Structure + +``` +yolo/ +├── artifacts/ # Downloaded ONNX models and test data +│ ├── yolov8n_opset16.onnx +│ ├── yolov8n_test_data.pt +│ └── ... +├── src/ +│ └── main.rs # Test runner +├── build.rs # Build script that generates model code +├── get_model.py # Model download and preparation script +└── Cargo.toml +``` + +## Notes + +- All YOLO models (except v10) output shape `[1, 84, 8400]` for standard object detection +- YOLOv10n has a different architecture with output shape `[1, 300, 6]` and uses TopK operator +- The crate requires explicit model selection at build time (no default model) diff --git a/crates/burn-import/model-checks/yolo/build.rs b/crates/burn-import/model-checks/yolo/build.rs new file mode 100644 index 0000000000..8d3d71445e --- /dev/null +++ b/crates/burn-import/model-checks/yolo/build.rs @@ -0,0 +1,85 @@ +use burn_import::onnx::ModelGen; +use std::env; +use std::fs; +use std::path::Path; + +fn main() { + // Supported models + let supported_models = vec!["yolov5s", "yolov8n", "yolov8s", "yolov10n", "yolo11x"]; + + // Get the model name from environment variable (required) + let model_name = env::var("YOLO_MODEL").unwrap_or_else(|_| { + eprintln!("Error: YOLO_MODEL environment variable is not set."); + eprintln!(); + eprintln!("Please specify which YOLO model to build:"); + eprintln!(" YOLO_MODEL=yolov8n cargo build"); + eprintln!(); + eprintln!("Available models: {}", supported_models.join(", ")); + std::process::exit(1); + }); + + if !supported_models.contains(&model_name.as_str()) { + eprintln!( + "Error: Unsupported model '{}'. Supported models: {:?}", + model_name, supported_models + ); + std::process::exit(1); + } + + let onnx_path = format!("artifacts/{}_opset16.onnx", model_name); + let test_data_path = format!("artifacts/{}_test_data.pt", model_name); + + // Tell Cargo to only rebuild if these files change + println!("cargo:rerun-if-changed={}", onnx_path); + println!("cargo:rerun-if-changed={}", test_data_path); + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-env-changed=YOLO_MODEL"); + + // Check if the ONNX model file exists + if !Path::new(&onnx_path).exists() { + eprintln!("Error: ONNX model file not found at '{}'", onnx_path); + eprintln!(); + eprintln!( + "Please run the following command to download and prepare the {} model:", + model_name + ); + eprintln!(" python get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Or if you prefer using uv:"); + eprintln!(" uv run get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Available models: {}", supported_models.join(", ")); + std::process::exit(1); + } + + // Generate the model code from the ONNX file + ModelGen::new() + .input(&onnx_path) + .out_dir("model/") + .run_from_script(); + + // Write the model name to a file so main.rs can access it + let out_dir = env::var("OUT_DIR").unwrap(); + let model_info_path = Path::new(&out_dir).join("model_info.rs"); + + // Generate the include path for the model + let model_include = format!( + "include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));", + model_name + ); + + fs::write( + model_info_path, + format!( + r#"pub const MODEL_NAME: &str = "{}"; +pub const TEST_DATA_FILE: &str = "{}_test_data.pt"; + +// Include the generated model +pub mod yolo_model {{ + {} +}}"#, + model_name, model_name, model_include + ), + ) + .expect("Failed to write model info"); +} diff --git a/crates/burn-import/model-checks/yolov8n/get_model.py b/crates/burn-import/model-checks/yolo/get_model.py similarity index 62% rename from crates/burn-import/model-checks/yolov8n/get_model.py rename to crates/burn-import/model-checks/yolo/get_model.py index 5c36386d12..4506c861bf 100755 --- a/crates/burn-import/model-checks/yolov8n/get_model.py +++ b/crates/burn-import/model-checks/yolo/get_model.py @@ -17,6 +17,17 @@ from onnx import shape_inference, version_converter import numpy as np from pathlib import Path +import argparse + + +# Supported YOLO models configuration +SUPPORTED_MODELS = { + 'yolov5s': {'download_name': 'yolov5s.pt', 'display_name': 'YOLOv5s'}, + 'yolov8n': {'download_name': 'yolov8n.pt', 'display_name': 'YOLOv8n'}, + 'yolov8s': {'download_name': 'yolov8s.pt', 'display_name': 'YOLOv8s'}, + 'yolov10n': {'download_name': 'yolov10n.pt', 'display_name': 'YOLOv10n'}, + 'yolo11x': {'download_name': 'yolo11x.pt', 'display_name': 'YOLO11x'}, +} def get_input_shape(model): @@ -35,23 +46,28 @@ def get_input_shape(model): return shape -def download_and_convert_model(output_path): - """Download YOLOv8n model and export to ONNX format.""" +def download_and_convert_model(model_name, output_path): + """Download YOLO model and export to ONNX format.""" from ultralytics import YOLO - print("Downloading YOLOv8n model...") - model = YOLO('yolov8n.pt') + model_config = SUPPORTED_MODELS[model_name] + display_name = model_config['display_name'] + download_name = model_config['download_name'] + + print(f"Downloading {display_name} model...") + model = YOLO(download_name) print("Exporting to ONNX format...") model.export(format="onnx", simplify=True) # Move exported file to artifacts - exported_file = Path("yolov8n.onnx") + base_name = download_name.replace('.pt', '') + exported_file = Path(f"{base_name}.onnx") if exported_file.exists(): exported_file.rename(output_path) # Clean up PyTorch file - pt_file = Path("yolov8n.pt") + pt_file = Path(download_name) if pt_file.exists(): pt_file.unlink() @@ -81,7 +97,7 @@ def process_model(input_path, output_path, target_opset=16): return model -def generate_test_data(model_path, output_dir): +def generate_test_data(model_path, output_path, model_name): """Generate test input/output data and save as PyTorch tensors.""" import torch import onnxruntime as ort @@ -108,29 +124,46 @@ def generate_test_data(model_path, output_dir): 'output': torch.from_numpy(outputs[0]) } - test_data_path = Path(output_dir) / "test_data.pt" - torch.save(test_data, test_data_path) + torch.save(test_data, output_path) - print(f" ✓ Test data saved to: {test_data_path}") + print(f" ✓ Test data saved to: {output_path}") print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}") def main(): + parser = argparse.ArgumentParser(description='YOLO Model Preparation Tool') + parser.add_argument('--model', type=str, default='yolov8n', + choices=list(SUPPORTED_MODELS.keys()), + help=f'YOLO model to download and prepare (default: yolov8n). Choices: {", ".join(SUPPORTED_MODELS.keys())}') + parser.add_argument('--list', action='store_true', + help='List all supported models') + + args = parser.parse_args() + + if args.list: + print("Supported YOLO models:") + for model_id, config in SUPPORTED_MODELS.items(): + print(f" - {model_id:10s} ({config['display_name']})") + return + + model_name = args.model + display_name = SUPPORTED_MODELS[model_name]['display_name'] + print("=" * 60) - print("YOLOv8n Model Preparation Tool") + print(f"{display_name} Model Preparation Tool") print("=" * 60) # Setup paths artifacts_dir = Path("artifacts") artifacts_dir.mkdir(exist_ok=True) - original_path = artifacts_dir / "yolov8n.onnx" - processed_path = artifacts_dir / "yolov8n_opset16.onnx" - test_data_path = artifacts_dir / "test_data.pt" + original_path = artifacts_dir / f"{model_name}.onnx" + processed_path = artifacts_dir / f"{model_name}_opset16.onnx" + test_data_path = artifacts_dir / f"{model_name}_test_data.pt" # Check if we already have everything if processed_path.exists() and test_data_path.exists(): - print(f"\n✓ All files already exist:") + print(f"\n✓ All files already exist for {display_name}:") print(f" Model: {processed_path}") print(f" Test data: {test_data_path}") print("\nNothing to do!") @@ -138,8 +171,8 @@ def main(): # Download and convert if needed if not original_path.exists() and not processed_path.exists(): - print("\nStep 1: Downloading and converting YOLOv8n model...") - download_and_convert_model(original_path) + print(f"\nStep 1: Downloading and converting {display_name} model...") + download_and_convert_model(model_name, original_path) # Process model if needed if not processed_path.exists(): @@ -153,10 +186,10 @@ def main(): # Generate test data if needed if not test_data_path.exists(): print("\nStep 3: Generating test data...") - generate_test_data(processed_path, artifacts_dir) + generate_test_data(processed_path, test_data_path, model_name) print("\n" + "=" * 60) - print("✓ YOLOv8n model preparation completed!") + print(f"✓ {display_name} model preparation completed!") print(f" Model: {processed_path}") print(f" Test data: {test_data_path}") print("=" * 60) diff --git a/crates/burn-import/model-checks/yolov8n/src/main.rs b/crates/burn-import/model-checks/yolo/src/main.rs similarity index 67% rename from crates/burn-import/model-checks/yolov8n/src/main.rs rename to crates/burn-import/model-checks/yolo/src/main.rs index 5c421cb831..2c11abea52 100644 --- a/crates/burn-import/model-checks/yolov8n/src/main.rs +++ b/crates/burn-import/model-checks/yolo/src/main.rs @@ -5,6 +5,7 @@ use burn::prelude::*; use burn::record::*; use burn_import::pytorch::PyTorchFileRecorder; +use std::env; use std::path::Path; use std::time::Instant; @@ -20,10 +21,11 @@ pub type MyBackend = burn::backend::LibTorch; #[cfg(feature = "metal")] pub type MyBackend = burn::backend::Metal; -// Import the generated model code as a module -pub mod yolov8n { - include!(concat!(env!("OUT_DIR"), "/model/yolov8n_opset16.rs")); -} +// Import model info generated by build.rs (includes the yolo_model module) +include!(concat!(env!("OUT_DIR"), "/model_info.rs")); + +// Use the yolo_model module from model_info.rs +use yolo_model::Model; #[derive(Debug, Module)] struct TestData { @@ -31,9 +33,24 @@ struct TestData { output: Param>, } +fn get_model_display_name(model_name: &str) -> &str { + match model_name { + "yolov5s" => "YOLOv5s", + "yolov8n" => "YOLOv8n", + "yolov8s" => "YOLOv8s", + "yolov10n" => "YOLOv10n", + "yolo11x" => "YOLO11x", + _ => model_name, + } +} + fn main() { + // MODEL_NAME is set at build time from YOLO_MODEL env var + let model_name = MODEL_NAME; + let display_name = get_model_display_name(model_name); + println!("========================================"); - println!("YOLOv8n Burn Model Test"); + println!("{} Burn Model Test", display_name); println!("========================================\n"); // Check if artifacts exist @@ -41,29 +58,50 @@ fn main() { if !artifacts_dir.exists() { eprintln!("Error: artifacts directory not found!"); eprintln!("Please run get_model.py first to download the model and test data."); + eprintln!("Example: uv run get_model.py --model {}", model_name); + std::process::exit(1); + } + + // Check if model files exist for this specific model + let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name)); + let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name)); + + if !model_file.exists() || !test_data_file.exists() { + eprintln!("Error: Model files not found for {}!", display_name); + eprintln!("Please run: uv run get_model.py --model {}", model_name); + eprintln!(); + eprintln!("Available models:"); + eprintln!(" - yolov5s"); + eprintln!(" - yolov8n"); + eprintln!(" - yolov8s"); + eprintln!(" - yolov10n"); + eprintln!(" - yolo11x"); std::process::exit(1); } // Initialize the model (without weights for now) - println!("Initializing YOLOv8n model..."); + println!("Initializing {} model...", display_name); let start = Instant::now(); let device = Default::default(); - let model: yolov8n::Model = yolov8n::Model::default(); + let model: Model = Model::default(); let init_time = start.elapsed(); println!(" Model initialized in {:.2?}", init_time); // Save model structure to file - println!("\nSaving model structure to artifacts/model.txt..."); + let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name)); + println!( + "\nSaving model structure to {}...", + model_txt_path.display() + ); let model_str = format!("{}", model); - std::fs::write("artifacts/model.txt", &model_str) - .expect("Failed to write model structure to file"); + std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file"); println!(" Model structure saved"); // Load test data from PyTorch file - println!("\nLoading test data from artifacts/test_data.pt..."); + println!("\nLoading test data from {}...", test_data_file.display()); let start = Instant::now(); let test_data: TestDataRecord = PyTorchFileRecorder::::new() - .load("artifacts/test_data.pt".into(), &device) + .load(test_data_file.into(), &device) .expect("Failed to load test data"); let load_time = start.elapsed(); println!(" Data loaded in {:.2?}", load_time); @@ -92,14 +130,14 @@ fn main() { let shape = output.shape(); println!("\n Model output shape: {:?}", shape.dims); - // Verify expected output shape for YOLOv8n + // Verify expected output shape (most YOLO models use [1, 84, 8400]) let expected_shape = [1, 84, 8400]; if shape.dims == expected_shape { println!(" ✓ Output shape matches expected: {:?}", expected_shape); } else { println!( - " ⚠ Warning: Expected shape {:?}, got {:?}", - expected_shape, shape.dims + " ⚠ Note: Shape is {:?} (expected {:?} for most YOLO models)", + shape.dims, expected_shape ); } diff --git a/crates/burn-import/model-checks/yolov8n/build.rs b/crates/burn-import/model-checks/yolov8n/build.rs deleted file mode 100644 index 8c0600cdce..0000000000 --- a/crates/burn-import/model-checks/yolov8n/build.rs +++ /dev/null @@ -1,32 +0,0 @@ -use burn_import::onnx::ModelGen; -use std::path::Path; - -fn main() { - let onnx_path = "artifacts/yolov8n_opset16.onnx"; - let test_data_path = "artifacts/test_data.pt"; - - // Tell Cargo to only rebuild if these files change - println!("cargo:rerun-if-changed={}", onnx_path); - println!("cargo:rerun-if-changed={}", test_data_path); - println!("cargo:rerun-if-changed=build.rs"); - - // Check if the ONNX model file exists - if !Path::new(onnx_path).exists() { - eprintln!("Error: ONNX model file not found at '{}'", onnx_path); - eprintln!(); - eprintln!("Please run the following command to download and prepare the model:"); - eprintln!(" python get_model.py"); - eprintln!(); - eprintln!("Or if you prefer using uv:"); - eprintln!(" uv run get_model.py"); - eprintln!(); - eprintln!("This will download the YOLOv8n model and convert it to ONNX format."); - std::process::exit(1); - } - - // Generate the model code from the ONNX file - ModelGen::new() - .input(onnx_path) - .out_dir("model/") - .run_from_script(); -} From a1c69d5df99deff6e34a98b2af7f2299a1be7879 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Sun, 21 Sep 2025 14:20:57 -0500 Subject: [PATCH 3/3] Remove YOLO11x model check and related files Deleted the YOLO11x model check directory, including Cargo.toml, build script, model preparation Python script, and main Rust source. This removes support and tests for the YOLO11x model from burn-import/model-checks. --- .../model-checks/yolo11x/Cargo.toml | 26 --- .../burn-import/model-checks/yolo11x/build.rs | 32 ---- .../model-checks/yolo11x/get_model.py | 173 ------------------ .../model-checks/yolo11x/src/main.rs | 148 --------------- 4 files changed, 379 deletions(-) delete mode 100644 crates/burn-import/model-checks/yolo11x/Cargo.toml delete mode 100644 crates/burn-import/model-checks/yolo11x/build.rs delete mode 100755 crates/burn-import/model-checks/yolo11x/get_model.py delete mode 100644 crates/burn-import/model-checks/yolo11x/src/main.rs diff --git a/crates/burn-import/model-checks/yolo11x/Cargo.toml b/crates/burn-import/model-checks/yolo11x/Cargo.toml deleted file mode 100644 index 3a586ae632..0000000000 --- a/crates/burn-import/model-checks/yolo11x/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -[package] -name = "burn-import-model-checks-yolo11x" -version = "0.1.0" -edition = "2024" -publish = false - -[workspace] - -[features] -default = ["tch"] -ndarray = [] -tch = [] -wgpu = [] -metal = [] - -[dependencies] -burn = { path = "../../../../crates/burn", features = [ - "ndarray", - "tch", - "wgpu", - "metal", -] } -burn-import = { path = "../../../burn-import", features = ["pytorch"] } - -[build-dependencies] -burn-import = { path = "../../../burn-import" } diff --git a/crates/burn-import/model-checks/yolo11x/build.rs b/crates/burn-import/model-checks/yolo11x/build.rs deleted file mode 100644 index 625a22f2ba..0000000000 --- a/crates/burn-import/model-checks/yolo11x/build.rs +++ /dev/null @@ -1,32 +0,0 @@ -use burn_import::onnx::ModelGen; -use std::path::Path; - -fn main() { - let onnx_path = "artifacts/yolo11x_opset16.onnx"; - let test_data_path = "artifacts/test_data.pt"; - - // Tell Cargo to only rebuild if these files change - println!("cargo:rerun-if-changed={}", onnx_path); - println!("cargo:rerun-if-changed={}", test_data_path); - println!("cargo:rerun-if-changed=build.rs"); - - // Check if the ONNX model file exists - if !Path::new(onnx_path).exists() { - eprintln!("Error: ONNX model file not found at '{}'", onnx_path); - eprintln!(); - eprintln!("Please run the following command to download and prepare the model:"); - eprintln!(" python get_model.py"); - eprintln!(); - eprintln!("Or if you prefer using uv:"); - eprintln!(" uv run get_model.py"); - eprintln!(); - eprintln!("This will download the YOLO11x model and convert it to ONNX format."); - std::process::exit(1); - } - - // Generate the model code from the ONNX file - ModelGen::new() - .input(onnx_path) - .out_dir("model/") - .run_from_script(); -} \ No newline at end of file diff --git a/crates/burn-import/model-checks/yolo11x/get_model.py b/crates/burn-import/model-checks/yolo11x/get_model.py deleted file mode 100755 index 7e76229d8b..0000000000 --- a/crates/burn-import/model-checks/yolo11x/get_model.py +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env -S uv run --script - -# /// script -# dependencies = [ -# "onnx-weekly==1.19.0.dev20250419", -# "onnxruntime>=1.22.0", -# "ultralytics>=8.3.0", -# "numpy", -# "pillow", -# "torch", -# ] -# /// - -import os -import sys -import onnx -from onnx import shape_inference, version_converter -import numpy as np -from pathlib import Path - - -def get_input_shape(model): - """Extract input shape from ONNX model.""" - input_info = model.graph.input[0] - shape = [] - for dim in input_info.type.tensor_type.shape.dim: - if dim.HasField('dim_value'): - shape.append(dim.dim_value) - else: - shape.append(1) # Default to 1 for dynamic dimensions - - # Ensure valid YOLO input shape - if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000: - return [1, 3, 640, 640] - return shape - - -def download_and_convert_model(output_path): - """Download YOLO11x model and export to ONNX format.""" - from ultralytics import YOLO - - print("Downloading YOLO11x model...") - pt_path = Path("artifacts/yolo11x.pt") - model = YOLO(str(pt_path)) - - print("Exporting to ONNX format...") - model.export(format="onnx", simplify=True) - - # Move exported file to artifacts - exported_file = Path("yolo11x.onnx") - if exported_file.exists(): - exported_file.rename(output_path) - - # Clean up PyTorch file - if pt_path.exists(): - pt_path.unlink() - - if not output_path.exists(): - raise FileNotFoundError(f"Failed to create ONNX file at {output_path}") - - -def process_model(input_path, output_path, target_opset=16): - """Load, upgrade opset, and apply shape inference to model.""" - print(f"Loading model from {input_path}...") - model = onnx.load(input_path) - - # Check and upgrade opset if needed - current_opset = model.opset_import[0].version - if current_opset < target_opset: - print(f"Upgrading opset from {current_opset} to {target_opset}...") - model = version_converter.convert_version(model, target_opset) - - # Apply shape inference - print("Applying shape inference...") - model = shape_inference.infer_shapes(model) - - # Save processed model - onnx.save(model, output_path) - print(f"✓ Processed model saved to: {output_path}") - - return model - - -def generate_test_data(model_path, output_dir): - """Generate test input/output data and save as PyTorch tensors.""" - import torch - import onnxruntime as ort - - print("\nGenerating test data...") - - # Load model to get input shape - model = onnx.load(model_path) - input_shape = get_input_shape(model) - print(f" Input shape: {input_shape}") - - # Create reproducible test input - np.random.seed(42) - test_input = np.random.rand(*input_shape).astype(np.float32) - - # Run inference to get output - session = ort.InferenceSession(model_path) - input_name = session.get_inputs()[0].name - outputs = session.run(None, {input_name: test_input}) - - # Save as PyTorch tensors - test_data = { - 'input': torch.from_numpy(test_input), - 'output': torch.from_numpy(outputs[0]) - } - - test_data_path = Path(output_dir) / "test_data.pt" - torch.save(test_data, test_data_path) - - print(f" ✓ Test data saved to: {test_data_path}") - print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}") - - -def main(): - print("=" * 60) - print("YOLO11x Model Preparation Tool") - print("=" * 60) - - # Setup paths - artifacts_dir = Path("artifacts") - artifacts_dir.mkdir(exist_ok=True) - - original_path = artifacts_dir / "yolo11x.onnx" - processed_path = artifacts_dir / "yolo11x_opset16.onnx" - test_data_path = artifacts_dir / "test_data.pt" - - # Check if we already have everything - if processed_path.exists() and test_data_path.exists(): - print(f"\n✓ All files already exist:") - print(f" Model: {processed_path}") - print(f" Test data: {test_data_path}") - print("\nNothing to do!") - return - - # Download and convert if needed - if not original_path.exists() and not processed_path.exists(): - print("\nStep 1: Downloading and converting YOLO11x model...") - download_and_convert_model(original_path) - - # Process model if needed - if not processed_path.exists(): - print("\nStep 2: Processing model...") - process_model(original_path, processed_path, target_opset=16) - - # Clean up original if we have the processed version - if original_path.exists(): - original_path.unlink() - - # Generate test data if needed - if not test_data_path.exists(): - print("\nStep 3: Generating test data...") - generate_test_data(processed_path, artifacts_dir) - - print("\n" + "=" * 60) - print("✓ YOLO11x model preparation completed!") - print(f" Model: {processed_path}") - print(f" Test data: {test_data_path}") - print("=" * 60) - - -if __name__ == "__main__": - try: - main() - except KeyboardInterrupt: - print("\n⚠ Operation cancelled by user.") - sys.exit(1) - except Exception as e: - print(f"\n✗ Error: {str(e)}") - sys.exit(1) diff --git a/crates/burn-import/model-checks/yolo11x/src/main.rs b/crates/burn-import/model-checks/yolo11x/src/main.rs deleted file mode 100644 index 7de0d2a150..0000000000 --- a/crates/burn-import/model-checks/yolo11x/src/main.rs +++ /dev/null @@ -1,148 +0,0 @@ -extern crate alloc; - -use burn::module::Param; -use burn::prelude::*; -use burn::record::*; - -use burn_import::pytorch::PyTorchFileRecorder; -use std::path::Path; -use std::time::Instant; - -#[cfg(feature = "wgpu")] -pub type MyBackend = burn::backend::Wgpu; - -#[cfg(feature = "ndarray")] -pub type MyBackend = burn::backend::NdArray; - -#[cfg(feature = "tch")] -pub type MyBackend = burn::backend::LibTorch; - -#[cfg(feature = "metal")] -pub type MyBackend = burn::backend::Metal; - -// Import the generated model code as a module -pub mod yolo11x { - include!(concat!(env!("OUT_DIR"), "/model/yolo11x_opset16.rs")); -} - -#[derive(Debug, Module)] -struct TestData { - input: Param>, - output: Param>, -} - -fn main() { - println!("========================================"); - println!("YOLO11x Burn Model Test"); - println!("========================================\n"); - - // Check if artifacts exist - let artifacts_dir = Path::new("artifacts"); - if !artifacts_dir.exists() { - eprintln!("Error: artifacts directory not found!"); - eprintln!("Please run get_model.py first to download the model and test data."); - std::process::exit(1); - } - - // Initialize the model (without weights for now) - println!("Initializing YOLO11x model..."); - let start = Instant::now(); - let device = Default::default(); - let model: yolo11x::Model = yolo11x::Model::default(); - let init_time = start.elapsed(); - println!(" Model initialized in {:.2?}", init_time); - - // Save model structure to file - println!("\nSaving model structure to artifacts/model.txt..."); - let model_str = format!("{}", model); - std::fs::write("artifacts/model.txt", &model_str) - .expect("Failed to write model structure to file"); - println!(" Model structure saved"); - - // Load test data from PyTorch file - println!("\nLoading test data from artifacts/test_data.pt..."); - let start = Instant::now(); - let test_data: TestDataRecord = PyTorchFileRecorder::::new() - .load("artifacts/test_data.pt".into(), &device) - .expect("Failed to load test data"); - let load_time = start.elapsed(); - println!(" Data loaded in {:.2?}", load_time); - - // Get the input tensor from test data - let input = test_data.input.val(); - let input_shape = input.shape(); - println!(" Loaded input tensor with shape: {:?}", input_shape.dims); - - // Get the reference output from test data - let reference_output = test_data.output.val(); - let reference_shape = reference_output.shape(); - println!( - " Loaded reference output with shape: {:?}", - reference_shape.dims - ); - - // Run inference with the loaded input - println!("\nRunning model inference with test input..."); - let start = Instant::now(); - let output = model.forward(input); - let inference_time = start.elapsed(); - println!(" Inference completed in {:.2?}", inference_time); - - // Display output shape - let shape = output.shape(); - println!("\n Model output shape: {:?}", shape.dims); - - // Verify expected output shape - let expected_shape = [1, 84, 8400]; - if shape.dims == expected_shape { - println!(" ✓ Output shape matches expected: {:?}", expected_shape); - } else { - println!( - " ⚠ Warning: Expected shape {:?}, got {:?}", - expected_shape, shape.dims - ); - } - - // Compare outputs - println!("\nComparing model output with reference data..."); - - // Check if outputs are close - if output - .clone() - .all_close(reference_output.clone(), Some(1e-4), Some(1e-4)) - { - println!(" ✓ Model output matches reference data within tolerance (1e-4)!"); - } else { - println!(" ⚠ Model output differs from reference data!"); - - // Calculate and display the difference statistics - let diff = output.clone() - reference_output.clone(); - let abs_diff = diff.abs(); - let max_diff = abs_diff.clone().max().into_scalar(); - let mean_diff = abs_diff.mean().into_scalar(); - - println!(" Maximum absolute difference: {:.6}", max_diff); - println!(" Mean absolute difference: {:.6}", mean_diff); - - // Show some sample values for debugging - println!("\n Sample values comparison (first 5 elements):"); - let output_flat = output.clone().flatten::<1>(0, 2); - let reference_flat = reference_output.clone().flatten::<1>(0, 2); - - for i in 0..5.min(output_flat.dims()[0]) { - let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar(); - let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar(); - println!( - " [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}", - i, - model_val, - ref_val, - (model_val - ref_val).abs() - ); - } - } - - println!("\n========================================"); - println!("Model test completed!"); - println!("========================================"); -}