Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "burn-import-model-checks-yolo11x"
name = "burn-import-model-checks-yolo"
version = "0.1.0"
edition = "2024"
publish = false
Expand All @@ -23,4 +23,4 @@ burn = { path = "../../../../crates/burn", features = [
burn-import = { path = "../../../burn-import", features = ["pytorch"] }

[build-dependencies]
burn-import = { path = "../../../burn-import" }
burn-import = { path = "../../../burn-import" }
55 changes: 55 additions & 0 deletions crates/burn-import/model-checks/yolo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# YOLO Model Checks

This crate provides a unified interface for testing multiple YOLO model variants with Burn.

## Supported Models

- `yolov5s` - YOLOv5 small variant
- `yolov8n` - YOLOv8 nano variant
- `yolov8s` - YOLOv8 small variant
- `yolo11x` - YOLO11 extra-large variant
- `yolov10n` - YOLOv10 nano variant (Note: Currently fails due to TopK operator issue)

## Usage

### 1. Download and prepare a model

```bash
# Using Python directly
python get_model.py --model yolov8n

# Or using uv
uv run get_model.py --model yolov8n

# List available models
uv run get_model.py --list
```

### 2. Run the model test

After building, you can run the test. The model is already compiled in:

```bash
YOLO_MODEL=yolov8s cargo run --release
```

## Directory Structure

```
yolo/
├── artifacts/ # Downloaded ONNX models and test data
│ ├── yolov8n_opset16.onnx
│ ├── yolov8n_test_data.pt
│ └── ...
├── src/
│ └── main.rs # Test runner
├── build.rs # Build script that generates model code
├── get_model.py # Model download and preparation script
└── Cargo.toml
```

## Notes

- All YOLO models (except v10) output shape `[1, 84, 8400]` for standard object detection
- YOLOv10n has a different architecture with output shape `[1, 300, 6]` and uses TopK operator
- The crate requires explicit model selection at build time (no default model)
85 changes: 85 additions & 0 deletions crates/burn-import/model-checks/yolo/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use burn_import::onnx::ModelGen;
use std::env;
use std::fs;
use std::path::Path;

fn main() {
// Supported models
let supported_models = vec!["yolov5s", "yolov8n", "yolov8s", "yolov10n", "yolo11x"];

// Get the model name from environment variable (required)
let model_name = env::var("YOLO_MODEL").unwrap_or_else(|_| {
eprintln!("Error: YOLO_MODEL environment variable is not set.");
eprintln!();
eprintln!("Please specify which YOLO model to build:");
eprintln!(" YOLO_MODEL=yolov8n cargo build");
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
});

if !supported_models.contains(&model_name.as_str()) {
eprintln!(
"Error: Unsupported model '{}'. Supported models: {:?}",
model_name, supported_models
);
std::process::exit(1);
}

let onnx_path = format!("artifacts/{}_opset16.onnx", model_name);
let test_data_path = format!("artifacts/{}_test_data.pt", model_name);

// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=YOLO_MODEL");

// Check if the ONNX model file exists
if !Path::new(&onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!(
"Please run the following command to download and prepare the {} model:",
model_name
);
eprintln!(" python get_model.py --model {}", model_name);
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py --model {}", model_name);
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
}

// Generate the model code from the ONNX file
ModelGen::new()
.input(&onnx_path)
.out_dir("model/")
.run_from_script();

// Write the model name to a file so main.rs can access it
let out_dir = env::var("OUT_DIR").unwrap();
let model_info_path = Path::new(&out_dir).join("model_info.rs");

// Generate the include path for the model
let model_include = format!(
"include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));",
model_name
);

fs::write(
model_info_path,
format!(
r#"pub const MODEL_NAME: &str = "{}";
pub const TEST_DATA_FILE: &str = "{}_test_data.pt";

// Include the generated model
pub mod yolo_model {{
{}
}}"#,
model_name, model_name, model_include
),
)
.expect("Failed to write model info");
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# /// script
# dependencies = [
# "onnx-weekly==1.19.0.dev20250419",
# "onnxruntime>=1.22.0",
# "onnx>=1.17.0",
# "onnxruntime>=1.18.0",
# "ultralytics>=8.3.0",
# "numpy",
# "pillow",
Expand All @@ -17,6 +17,17 @@
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
import argparse


# Supported YOLO models configuration
SUPPORTED_MODELS = {
'yolov5s': {'download_name': 'yolov5s.pt', 'display_name': 'YOLOv5s'},
'yolov8n': {'download_name': 'yolov8n.pt', 'display_name': 'YOLOv8n'},
'yolov8s': {'download_name': 'yolov8s.pt', 'display_name': 'YOLOv8s'},
'yolov10n': {'download_name': 'yolov10n.pt', 'display_name': 'YOLOv10n'},
'yolo11x': {'download_name': 'yolo11x.pt', 'display_name': 'YOLO11x'},
}


def get_input_shape(model):
Expand All @@ -35,25 +46,30 @@ def get_input_shape(model):
return shape


def download_and_convert_model(output_path):
"""Download YOLO11x model and export to ONNX format."""
def download_and_convert_model(model_name, output_path):
"""Download YOLO model and export to ONNX format."""
from ultralytics import YOLO

print("Downloading YOLO11x model...")
pt_path = Path("artifacts/yolo11x.pt")
model = YOLO(str(pt_path))
model_config = SUPPORTED_MODELS[model_name]
display_name = model_config['display_name']
download_name = model_config['download_name']

print(f"Downloading {display_name} model...")
model = YOLO(download_name)

print("Exporting to ONNX format...")
model.export(format="onnx", simplify=True)

# Move exported file to artifacts
exported_file = Path("yolo11x.onnx")
base_name = download_name.replace('.pt', '')
exported_file = Path(f"{base_name}.onnx")
if exported_file.exists():
exported_file.rename(output_path)

# Clean up PyTorch file
if pt_path.exists():
pt_path.unlink()
pt_file = Path(download_name)
if pt_file.exists():
pt_file.unlink()

if not output_path.exists():
raise FileNotFoundError(f"Failed to create ONNX file at {output_path}")
Expand Down Expand Up @@ -81,7 +97,7 @@ def process_model(input_path, output_path, target_opset=16):
return model


def generate_test_data(model_path, output_dir):
def generate_test_data(model_path, output_path, model_name):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
Expand All @@ -108,38 +124,55 @@ def generate_test_data(model_path, output_dir):
'output': torch.from_numpy(outputs[0])
}

test_data_path = Path(output_dir) / "test_data.pt"
torch.save(test_data, test_data_path)
torch.save(test_data, output_path)

print(f" ✓ Test data saved to: {test_data_path}")
print(f" ✓ Test data saved to: {output_path}")
print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}")


def main():
parser = argparse.ArgumentParser(description='YOLO Model Preparation Tool')
parser.add_argument('--model', type=str, default='yolov8n',
choices=list(SUPPORTED_MODELS.keys()),
help=f'YOLO model to download and prepare (default: yolov8n). Choices: {", ".join(SUPPORTED_MODELS.keys())}')
parser.add_argument('--list', action='store_true',
help='List all supported models')

args = parser.parse_args()

if args.list:
print("Supported YOLO models:")
for model_id, config in SUPPORTED_MODELS.items():
print(f" - {model_id:10s} ({config['display_name']})")
return

model_name = args.model
display_name = SUPPORTED_MODELS[model_name]['display_name']

print("=" * 60)
print("YOLO11x Model Preparation Tool")
print(f"{display_name} Model Preparation Tool")
print("=" * 60)

# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)

original_path = artifacts_dir / "yolo11x.onnx"
processed_path = artifacts_dir / "yolo11x_opset16.onnx"
test_data_path = artifacts_dir / "test_data.pt"
original_path = artifacts_dir / f"{model_name}.onnx"
processed_path = artifacts_dir / f"{model_name}_opset16.onnx"
test_data_path = artifacts_dir / f"{model_name}_test_data.pt"

# Check if we already have everything
if processed_path.exists() and test_data_path.exists():
print(f"\n✓ All files already exist:")
print(f"\n✓ All files already exist for {display_name}:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("\nNothing to do!")
return

# Download and convert if needed
if not original_path.exists() and not processed_path.exists():
print("\nStep 1: Downloading and converting YOLO11x model...")
download_and_convert_model(original_path)
print(f"\nStep 1: Downloading and converting {display_name} model...")
download_and_convert_model(model_name, original_path)

# Process model if needed
if not processed_path.exists():
Expand All @@ -153,10 +186,10 @@ def main():
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, artifacts_dir)
generate_test_data(processed_path, test_data_path, model_name)

print("\n" + "=" * 60)
print("✓ YOLO11x model preparation completed!")
print(f"✓ {display_name} model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("=" * 60)
Expand All @@ -170,4 +203,4 @@ def main():
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
sys.exit(1)
sys.exit(1)
Loading