diff --git a/lib/rust/hmll-sys/Cargo.toml b/lib/rust/hmll-sys/Cargo.toml new file mode 100644 index 0000000..5b0fe8a --- /dev/null +++ b/lib/rust/hmll-sys/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "hmll-sys" +version = "0.1.0" +edition = "2021" +authors = ["Morgan Funtowicz "] +description = "Low-level FFI bindings to the hmll library" +license = "MIT OR Apache-2.0" +repository = "https://github.com/huggingface/hmll" +links = "hmll" + +[dependencies] + +[build-dependencies] +bindgen = "0.72.1" +cmake = "0.1" + +[features] +default = ["io_uring"] +io_uring = [] +safetensors = [] +cuda = [] + +[profile.release] +lto = "fat" # Full LTO for maximum optimization across crates +opt-level = 3 # Maximum optimization level +codegen-units = 1 # Single codegen unit for better optimization +strip = true # Strip symbols for smaller binary size + +[profile.bench] +inherits = "release" +lto = "fat" +opt-level = 3 +codegen-units = 1 diff --git a/lib/rust/hmll-sys/build.rs b/lib/rust/hmll-sys/build.rs new file mode 100644 index 0000000..a9f56ac --- /dev/null +++ b/lib/rust/hmll-sys/build.rs @@ -0,0 +1,140 @@ +use std::env; +use std::path::PathBuf; + +fn main() { + // Get the project root (3 levels up from hmll-sys: lib/rust/hmll-sys -> .) + let project_root = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()) + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap() + .to_path_buf(); + + println!("cargo:rerun-if-changed=../../.."); + println!("cargo:rerun-if-changed=build.rs"); + + // Detect Rust build profile and map to CMake build type + let profile = env::var("PROFILE").unwrap_or_else(|_| "debug".to_string()); + let cmake_build_type = match profile.as_str() { + "debug" => "Debug", + "release" => "Release", + "bench" => "Release", + _ => "RelWithDebInfo", + }; + + // Configure CMake build + let mut cmake_config = cmake::Config::new(&project_root); + + // Set the CMake build type based on Rust profile + cmake_config.profile(cmake_build_type); + + // Enable features based on Rust features + cmake_config + .define("HMLL_BUILD_STATIC", "ON") + .define("HMLL_BUILD_EXAMPLES", "OFF") + .define("HMLL_BUILD_TESTS", "OFF") + .define("HMLL_ENABLE_PYTHON", "OFF") + .build_target("libhmll"); + + #[cfg(feature = "io_uring")] + cmake_config.define("HMLL_ENABLE_IO_URING", "ON"); + + #[cfg(not(feature = "io_uring"))] + cmake_config.define("HMLL_ENABLE_IO_URING", "OFF"); + + #[cfg(feature = "safetensors")] + cmake_config.define("HMLL_ENABLE_SAFETENSORS", "ON"); + + #[cfg(not(feature = "safetensors"))] + cmake_config.define("HMLL_ENABLE_SAFETENSORS", "OFF"); + + #[cfg(feature = "cuda")] + cmake_config.define("HMLL_ENABLE_CUDA", "ON"); + + #[cfg(not(feature = "cuda"))] + cmake_config.define("HMLL_ENABLE_CUDA", "OFF"); + + // Build the library + let dst = cmake_config.build(); + + // Tell cargo to link the library + println!("cargo:rustc-link-search=native={}/build", dst.display()); + println!("cargo:rustc-link-lib=static=libhmll"); + + // Link io_uring if enabled + #[cfg(all(target_os = "linux", feature = "io_uring"))] + { + println!("cargo:rustc-link-search=native={}/build/_deps/liburing-src/src", dst.display()); + println!("cargo:rustc-link-lib=static=uring"); + } + + // Link yyjson if safetensors is enabled + #[cfg(feature = "safetensors")] + { + println!("cargo:rustc-link-search=native={}/build/_deps/yyjson-build", dst.display()); + println!("cargo:rustc-link-lib=static=yyjson"); + } + + // Link CUDA runtime if enabled + #[cfg(feature = "cuda")] + { + // Try to find CUDA installation + if let Ok(cuda_path) = env::var("CUDA_PATH") { + println!("cargo:rustc-link-search=native={}/lib64", cuda_path); + println!("cargo:rustc-link-search=native={}/lib", cuda_path); + } else if let Ok(cuda_home) = env::var("CUDA_HOME") { + println!("cargo:rustc-link-search=native={}/lib64", cuda_home); + println!("cargo:rustc-link-search=native={}/lib", cuda_home); + } else { + // Try common default locations + println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64"); + println!("cargo:rustc-link-search=native=/usr/local/cuda/lib"); + println!("cargo:rustc-link-search=native=/opt/cuda/lib64"); + println!("cargo:rustc-link-search=native=/opt/cuda/lib"); + } + + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=cuda"); + } + + // Generate bindings + let include_path = project_root.join("include"); + + let builder = bindgen::Builder::default() + .header(include_path.join("hmll/hmll.h").to_str().unwrap()) + .clang_arg(format!("-I{}", include_path.display())) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .allowlist_function("hmll_.*") + .allowlist_type("hmll_.*") + .allowlist_var("HMLL_.*") + .derive_debug(true) + .derive_default(true) + .derive_copy(true) + .derive_eq(true) + .derive_hash(true) + .impl_debug(true) + .prepend_enum_name(false) + .size_t_is_usize(true) + .layout_tests(false); + + // Add conditional defines based on features + #[cfg(feature = "safetensors")] + let builder = builder + .clang_arg("-D__HMLL_SAFETENSORS_ENABLED__=1") + .clang_arg("-D__HMLL_TENSORS_ENABLED__=1"); + + #[cfg(feature = "cuda")] + let builder = builder.clang_arg("-D__HMLL_CUDA_ENABLED__=1"); + + let bindings = builder + .generate() + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} diff --git a/lib/rust/hmll/Cargo.toml b/lib/rust/hmll/Cargo.toml new file mode 100644 index 0000000..c8ca2af --- /dev/null +++ b/lib/rust/hmll/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "hmll" +version = "0.1.0" +edition = "2021" +authors = ["HMLL Contributors"] +description = "Safe, idiomatic Rust bindings to the hmll library for high-performance ML model loading" +license = "MIT OR Apache-2.0" +repository = "https://github.com/huggingface/hmll" + +[dependencies] +hmll-sys = { path = "../hmll-sys" } +thiserror = "2.0" + +[features] +default = ["io_uring"] +io_uring = ["hmll-sys/io_uring"] +safetensors = ["hmll-sys/safetensors"] +cuda = ["hmll-sys/cuda"] + +[profile.release] +lto = "fat" # Full LTO for maximum optimization across crates +opt-level = 3 # Maximum optimization level +codegen-units = 1 # Single codegen unit for better optimization +strip = true # Strip symbols for smaller binary size +panic = "abort" # Abort on panic for smaller code size + +[profile.bench] +inherits = "release" +lto = "fat" +opt-level = 3 +codegen-units = 1 \ No newline at end of file diff --git a/lib/rust/hmll/examples/basic.rs b/lib/rust/hmll/examples/basic.rs new file mode 100644 index 0000000..2dc8ad2 --- /dev/null +++ b/lib/rust/hmll/examples/basic.rs @@ -0,0 +1,81 @@ +//! Basic example of loading data from a single model file. + +use hmll::{Device, LoaderKind, Source, WeightLoader}; +use std::env; +use std::str::FromStr; +use std::time::Instant; + +fn main() -> Result<(), Box> { + // Get the file path from command line arguments + let args: Vec = env::args().collect(); + if args.len() != 4 { + eprintln!("Usage: {} ", args[0]); + eprintln!("Example: {} model.safetensors", args[0]); + std::process::exit(1); + } + + let file_path = &args[1]; + let start = usize::from_str(&args[2]).expect(" parameter should be a number"); + let end = usize::from_str(&args[3]).expect(" parameter should be a number"); + + println!("Opening file: {}", file_path); + + // Open the source file + let source = Source::open(file_path)?; + println!("āœ“ File opened successfully"); + println!(" Size: {} bytes ({:.2} MB)", source.size(), source.size() as f64 / 1_048_576.0); + + // Store in an array to ensure proper lifetime + let sources = [source]; + + // Create a weight loader + println!("\nCreating weight loader..."); + let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + println!("āœ“ Loader created successfully"); + println!(" Device: {}", loader.device()); + println!(" Number of sources: {}", loader.num_sources()); + + // Fetch some data from the beginning of the file + let fetch_size = end - start; + let actual_fetch_size = fetch_size.min(sources[0].size()); + println!("\nFetching {} bytes ({:.2} MB)...", actual_fetch_size, actual_fetch_size as f64 / 1_048_576.0); + + let start_time = Instant::now(); + let buffer = loader.fetch(start..end, 0)?; + let elapsed = start_time.elapsed(); + + println!("āœ“ Data fetched successfully"); + println!(" Buffer size: {} bytes", buffer.len()); + println!(" Buffer device: {}", buffer.device()); + println!(" Fetch time: {:.3}s", elapsed.as_secs_f64()); + + // Calculate throughput + let throughput_bytes_per_sec = buffer.len() as f64 / elapsed.as_secs_f64(); + let throughput_mb_per_sec = throughput_bytes_per_sec / 1_048_576.0; + let throughput_gb_per_sec = throughput_bytes_per_sec / 1_073_741_824.0; + + println!("\nšŸ“Š Throughput:"); + println!(" {:.2} MB/s", throughput_mb_per_sec); + println!(" {:.2} GB/s", throughput_gb_per_sec); + + // Access the data (for CPU buffers) + if let Some(data) = buffer.as_slice() { + println!("\nāœ“ Buffer accessible as slice"); + + // Print first 64 bytes as hex + let preview_len = 64usize.min(data.len()); + println!(" First {} bytes (hex):", preview_len); + print!(" "); + for (i, byte) in data[..preview_len].iter().enumerate() { + print!("{:02x} ", byte); + if (i + 1) % 16 == 0 && i < preview_len - 1 { + print!("\n "); + } + } + println!(); + } + + println!("\nāœ“ All operations completed successfully!"); + + Ok(()) +} diff --git a/lib/rust/hmll/examples/multi_files.rs b/lib/rust/hmll/examples/multi_files.rs new file mode 100644 index 0000000..9d29b17 --- /dev/null +++ b/lib/rust/hmll/examples/multi_files.rs @@ -0,0 +1,138 @@ +//! Example of loading data from multiple sharded model files. + +use hmll::{Device, LoaderKind, Source, WeightLoader}; +use std::env; +use std::time::Instant; + +fn main() -> Result<(), Box> { + // Get file paths from command line arguments + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} [file2] [file3] ...", args[0]); + eprintln!("Example: {} model-00001.safetensors model-00002.safetensors", args[0]); + std::process::exit(1); + } + + let file_paths = &args[1..]; + + println!("Opening {} file(s)...", file_paths.len()); + + // Open all source files + let mut sources = Vec::new(); + let mut total_size = 0u64; + + for (i, path) in file_paths.iter().enumerate() { + print!(" [{}] Opening: {}... ", i, path); + match Source::open(path) { + Ok(source) => { + let size = source.size(); + total_size += size as u64; + println!("āœ“ ({} bytes)", size); + sources.push(source); + } + Err(e) => { + eprintln!("āœ— Failed: {}", e); + return Err(e.into()); + } + } + } + + println!("\nāœ“ All files opened successfully"); + println!(" Total size: {} bytes ({:.2} MB)", total_size, total_size as f64 / 1_048_576.0); + + // Create a weight loader for all sources + println!("\nCreating weight loader for {} sources...", sources.len()); + let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + println!("āœ“ Loader created successfully"); + + // Display information about each source + println!("\nSource information:"); + for i in 0..loader.num_sources() { + if let Some(info) = loader.source_info(i) { + println!(" [{}] Size: {} bytes ({:.2} MB)", + i, info.size, info.size as f64 / 1_048_576.0); + } + } + + // Fetch data from each file with throughput measurement + println!("\nFetching data from each source..."); + let mut total_fetch_time = 0.0; + let mut total_fetched_bytes = 0; + + for i in 0..loader.num_sources() { + if let Some(info) = loader.source_info(i) { + let fetch_size = 512usize.min(info.size); + print!(" [{}] Fetching {} bytes... ", i, fetch_size); + + let start_time = Instant::now(); + match loader.fetch(0..fetch_size, i) { + Ok(buffer) => { + let elapsed = start_time.elapsed(); + total_fetch_time += elapsed.as_secs_f64(); + total_fetched_bytes += buffer.len(); + + let throughput_mb = (buffer.len() as f64 / 1_048_576.0) / elapsed.as_secs_f64(); + println!("āœ“ {} bytes in {:.3}s ({:.2} MB/s)", buffer.len(), elapsed.as_secs_f64(), throughput_mb); + + // Show a preview of the data + if let Some(data) = buffer.as_slice() { + let preview_len = 16usize.min(data.len()); + print!(" Preview: "); + for byte in &data[..preview_len] { + print!("{:02x} ", byte); + } + if preview_len < data.len() { + print!("..."); + } + println!(); + } + } + Err(e) => { + eprintln!("āœ— Failed: {}", e); + return Err(e.into()); + } + } + } + } + + // Demonstrate fetching from different ranges in different files + println!("\nDemonstrating random access across files..."); + + for i in 0..loader.num_sources().min(3) { + if let Some(info) = loader.source_info(i) { + if info.size >= 2048 { + let ranges = [ + (0, 256), + (512, 768), + (1024, 1280), + ]; + + for (start, end) in ranges { + if end <= info.size { + print!(" [{}] Range {}..{}... ", i, start, end); + match loader.fetch(start..end, i) { + Ok(buffer) => println!("āœ“ {} bytes", buffer.len()), + Err(e) => println!("āœ— {}", e), + } + } + } + } + } + } + + println!("\nāœ“ All operations completed successfully!"); + println!("\nSummary:"); + println!(" Files loaded: {}", loader.num_sources()); + println!(" Total file size: {:.2} MB", total_size as f64 / 1_048_576.0); + println!(" Device: {}", loader.device()); + + if total_fetched_bytes > 0 && total_fetch_time > 0.0 { + let avg_throughput_mb = (total_fetched_bytes as f64 / 1_048_576.0) / total_fetch_time; + println!("\nšŸ“Š Overall Throughput:"); + println!(" Total fetched: {:.2} MB", total_fetched_bytes as f64 / 1_048_576.0); + println!(" Total time: {:.3}s", total_fetch_time); + println!(" Average: {:.2} MB/s", avg_throughput_mb); + } + + Ok(()) +} diff --git a/lib/rust/hmll/src/buffer.rs b/lib/rust/hmll/src/buffer.rs new file mode 100644 index 0000000..e66b9f2 --- /dev/null +++ b/lib/rust/hmll/src/buffer.rs @@ -0,0 +1,187 @@ +//! Buffer and range types for data operations. + +use crate::Device; +use std::ops; + +/// Represents a range of bytes to fetch. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Range { + pub start: usize, + pub end: usize, +} + +impl Range { + /// Create a new range. + /// + /// This can be evaluated at compile time for constant ranges. + #[inline(always)] + pub const fn new(start: usize, end: usize) -> Self { + Self { start, end } + } + + /// Get the length of the range. + /// + /// Hot path - inline always for zero-cost abstraction. + #[inline(always)] + pub const fn len(&self) -> usize { + self.end.saturating_sub(self.start) + } + + /// Check if the range is empty. + /// + /// Hot path - inline always for zero-cost abstraction. + #[inline(always)] + pub const fn is_empty(&self) -> bool { + self.start >= self.end + } + + /// Convert to the underlying C struct. + /// + /// Hot path - always inline for FFI conversion. + #[inline(always)] + pub(crate) fn to_raw(self) -> hmll_sys::hmll_range { + hmll_sys::hmll_range { + start: self.start, + end: self.end, + } + } + + /// Convert from the underlying C struct. + /// + /// Hot path - always inline for FFI conversion. + #[allow(unused)] + #[inline(always)] + pub(crate) const fn from_raw(range: hmll_sys::hmll_range) -> Self { + Self { + start: range.start, + end: range.end, + } + } +} + +impl From> for Range { + /// Convert from standard library Range. + /// + /// Hot path - inline always for zero-cost conversion. + #[inline(always)] + fn from(range: ops::Range) -> Self { + Self { + start: range.start, + end: range.end, + } + } +} + +impl From for ops::Range { + /// Convert to standard library Range. + /// + /// Hot path - inline always for zero-cost conversion. + #[inline(always)] + fn from(range: Range) -> Self { + range.start..range.end + } +} + +/// A buffer containing fetched data. +#[derive(Debug)] +pub struct Buffer { + ptr: *mut u8, + size: usize, + device: Device, + // We own this memory, so we need to track whether to free it + #[allow(dead_code)] + owned: bool, +} + +impl Buffer { + /// Create a new buffer from raw parts. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` points to valid memory of at least `size` bytes. + /// + /// Hot path - inline always for construction. + #[inline(always)] + pub(crate) unsafe fn from_raw_parts(ptr: *mut u8, size: usize, device: Device, owned: bool) -> Self { + Self { + ptr, + size, + device, + owned, + } + } + + /// Get the buffer as a byte slice (CPU only). + /// + /// Hot path - inline for efficient slice creation. + #[inline] + pub fn as_slice(&self) -> Option<&[u8]> { + if self.device == Device::Cpu && !self.ptr.is_null() { + unsafe { Some(std::slice::from_raw_parts(self.ptr, self.size)) } + } else { + None + } + } + + /// Get the size of the buffer in bytes. + /// + /// Hot path - inline always for zero-cost field access. + #[inline(always)] + pub const fn len(&self) -> usize { + self.size + } + + /// Check if the buffer is empty. + /// + /// Hot path - inline always for zero-cost check. + #[inline(always)] + pub const fn is_empty(&self) -> bool { + self.size == 0 + } + + /// Get the device where the buffer is located. + /// + /// Hot path - inline always for zero-cost field access. + #[inline(always)] + pub const fn device(&self) -> Device { + self.device + } + + /// Get a raw pointer to the buffer. + /// + /// Hot path - inline always for zero-cost pointer access. + #[inline(always)] + pub const fn as_ptr(&self) -> *const u8 { + self.ptr as *const u8 + } + + /// Get a mutable raw pointer to the buffer. + /// + /// Hot path - inline always for zero-cost pointer access. + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.ptr + } + + /// Convert to a Vec (copies data if on CPU, panics if on GPU). + /// + /// This is a less common operation, so we use regular inline. + #[inline] + pub fn to_vec(&self) -> Vec { + self.as_slice() + .expect("Cannot convert GPU buffer to Vec") + .to_vec() + } +} + +// Buffer is Send and Sync as long as the device supports it +unsafe impl Send for Buffer {} +unsafe impl Sync for Buffer {} + +impl Drop for Buffer { + fn drop(&mut self) { + // Note: In hmll, buffers are managed by the context + // We don't manually free them here as they're part of the arena allocator + // This is why we track `owned` - in the future we might need to handle this differently + } +} diff --git a/lib/rust/hmll/src/device.rs b/lib/rust/hmll/src/device.rs new file mode 100644 index 0000000..af3e274 --- /dev/null +++ b/lib/rust/hmll/src/device.rs @@ -0,0 +1,57 @@ +//! Device types for specifying where data should be loaded. + +use hmll_sys::{hmll_device, HMLL_DEVICE_CPU, HMLL_DEVICE_CUDA}; + +/// Represents a device where data can be loaded. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Device { + /// CPU memory + Cpu, + /// CUDA GPU memory + Cuda, +} + +impl Device { + /// Convert to the underlying C enum value. + /// + /// Hot path - inline always for FFI conversion. + #[inline(always)] + pub(crate) const fn to_raw(self) -> hmll_device { + match self { + Device::Cpu => HMLL_DEVICE_CPU, + Device::Cuda => HMLL_DEVICE_CUDA, + } + } + + /// Convert from the underlying C enum value. + /// + /// Hot path - inline always for FFI conversion. + #[allow(dead_code)] + #[inline(always)] + pub(crate) const fn from_raw(device: hmll_device) -> Option { + match device { + HMLL_DEVICE_CPU => Some(Device::Cpu), + HMLL_DEVICE_CUDA => Some(Device::Cuda), + _ => None, + } + } +} + +impl Default for Device { + /// Default device is CPU. + /// + /// Hot path - inline always for zero-cost default. + #[inline(always)] + fn default() -> Self { + Device::Cpu + } +} + +impl std::fmt::Display for Device { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Device::Cpu => write!(f, "CPU"), + Device::Cuda => write!(f, "CUDA"), + } + } +} diff --git a/lib/rust/hmll/src/error.rs b/lib/rust/hmll/src/error.rs new file mode 100644 index 0000000..853872f --- /dev/null +++ b/lib/rust/hmll/src/error.rs @@ -0,0 +1,143 @@ +//! Error types for hmll operations. + +use std::ffi::CStr; +use thiserror::Error; + +/// Result type alias for hmll operations. +pub type Result = std::result::Result; + +/// Errors that can occur when using hmll. +#[derive(Debug, Error)] +pub enum Error { + #[error("Unsupported platform")] + UnsupportedPlatform, + + #[error("Unsupported file format")] + UnsupportedFileFormat, + + #[error("Unsupported device")] + UnsupportedDevice, + + #[error("Memory allocation failed")] + AllocationFailed, + + #[error("Table is empty")] + TableEmpty, + + #[error("Tensor not found")] + TensorNotFound, + + #[error("Invalid range")] + InvalidRange, + + #[error("Buffer address not aligned")] + BufferAddrNotAligned, + + #[error("Buffer too small")] + BufferTooSmall, + + #[error("I/O error")] + IoError, + + #[error("File not found: {0}")] + FileNotFound(String), + + #[error("File is empty")] + FileEmpty, + + #[error("Memory mapping failed")] + MmapFailed, + + #[error("I/O buffer registration failed")] + IoBufferRegistrationFailed, + + #[error("SafeTensors: Invalid JSON header")] + SafeTensorsJsonInvalidHeader, + + #[error("SafeTensors: Malformed JSON header")] + SafeTensorsJsonMalformedHeader, + + #[error("SafeTensors: Malformed JSON index")] + SafeTensorsJsonMalformedIndex, + + #[error("CUDA not enabled")] + CudaNotEnabled, + + #[error("No CUDA device available")] + CudaNoDevice, + + #[error("System error: {0}")] + SystemError(String), + + #[error("Unknown data type")] + UnknownDType, + + #[error("Unknown error code: {0}")] + Unknown(u32), +} + +impl Error { + /// Convert a hmll_error to a Rust Error. + /// + /// This is a cold path - errors should be rare in normal operation. + /// We mark it as cold and never inline to optimize the hot (success) path. + #[cold] + #[inline(never)] + pub(crate) fn from_hmll_error(err: hmll_sys::hmll_error) -> Self { + use hmll_sys::*; + + // Check if it's a system error + if err.code == HMLL_ERR_SYSTEM { + let msg = unsafe { + let ptr = hmll_strerr(err); + if ptr.is_null() { + format!("System error code: {}", err.sys_err) + } else { + CStr::from_ptr(ptr) + .to_string_lossy() + .into_owned() + } + }; + return Error::SystemError(msg); + } + + // Map hmll error codes to Rust errors + match err.code { + HMLL_ERR_SUCCESS => unreachable!("Success is not an error"), + HMLL_ERR_UNSUPPORTED_PLATFORM => Error::UnsupportedPlatform, + HMLL_ERR_UNSUPPORTED_FILE_FORMAT => Error::UnsupportedFileFormat, + HMLL_ERR_UNSUPPORTED_DEVICE => Error::UnsupportedDevice, + HMLL_ERR_ALLOCATION_FAILED => Error::AllocationFailed, + HMLL_ERR_TABLE_EMPTY => Error::TableEmpty, + HMLL_ERR_TENSOR_NOT_FOUND => Error::TensorNotFound, + HMLL_ERR_INVALID_RANGE => Error::InvalidRange, + HMLL_ERR_BUFFER_ADDR_NOT_ALIGNED => Error::BufferAddrNotAligned, + HMLL_ERR_BUFFER_TOO_SMALL => Error::BufferTooSmall, + HMLL_ERR_IO_ERROR => Error::IoError, + HMLL_ERR_FILE_NOT_FOUND => Error::FileNotFound(String::new()), + HMLL_ERR_FILE_EMPTY => Error::FileEmpty, + HMLL_ERR_MMAP_FAILED => Error::MmapFailed, + HMLL_ERR_IO_BUFFER_REGISTRATION_FAILED => Error::IoBufferRegistrationFailed, + HMLL_ERR_SAFETENSORS_JSON_INVALID_HEADER => Error::SafeTensorsJsonInvalidHeader, + HMLL_ERR_SAFETENSORS_JSON_MALFORMED_HEADER => Error::SafeTensorsJsonMalformedHeader, + HMLL_ERR_SAFETENSORS_JSON_MALFORMED_INDEX => Error::SafeTensorsJsonMalformedIndex, + HMLL_ERR_CUDA_NOT_ENABLED => Error::CudaNotEnabled, + HMLL_ERR_CUDA_NO_DEVICE => Error::CudaNoDevice, + HMLL_ERR_UNKNOWN_DTYPE => Error::UnknownDType, + code => Error::Unknown(code), + } + } + + /// Check if a hmll_error represents success. + /// + /// This is a hot path - inline always for zero-cost abstraction. + /// The success path should be optimized and the error path should branch predict as unlikely. + #[inline(always)] + pub(crate) fn check_hmll_error(err: hmll_sys::hmll_error) -> Result<()> { + if hmll_sys::hmll_is_success(err) { + Ok(()) + } else { + Err(Self::from_hmll_error(err)) + } + } +} diff --git a/lib/rust/hmll/src/lib.rs b/lib/rust/hmll/src/lib.rs new file mode 100644 index 0000000..354b10a --- /dev/null +++ b/lib/rust/hmll/src/lib.rs @@ -0,0 +1,57 @@ +//! Safe, idiomatic Rust bindings to the hmll library. +//! +//! This crate provides a safe, high-level interface to the hmll C library for +//! high-performance loading of machine learning model files. +//! +//! # Features +//! +//! - **`io_uring`** (default): High-performance I/O using io_uring on Linux +//! - **`safetensors`**: Native support for safetensors format +//! - **`cuda`**: CUDA memory support for GPU operations +//! +//! # Example +//! +//! ```no_run +//! use hmll::{Source, WeightLoader, Device, LoaderKind}; +//! +//! # fn main() -> Result<(), Box> { +//! // Open a model file +//! let source = Source::open("model.safetensors")?; +//! let sources = [source]; +//! +//! // Create a weight loader +//! let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; +//! +//! // Fetch a range of bytes +//! let data = loader.fetch(0..1024, 0)?; +//! println!("Fetched {} bytes", data.len()); +//! # Ok(()) +//! # } +//! ``` + +mod error; +mod source; +mod loader; +mod device; +mod buffer; + +pub use error::{Error, Result}; +pub use source::Source; +pub use loader::{WeightLoader, LoaderKind}; +pub use device::Device; +pub use buffer::{Buffer, Range}; + +// Re-export common types +pub use hmll_sys::{ + HMLL_ERR_SUCCESS, + HMLL_DEVICE_CPU, + HMLL_DEVICE_CUDA, +}; + +#[cfg(feature = "safetensors")] +pub use hmll_sys::{ + hmll_dtype as DType, + HMLL_DTYPE_FLOAT32, + HMLL_DTYPE_FLOAT16, + HMLL_DTYPE_BFLOAT16, +}; diff --git a/lib/rust/hmll/src/loader.rs b/lib/rust/hmll/src/loader.rs new file mode 100644 index 0000000..decffd8 --- /dev/null +++ b/lib/rust/hmll/src/loader.rs @@ -0,0 +1,293 @@ +//! Weight loader implementation for efficient model loading. + +use crate::{Buffer, Device, Error, Range, Result, Source}; +use std::marker::PhantomData; +use std::ptr; + +/// Loader backend kind. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum LoaderKind { + /// Automatically select the best backend. + Auto, + /// Use io_uring backend (Linux only). + #[cfg(target_os = "linux")] + IoUring, +} + +impl LoaderKind { + /// Convert to the underlying C enum value. + /// + /// Hot path - inline always for FFI conversion. + #[inline(always)] + pub(crate) const fn to_raw(self) -> hmll_sys::hmll_loader_kind { + match self { + LoaderKind::Auto => hmll_sys::HMLL_FETCHER_AUTO, + #[cfg(target_os = "linux")] + LoaderKind::IoUring => hmll_sys::HMLL_FETCHER_IO_URING, + } + } +} + +impl Default for LoaderKind { + /// Default loader kind is Auto. + /// + /// Hot path - inline always for zero-cost default. + #[inline(always)] + fn default() -> Self { + LoaderKind::Auto + } +} + +/// A high-performance weight loader for ML models. +/// +/// `WeightLoader` encapsulates the hmll context, loader, and device configuration, +/// providing a safe interface for fetching weight data from model files. +/// +/// # Example +/// +/// ```no_run +/// use hmll::{Source, WeightLoader, Device, LoaderKind}; +/// +/// # fn main() -> Result<(), Box> { +/// // Open source files +/// let source1 = Source::open("model-00001-of-00003.safetensors")?; +/// let source2 = Source::open("model-00002-of-00003.safetensors")?; +/// let source3 = Source::open("model-00003-of-00003.safetensors")?; +/// let sources = [source1, source2, source3]; +/// +/// // Create a loader +/// let mut loader = WeightLoader::new( +/// &sources, +/// Device::Cpu, +/// LoaderKind::Auto +/// )?; +/// +/// // Fetch data from the first file +/// let data = loader.fetch(0..1024, 0)?; +/// println!("Fetched {} bytes", data.len()); +/// # Ok(()) +/// # } +/// ``` +pub struct WeightLoader<'a> { + context: Box, + sources: Vec, + device: Device, + _marker: PhantomData<&'a ()>, +} + +impl<'a> WeightLoader<'a> { + /// Create a new weight loader. + /// + /// # Arguments + /// + /// * `sources` - Slice of source files to load from + /// * `device` - Target device (CPU or CUDA) + /// * `kind` - Loader backend kind + /// + /// # Errors + /// + /// Returns an error if the loader initialization fails. + pub fn new(sources: &'a [Source], device: Device, kind: LoaderKind) -> Result { + if sources.is_empty() { + return Err(Error::InvalidRange); + } + + let sources_vec: Vec = sources + .iter() + .map(|s| *s.as_raw()) + .collect(); + + let mut context = Box::new(hmll_sys::hmll { + fetcher: ptr::null_mut(), + sources: ptr::null(), + num_sources: 0, + error: hmll_sys::hmll_error { + code: hmll_sys::HMLL_ERR_SUCCESS, + sys_err: 0, + }, + }); + + unsafe { + let err = hmll_sys::hmll_loader_init( + context.as_mut(), + sources_vec.as_ptr(), + sources_vec.len(), + device.to_raw(), + kind.to_raw(), + ); + Error::check_hmll_error(err)?; + } + + Ok(Self { + context, + sources: sources_vec, + device, + _marker: PhantomData, + }) + } + + /// Fetch a range of bytes from a specific source file. + /// + /// # Arguments + /// + /// * `range` - The byte range to fetch (start..end) + /// * `file_index` - Index of the source file to fetch from + /// + /// # Returns + /// + /// A `Buffer` containing the fetched data. + /// + /// # Errors + /// + /// Returns an error if: + /// - The file index is out of bounds + /// - The range is invalid + /// - The fetch operation fails + /// + /// # Example + /// + /// ```no_run + /// # use hmll::{Source, WeightLoader, Device, LoaderKind}; + /// # fn main() -> Result<(), Box> { + /// # let source = Source::open("model.safetensors")?; + /// # let sources = [source]; + /// # let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + /// // Fetch first 1MB from the first file + /// let data = loader.fetch(0..1024*1024, 0)?; + /// println!("Fetched {} bytes", data.len()); + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub fn fetch>(&mut self, range: R, file_index: usize) -> Result { + let range = range.into(); + + // Fast path: bounds check + if file_index >= self.sources.len() { + return Err(Error::InvalidRange); + } + + // Fast path: empty range + if range.is_empty() { + return Ok(unsafe { Buffer::from_raw_parts(ptr::null_mut(), 0, self.device, false) }); + } + + // Get buffer for the requested range + let mut iobuf = unsafe { + hmll_sys::hmll_get_buffer_for_range( + self.context.as_mut(), + self.device.to_raw(), + range.to_raw(), + ) + }; + + // Check allocation success (less common error path) + if iobuf.ptr.is_null() { + return Err(Error::AllocationFailed); + } + + // Perform the actual fetch + let offsets = unsafe { + hmll_sys::hmll_fetch( + self.context.as_mut(), + &mut iobuf, + range.to_raw(), + file_index, + ) + }; + + // Check for errors (less common error path) + if self.context.error.code != hmll_sys::HMLL_ERR_SUCCESS { + let err = self.context.error; + self.context.error = hmll_sys::hmll_error { + code: hmll_sys::HMLL_ERR_SUCCESS, + sys_err: 0, + }; + return Err(Error::from_hmll_error(err)); + } + + // Success path: create buffer + Ok(unsafe { + Buffer::from_raw_parts( + (iobuf.ptr as *mut u8).add(offsets.start), + offsets.end - offsets.start, + self.device, + false, // hmll manages the memory + ) + }) + } + + /// Get the device this loader is configured for. + /// + /// Hot path - inline always for zero-cost field access. + #[inline(always)] + pub const fn device(&self) -> Device { + self.device + } + + /// Get the number of source files. + /// + /// Hot path - inline always for zero-cost length access. + #[inline(always)] + pub fn num_sources(&self) -> usize { + self.sources.len() + } + + /// Get information about a specific source file. + /// + /// Hot path - inline for efficient bounds checking and struct creation. + #[inline] + pub fn source_info(&self, index: usize) -> Option { + if index < self.sources.len() { + Some(SourceInfo { + size: self.sources[index].size, + #[cfg(target_family = "unix")] + fd: self.sources[index].fd, + }) + } else { + None + } + } +} + +impl<'a> Drop for WeightLoader<'a> { + fn drop(&mut self) { + unsafe { + hmll_sys::hmll_destroy(self.context.as_mut()); + } + } +} + +// WeightLoader is Send but not Sync (mutable operations) +unsafe impl<'a> Send for WeightLoader<'a> {} + +/// Information about a source file. +#[derive(Debug, Clone, Copy)] +pub struct SourceInfo { + /// Size of the file in bytes + pub size: usize, + /// File descriptor (Unix only) + #[cfg(target_family = "unix")] + pub fd: i32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_sources() { + let result = WeightLoader::new(&[], Device::Cpu, LoaderKind::Auto); + assert!(result.is_err()); + } + + #[test] + fn test_loader_kind_default() { + assert_eq!(LoaderKind::default(), LoaderKind::Auto); + } + + #[test] + fn test_device_default() { + assert_eq!(Device::default(), Device::Cpu); + } +} diff --git a/lib/rust/hmll/src/source.rs b/lib/rust/hmll/src/source.rs new file mode 100644 index 0000000..693543c --- /dev/null +++ b/lib/rust/hmll/src/source.rs @@ -0,0 +1,134 @@ +//! Source file handling for hmll. + +use crate::error::{Error, Result}; +use std::ffi::CString; +use std::path::Path; + +/// A source file for loading model weights. +/// +/// This wraps a file descriptor and ensures proper cleanup when dropped. +#[derive(Debug)] +pub struct Source { + inner: hmll_sys::hmll_source, + path: Option, +} + +impl Source { + /// Open a source file from a path. + /// + /// # Example + /// + /// ```no_run + /// use hmll::Source; + /// + /// let source = Source::open("model.safetensors")?; + /// println!("Opened file with size: {} bytes", source.size()); + /// # Ok::<(), hmll::Error>(()) + /// ``` + pub fn open>(path: P) -> Result { + let path_ref = path.as_ref(); + let path_str = path_ref + .to_str() + .ok_or_else(|| Error::FileNotFound("Invalid UTF-8 in path".to_string()))?; + + let c_path = CString::new(path_str) + .map_err(|_| Error::FileNotFound("Path contains null byte".to_string()))?; + + let mut source = hmll_sys::hmll_source { + fd: -1, + size: 0, + }; + + unsafe { + let err = hmll_sys::hmll_source_open(c_path.as_ptr(), &mut source); + Error::check_hmll_error(err)?; + } + + Ok(Self { + inner: source, + path: Some(path_str.to_string()), + }) + } + + /// Get the size of the source file in bytes. + /// + /// Hot path - inline always for zero-cost field access. + #[inline(always)] + pub const fn size(&self) -> usize { + self.inner.size + } + + /// Get the file descriptor (platform-specific). + /// + /// Hot path - inline always for zero-cost field access. + #[cfg(target_family = "unix")] + #[inline(always)] + pub const fn fd(&self) -> i32 { + self.inner.fd + } + + /// Get the path of the source file if available. + /// + /// Hot path - inline for efficient option access. + #[inline] + pub fn path(&self) -> Option<&str> { + self.path.as_deref() + } + + /// Get a reference to the underlying hmll_source. + /// + /// Hot path - inline always for zero-cost reference. + #[inline(always)] + pub(crate) const fn as_raw(&self) -> &hmll_sys::hmll_source { + &self.inner + } + + /// Consume self and return the raw hmll_source. + /// + /// # Safety + /// + /// The caller is responsible for calling hmll_source_close on the returned source. + /// + /// Hot path - inline always for efficient ownership transfer. + #[allow(dead_code)] + #[inline(always)] + pub(crate) unsafe fn into_raw(mut self) -> hmll_sys::hmll_source { + let source = self.inner; + // Prevent Drop from running + self.inner.fd = -1; + source + } +} + +impl Drop for Source { + fn drop(&mut self) { + // only close if we have a valid file descriptor + if self.inner.fd >= 0 { + unsafe { + hmll_sys::hmll_source_close(&self.inner); + } + } + } +} + +// Source can be safely sent between threads +unsafe impl Send for Source {} +// Source can be safely shared between threads (read-only operations) +unsafe impl Sync for Source {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_source_invalid_path() { + let result = Source::open("/nonexistent/file.safetensors"); + assert!(result.is_err()); + } + + #[test] + fn test_source_null_byte() { + let result = Source::open("file\0name.safetensors"); + assert!(result.is_err()); + } +}