Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions lib/rust/hmll-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[package]
name = "hmll-sys"
version = "0.1.0"
edition = "2021"
authors = ["Morgan Funtowicz <[email protected]>"]
description = "Low-level FFI bindings to the hmll library"
license = "MIT OR Apache-2.0"
repository = "https://github.com/huggingface/hmll"
links = "hmll"

[dependencies]

[build-dependencies]
bindgen = "0.72.1"
cmake = "0.1"

[features]
default = ["io_uring"]
io_uring = []
safetensors = []
cuda = []

[profile.release]
lto = "fat" # Full LTO for maximum optimization across crates
opt-level = 3 # Maximum optimization level
codegen-units = 1 # Single codegen unit for better optimization
strip = true # Strip symbols for smaller binary size

[profile.bench]
inherits = "release"
lto = "fat"
opt-level = 3
codegen-units = 1
140 changes: 140 additions & 0 deletions lib/rust/hmll-sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use std::env;
use std::path::PathBuf;

fn main() {
// Get the project root (3 levels up from hmll-sys: lib/rust/hmll-sys -> .)
let project_root = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
.parent()
.unwrap()
.parent()
.unwrap()
.parent()
.unwrap()
.to_path_buf();

println!("cargo:rerun-if-changed=../../..");
println!("cargo:rerun-if-changed=build.rs");

// Detect Rust build profile and map to CMake build type
let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string());
let cmake_build_type = match profile.as_str() {
"debug" => "Debug",
"release" => "Release",
"bench" => "Release",
_ => "RelWithDebInfo",
};

// Configure CMake build
let mut cmake_config = cmake::Config::new(&project_root);

// Set the CMake build type based on Rust profile
cmake_config.profile(cmake_build_type);

// Enable features based on Rust features
cmake_config
.define("HMLL_BUILD_STATIC", "ON")
.define("HMLL_BUILD_EXAMPLES", "OFF")
.define("HMLL_BUILD_TESTS", "OFF")
.define("HMLL_ENABLE_PYTHON", "OFF")
.build_target("libhmll");

#[cfg(feature = "io_uring")]
cmake_config.define("HMLL_ENABLE_IO_URING", "ON");

#[cfg(not(feature = "io_uring"))]
cmake_config.define("HMLL_ENABLE_IO_URING", "OFF");

#[cfg(feature = "safetensors")]
cmake_config.define("HMLL_ENABLE_SAFETENSORS", "ON");

#[cfg(not(feature = "safetensors"))]
cmake_config.define("HMLL_ENABLE_SAFETENSORS", "OFF");

#[cfg(feature = "cuda")]
cmake_config.define("HMLL_ENABLE_CUDA", "ON");

#[cfg(not(feature = "cuda"))]
cmake_config.define("HMLL_ENABLE_CUDA", "OFF");

// Build the library
let dst = cmake_config.build();

// Tell cargo to link the library
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=static=libhmll");

// Link io_uring if enabled
#[cfg(all(target_os = "linux", feature = "io_uring"))]
{
println!("cargo:rustc-link-search=native={}/build/_deps/liburing-src/src", dst.display());
println!("cargo:rustc-link-lib=static=uring");
}

// Link yyjson if safetensors is enabled
#[cfg(feature = "safetensors")]
{
println!("cargo:rustc-link-search=native={}/build/_deps/yyjson-build", dst.display());
println!("cargo:rustc-link-lib=static=yyjson");
}

// Link CUDA runtime if enabled
#[cfg(feature = "cuda")]
{
// Try to find CUDA installation
if let Ok(cuda_path) = env::var("CUDA_PATH") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-search=native={}/lib", cuda_path);
} else if let Ok(cuda_home) = env::var("CUDA_HOME") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_home);
println!("cargo:rustc-link-search=native={}/lib", cuda_home);
} else {
// Try common default locations
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib");
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
println!("cargo:rustc-link-search=native=/opt/cuda/lib");
}

println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
}

// Generate bindings
let include_path = project_root.join("include");

let builder = bindgen::Builder::default()
.header(include_path.join("hmll/hmll.h").to_str().unwrap())
.clang_arg(format!("-I{}", include_path.display()))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.allowlist_function("hmll_.*")
.allowlist_type("hmll_.*")
.allowlist_var("HMLL_.*")
.derive_debug(true)
.derive_default(true)
.derive_copy(true)
.derive_eq(true)
.derive_hash(true)
.impl_debug(true)
.prepend_enum_name(false)
.size_t_is_usize(true)
.layout_tests(false);

// Add conditional defines based on features
#[cfg(feature = "safetensors")]
let builder = builder
.clang_arg("-D__HMLL_SAFETENSORS_ENABLED__=1")
.clang_arg("-D__HMLL_TENSORS_ENABLED__=1");

#[cfg(feature = "cuda")]
let builder = builder.clang_arg("-D__HMLL_CUDA_ENABLED__=1");

let bindings = builder
.generate()
.expect("Unable to generate bindings");

// Write the bindings to the $OUT_DIR/bindings.rs file
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
31 changes: 31 additions & 0 deletions lib/rust/hmll/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[package]
name = "hmll"
version = "0.1.0"
edition = "2021"
authors = ["HMLL Contributors"]
description = "Safe, idiomatic Rust bindings to the hmll library for high-performance ML model loading"
license = "MIT OR Apache-2.0"
repository = "https://github.com/huggingface/hmll"

[dependencies]
hmll-sys = { path = "../hmll-sys" }
thiserror = "2.0"

[features]
default = ["io_uring"]
io_uring = ["hmll-sys/io_uring"]
safetensors = ["hmll-sys/safetensors"]
cuda = ["hmll-sys/cuda"]

[profile.release]
lto = "fat" # Full LTO for maximum optimization across crates
opt-level = 3 # Maximum optimization level
codegen-units = 1 # Single codegen unit for better optimization
strip = true # Strip symbols for smaller binary size
panic = "abort" # Abort on panic for smaller code size

[profile.bench]
inherits = "release"
lto = "fat"
opt-level = 3
codegen-units = 1
81 changes: 81 additions & 0 deletions lib/rust/hmll/examples/basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//! Basic example of loading data from a single model file.

use hmll::{Device, LoaderKind, Source, WeightLoader};
use std::env;
use std::str::FromStr;
use std::time::Instant;

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Get the file path from command line arguments
let args: Vec<String> = env::args().collect();
if args.len() != 4 {
eprintln!("Usage: {} <model_file>", args[0]);
eprintln!("Example: {} model.safetensors", args[0]);
std::process::exit(1);
}

let file_path = &args[1];
let start = usize::from_str(&args[2]).expect("<start> parameter should be a number");
let end = usize::from_str(&args[3]).expect("<end> parameter should be a number");

println!("Opening file: {}", file_path);

// Open the source file
let source = Source::open(file_path)?;
println!("✓ File opened successfully");
println!(" Size: {} bytes ({:.2} MB)", source.size(), source.size() as f64 / 1_048_576.0);

// Store in an array to ensure proper lifetime
let sources = [source];

// Create a weight loader
println!("\nCreating weight loader...");
let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?;
println!("✓ Loader created successfully");
println!(" Device: {}", loader.device());
println!(" Number of sources: {}", loader.num_sources());

// Fetch some data from the beginning of the file
let fetch_size = end - start;
let actual_fetch_size = fetch_size.min(sources[0].size());
println!("\nFetching {} bytes ({:.2} MB)...", actual_fetch_size, actual_fetch_size as f64 / 1_048_576.0);

let start_time = Instant::now();
let buffer = loader.fetch(start..end, 0)?;
let elapsed = start_time.elapsed();

println!("✓ Data fetched successfully");
println!(" Buffer size: {} bytes", buffer.len());
println!(" Buffer device: {}", buffer.device());
println!(" Fetch time: {:.3}s", elapsed.as_secs_f64());

// Calculate throughput
let throughput_bytes_per_sec = buffer.len() as f64 / elapsed.as_secs_f64();
let throughput_mb_per_sec = throughput_bytes_per_sec / 1_048_576.0;
let throughput_gb_per_sec = throughput_bytes_per_sec / 1_073_741_824.0;

println!("\n📊 Throughput:");
println!(" {:.2} MB/s", throughput_mb_per_sec);
println!(" {:.2} GB/s", throughput_gb_per_sec);

// Access the data (for CPU buffers)
if let Some(data) = buffer.as_slice() {
println!("\n✓ Buffer accessible as slice");

// Print first 64 bytes as hex
let preview_len = 64usize.min(data.len());
println!(" First {} bytes (hex):", preview_len);
print!(" ");
for (i, byte) in data[..preview_len].iter().enumerate() {
print!("{:02x} ", byte);
if (i + 1) % 16 == 0 && i < preview_len - 1 {
print!("\n ");
}
}
println!();
}

println!("\n✓ All operations completed successfully!");

Ok(())
}
Loading