diff --git a/Cargo.toml b/Cargo.toml index 13679827..55fad579 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ members = [ "crates/luminal_cpu", "crates/luminal_nn", "crates/luminal_training", + "crates/luminal_onnx", "docs/company", ] exclude = [ @@ -47,3 +48,4 @@ exclude = [ "crates/luminal_2", "demos/flash_attention", ] + diff --git a/crates/luminal_onnx/Cargo.toml b/crates/luminal_onnx/Cargo.toml new file mode 100644 index 00000000..1054c397 --- /dev/null +++ b/crates/luminal_onnx/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "luminal_onnx" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" +description = "ONNX import for Luminal" + +[dependencies] +luminal = { path = "../.." } +prost = { version = "0.12", default-features = false, features = ["prost-derive"] } +prost-types = { version = "0.12", default-features = false } +thiserror = "1" + +[build-dependencies] +prost-build = { version = "0.12" } +protoc-bin-vendored = "3" + +[dev-dependencies] +tempfile = "3" diff --git a/crates/luminal_onnx/build.rs b/crates/luminal_onnx/build.rs new file mode 100644 index 00000000..92cc6b66 --- /dev/null +++ b/crates/luminal_onnx/build.rs @@ -0,0 +1,11 @@ +fn main() -> Result<(), Box> { + // Ensure protoc is available without a system install. + let protoc = protoc_bin_vendored::protoc_bin_path()?; + std::env::set_var("PROTOC", protoc); + + let mut config = prost_build::Config::new(); + // Use BTreeMap for determinism where prost would generate HashMaps. + config.btree_map(["."]); + prost_build::compile_protos(&["src/onnx/proto/onnx.proto"], &["src/onnx/proto/"])?; + Ok(()) +} diff --git a/crates/luminal_onnx/src/lib.rs b/crates/luminal_onnx/src/lib.rs new file mode 100644 index 00000000..b39b6a2f --- /dev/null +++ b/crates/luminal_onnx/src/lib.rs @@ -0,0 +1,5 @@ +pub mod onnx; + +mod load; + +pub use load::{import_onnx, OnnxImportError, OnnxImportResult}; diff --git a/crates/luminal_onnx/src/load.rs b/crates/luminal_onnx/src/load.rs new file mode 100644 index 00000000..92b76b13 --- /dev/null +++ b/crates/luminal_onnx/src/load.rs @@ -0,0 +1,519 @@ +use std::{collections::HashMap, fs::File, io::Read, path::Path}; + +use luminal::prelude::*; +use prost::Message; +use thiserror::Error; + +use crate::onnx::proto as onnx; + +#[derive(Debug, Error)] +pub enum OnnxImportError { + #[error("failed to read file: {0}")] + Io(#[from] std::io::Error), + #[error("failed to decode ONNX: {0}")] + Decode(String), + #[error("unsupported or missing op input: {0}")] + MissingInput(String), + #[error("unsupported data type {0}")] + UnsupportedDtype(i32), + #[error("shape not found or unsupported for op {0}")] + BadShape(String), + #[error("unsupported operator: {0}")] + UnsupportedOp(String), +} + +pub struct OnnxImportResult { + pub graph: Box, + pub inputs: HashMap, + pub outputs: HashMap, +} + +struct Ctx { + g: Box, + // mapping from value name to tensor + env: HashMap, + // input/output maps + inputs: HashMap, + outputs: HashMap, + // map of dim_param -> char symbol + symmap: HashMap, + next_sym: u8, +} + +impl Ctx { + fn new() -> Self { + Self { + g: Box::new(Graph::new()), + env: HashMap::default(), + inputs: HashMap::default(), + outputs: HashMap::default(), + symmap: HashMap::default(), + next_sym: b'a', + } + } + + fn sym_for(&mut self, s: &str) -> char { + if let Some(&c) = self.symmap.get(s) { + return c; + } + // allocate a new ascii letter, wrap to 'a'..'z' then 'A'..'Z' + let mut c = self.next_sym as char; + if !(c.is_ascii_alphabetic()) { + // fallback to 'a' + c = 'a'; + } + self.next_sym = self.next_sym.wrapping_add(1); + self.symmap.insert(s.to_string(), c); + c + } +} + +pub fn import_onnx(path: impl AsRef) -> Result { + let mut f = File::open(path)?; + let mut buf = Vec::new(); + f.read_to_end(&mut buf)?; + let model = + onnx::ModelProto::decode(&*buf).map_err(|e| OnnxImportError::Decode(e.to_string()))?; + let mut ctx = Ctx::new(); + let graph = model + .graph + .as_ref() + .ok_or_else(|| OnnxImportError::BadShape("model.graph".into()))?; + import_graph(&mut ctx, graph)?; + Ok(OnnxImportResult { + graph: ctx.g, + inputs: ctx.inputs, + outputs: ctx.outputs, + }) +} + +fn import_graph(ctx: &mut Ctx, gp: &onnx::GraphProto) -> Result<(), OnnxImportError> { + // Initializers + for t in &gp.initializer { + let name = t.name.clone(); + let shape = to_dims(ctx, &t.dims.iter().map(|&d| Some(d)).collect::>()); + let data = tensor_to_f32(t)?; + let gt = ctx.g.tensor(shape).set(data); + ctx.env.insert(name, gt); + } + + // Inputs (exclude initializers by name) + for vi in &gp.input { + if ctx.env.contains_key(&vi.name) { + continue; + } + if let Some(shape) = value_info_shape(ctx, vi) { + let gt = ctx.g.tensor(shape); + ctx.inputs.insert(vi.name.clone(), gt); + ctx.env.insert(vi.name.clone(), gt); + } else { + return Err(OnnxImportError::BadShape(format!("input {}", vi.name))); + } + } + + // Nodes + for n in &gp.node { + let op = n.op_type.as_str(); + match op { + "Constant" => { + // attribute 'value' -> TensorProto + let out_name = n.output.first().cloned().unwrap_or_default(); + let attr = n + .attribute + .iter() + .find(|a| a.name == "value") + .and_then(|a| a.t.as_ref()) + .ok_or_else(|| { + OnnxImportError::UnsupportedOp("Constant without value".into()) + })?; + let shape = to_dims(ctx, &attr.dims.iter().map(|&d| Some(d)).collect::>()); + let data = tensor_to_f32(attr)?; + let gt = ctx.g.tensor(shape).set(data); + ctx.env.insert(out_name, gt); + } + "Add" | "Div" | "Mul" | "Sub" | "Max" | "Min" => { + let mut a = get_input(ctx, n, 0)?; + let mut b = get_input(ctx, n, 1)?; + // naive broadcasting to the larger-rank tensor + if a.shape.len() > b.shape.len() { + b = b.expand(a.shape); + } else if b.shape.len() > a.shape.len() { + a = a.expand(b.shape); + } else { + // same rank: try to expand dims where needed + let target = a.shape; + b = b.expand(target); + } + let out = match op { + "Add" => a + b, + "Sub" => a - b, + "Mul" => a * b, + "Div" => a / b, + "Max" => a.maximum(b), + "Min" => a.minimum(b), + _ => unreachable!(), + }; + set_output(ctx, n, out); + } + "Relu" => { + let a = get_input(ctx, n, 0)?; + set_output(ctx, n, a.relu()); + } + "Sigmoid" => { + let a = get_input(ctx, n, 0)?; + set_output(ctx, n, a.sigmoid()); + } + "Tanh" => { + let a = get_input(ctx, n, 0)?; + set_output(ctx, n, a.tanh()); + } + "Sqrt" => { + let a = get_input(ctx, n, 0)?; + set_output(ctx, n, a.sqrt()); + } + "MatMul" => { + let a = get_input(ctx, n, 0)?; + let b = get_input(ctx, n, 1)?; + set_output(ctx, n, a.matmul(b)); + } + "Gemm" => { + // Y = alpha * A * B + beta * C + let a = get_input(ctx, n, 0)?; + let b = get_input(ctx, n, 1)?; + let c = get_input(ctx, n, 2).ok(); + let mut alpha = 1.0f32; + let mut beta = 1.0f32; + let mut trans_a = false; + let mut trans_b = false; + for a in &n.attribute { + match a.name.as_str() { + "alpha" => alpha = a.f, + "beta" => beta = a.f, + "transA" => trans_a = a.i != 0, + "transB" => trans_b = a.i != 0, + _ => {} + } + } + let mut aa = a; + let mut bb = b; + if trans_a { + let (m, n) = aa.dims2(); + aa = aa.permute((1, 0)).reshape((n, m)); + } + if trans_b { + let (m, n) = bb.dims2(); + bb = bb.permute((1, 0)).reshape((n, m)); + } + let mut y = aa.matmul(bb) * alpha; + if let Some(cc) = c { + y += cc * beta; + } + set_output(ctx, n, y); + } + "Softmax" => { + // default axis is 1 in ONNX (older), sometimes -1 newer - we'll honor attr if present + let a = get_input(ctx, n, 0)?; + let mut axis: i64 = 1; + for at in &n.attribute { + if at.name == "axis" { + axis = at.i; + } + } + // map negative axis + let axis = normalize_axis(axis, a.shape.len()); + set_output(ctx, n, a.softmax(axis)); + } + "Reshape" => { + // input[1] is shape; must be constant/initializer + let a = get_input(ctx, n, 0)?; + let shape_name = n.input.get(1).cloned().unwrap_or_default(); + let shape_const = ctx + .env + .get(&shape_name) + .copied() + .ok_or_else(|| OnnxImportError::MissingInput(shape_name.clone()))?; + // Pull shape values from Tensor if constant + let new_shape_ints: Vec = + shape_const.data().into_iter().map(|e| e as i64).collect(); + let mut shape_spec: Vec = Vec::with_capacity(new_shape_ints.len()); + for v in new_shape_ints.iter() { + if *v == -1 { + shape_spec.push(usize::MAX - 1); + } else if *v == 0 { + shape_spec.push(usize::MAX - 2); + } else { + shape_spec.push(*v as usize); + } + } + let new_dims = infer_reshape_dims(a.shape.dims(), &shape_spec); + set_output(ctx, n, a.reshape(new_dims)); + } + "Transpose" => { + let a = get_input(ctx, n, 0)?; + // perm attr + let mut perm: Vec = (0..a.shape.len()).collect(); + for at in &n.attribute { + if at.name == "perm" { + perm = at.ints.iter().map(|&i| i as usize).collect(); + } + } + set_output(ctx, n, a.permute(perm)); + } + "Unsqueeze" => { + let a = get_input(ctx, n, 0)?; + let mut axes: Vec = vec![]; + // Prefer second input as tensor of axes if present + if n.input.len() > 1 { + if let Ok(ax_t) = get_input(ctx, n, 1) { + let vals: Vec = ax_t.data().into_iter().map(|e| e as i64).collect(); + axes = vals + .iter() + .map(|&i| normalize_axis(i, a.shape.len() + 1)) + .collect(); + } + } + if axes.is_empty() { + for at in &n.attribute { + if at.name == "axes" { + axes = at + .ints + .iter() + .map(|&i| normalize_axis(i, a.shape.len() + 1)) + .collect(); + } + } + } + let mut out = a; + // Insert in sorted order to maintain correct indexes + axes.sort(); + for ax in axes { + out = out.unsqueeze(ax); + } + set_output(ctx, n, out); + } + "Squeeze" => { + let a = get_input(ctx, n, 0)?; + let mut axes: Vec = vec![]; + if n.input.len() > 1 { + if let Ok(ax_t) = get_input(ctx, n, 1) { + let vals: Vec = ax_t.data().into_iter().map(|e| e as i64).collect(); + axes = vals + .iter() + .map(|&i| normalize_axis(i, a.shape.len())) + .collect(); + } + } + if axes.is_empty() { + for at in &n.attribute { + if at.name == "axes" { + axes = at + .ints + .iter() + .map(|&i| normalize_axis(i, a.shape.len())) + .collect(); + } + } + } + let mut dims = a.dims(); + if axes.is_empty() { + dims.retain(|d| d.to_usize().unwrap_or(1) != 1); + } else { + // remove given axes + axes.sort(); + for ax in axes.into_iter().rev() { + dims.remove(ax); + } + } + set_output(ctx, n, a.reshape(dims)); + } + "Concat" => { + let axis = n + .attribute + .iter() + .find(|a| a.name == "axis") + .map(|a| a.i) + .unwrap_or(0); + let mut it = n.input.iter(); + let first = get_input_by_name(ctx, it.next().unwrap())?; + let mut out = first; + for name in it { + let t = get_input_by_name(ctx, name)?; + out = out.concat_along(t, normalize_axis(axis, out.shape.len())); + } + set_output(ctx, n, out); + } + other => return Err(OnnxImportError::UnsupportedOp(other.to_string())), + } + } + + // Outputs + for vo in &gp.output { + let name = &vo.name; + if let Some(&t) = ctx.env.get(name) { + ctx.outputs.insert(name.clone(), t.retrieve()); + } + } + Ok(()) +} + +fn get_input(ctx: &Ctx, n: &onnx::NodeProto, idx: usize) -> Result { + let name = n.input.get(idx).cloned().unwrap_or_default(); + get_input_by_name(ctx, &name) +} + +fn get_input_by_name(ctx: &Ctx, name: &str) -> Result { + ctx.env + .get(name) + .copied() + .ok_or_else(|| OnnxImportError::MissingInput(name.to_string())) +} + +fn set_output(ctx: &mut Ctx, n: &onnx::NodeProto, t: GraphTensor) { + for out in &n.output { + if !out.is_empty() { + ctx.env.insert(out.clone(), t); + } + } +} + +fn normalize_axis(axis: i64, rank: usize) -> usize { + if axis >= 0 { + axis as usize + } else { + (rank as i64 + axis) as usize + } +} + +fn infer_reshape_dims(old: Vec, target: &[usize]) -> Vec { + // implements ONNX reshape rules for -1 and 0 + let mut new_dims: Vec = Vec::with_capacity(target.len()); + let mut known: usize = 1; + let mut infer_at: Option = None; + for (i, &d) in target.iter().enumerate() { + if d == usize::MAX - 1 { + // sentinel for -1 + infer_at = Some(i); + new_dims.push(1.into()); + } else if d == usize::MAX - 2 { + // sentinel for 0 -> copy from input + new_dims.push(old[i]); + known *= new_dims.last().unwrap().to_usize().unwrap_or(1); + } else { + new_dims.push(d.into()); + known *= d; + } + } + if let Some(ix) = infer_at { + let total: usize = old.iter().map(|e| e.to_usize().unwrap_or(1)).product(); + let inferred = total / known.max(1); + new_dims[ix] = inferred.into(); + } + new_dims +} + +fn value_info_shape(ctx: &mut Ctx, vi: &onnx::ValueInfoProto) -> Option> { + let t = vi.r#type.as_ref()?; + let ten = match &t.value { + Some(onnx::type_proto::Value::TensorType(tt)) => tt, + _ => return None, + }; + let dims = &ten.shape.as_ref()?.dim; + // Build an intermediate vector first to avoid borrow issues + let mut dim_vals: Vec> = Vec::with_capacity(dims.len()); + for d in dims { + match &d.value { + Some(onnx::tensor_shape_proto_dimension::Value::DimValue(v)) => dim_vals.push(Some(*v)), + Some(onnx::tensor_shape_proto_dimension::Value::DimParam(p)) => { + let c = ctx.sym_for(p); + dim_vals.push(Some(-(c as i64))); + } + None => dim_vals.push(None), + } + } + Some(to_dims(ctx, &dim_vals)) +} + +fn to_dims(ctx: &mut Ctx, dims: &[Option]) -> Vec { + let mut out = vec![]; + for (i, d) in dims.iter().enumerate() { + match d { + Some(v) if *v >= 0 => out.push(Expression::from(*v as usize)), + Some(v) if *v < 0 => { + // negative marker used for dim_param -> char + let c = (-*v) as u8 as char; + out.push(Expression::from(c)); + } + None => { + // unknown -> allocate symbol based on position + let c = ctx.sym_for(&format!("dim_{i}")); + out.push(Expression::from(c)); + } + _ => unreachable!(), + } + } + out +} + +fn tensor_to_f32(t: &onnx::TensorProto) -> Result, OnnxImportError> { + use onnx::tensor_proto::DataType as Dt; + let dt = t.data_type; + let elem_count = t.dims.iter().map(|&d| d as usize).product::().max(1); + if !t.raw_data.is_empty() { + let raw = &t.raw_data; + match dt { + x if x == Dt::Float as i32 => { + let mut out = vec![0f32; raw.len() / 4]; + for (i, chunk) in raw.chunks_exact(4).take(elem_count).enumerate() { + out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + } + Ok(out) + } + x if x == Dt::Double as i32 => { + let mut out = vec![0f32; raw.len() / 8]; + for (i, chunk) in raw.chunks_exact(8).take(elem_count).enumerate() { + out[i] = f64::from_le_bytes([ + chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], + chunk[7], + ]) as f32; + } + Ok(out) + } + x if x == Dt::Int64 as i32 => { + let mut out = vec![0f32; raw.len() / 8]; + for (i, chunk) in raw.chunks_exact(8).take(elem_count).enumerate() { + out[i] = i64::from_le_bytes([ + chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], + chunk[7], + ]) as f32; + } + Ok(out) + } + x if x == Dt::Int32 as i32 => { + let mut out = vec![0f32; raw.len() / 4]; + for (i, chunk) in raw.chunks_exact(4).take(elem_count).enumerate() { + out[i] = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32; + } + Ok(out) + } + x => Err(OnnxImportError::UnsupportedDtype(x)), + } + } else if !t.float_data.is_empty() { + let mut out = t.float_data.clone(); + out.truncate(elem_count); + Ok(out) + } else if !t.int64_data.is_empty() { + Ok(t.int64_data + .iter() + .take(elem_count) + .map(|&v| v as f32) + .collect()) + } else if !t.int32_data.is_empty() { + Ok(t.int32_data + .iter() + .take(elem_count) + .map(|&v| v as f32) + .collect()) + } else { + Ok(vec![0.0; elem_count]) + } +} diff --git a/crates/luminal_onnx/src/onnx/mod.rs b/crates/luminal_onnx/src/onnx/mod.rs new file mode 100644 index 00000000..b1788e4a --- /dev/null +++ b/crates/luminal_onnx/src/onnx/mod.rs @@ -0,0 +1,2 @@ +// The ONNX protobuf definitions and helpers. +pub mod proto; diff --git a/crates/luminal_onnx/src/onnx/proto/mod.rs b/crates/luminal_onnx/src/onnx/proto/mod.rs new file mode 100644 index 00000000..538b5421 --- /dev/null +++ b/crates/luminal_onnx/src/onnx/proto/mod.rs @@ -0,0 +1 @@ +include!(concat!(env!("OUT_DIR"), "/onnx.rs")); diff --git a/crates/luminal_onnx/src/onnx/proto/onnx.proto b/crates/luminal_onnx/src/onnx/proto/onnx.proto new file mode 100644 index 00000000..49e89e93 --- /dev/null +++ b/crates/luminal_onnx/src/onnx/proto/onnx.proto @@ -0,0 +1,134 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package onnx; + +// NOTE: Truncated ONNX proto to essential messages for import. +// This subset includes only the parts used by this crate to parse +// common models. For full spec, use the official onnx.proto. + +enum Version { + _START_VERSION = 0; + IR_VERSION_2019_1_22 = 0x0000000000000004; // Relaxed initializers constraint +} + +message StringStringEntryProto { string key = 1; string value = 2; } + +message TensorShapeProto_Dimension { + oneof value { int64 dim_value = 1; string dim_param = 2; } +} + +message TensorShapeProto { repeated TensorShapeProto_Dimension dim = 1; } + +message TypeProto_Tensor { + int32 elem_type = 1; + TensorShapeProto shape = 2; +} + +message TypeProto { + oneof value { TypeProto_Tensor tensor_type = 1; } +} + +message ValueInfoProto { + string name = 1; + TypeProto type = 2; + string doc_string = 3; + repeated StringStringEntryProto metadata_props = 4; +} + +message TensorProto { + enum DataType { + UNDEFINED = 0; + FLOAT = 1; + UINT8 = 2; + INT8 = 3; + UINT16 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + STRING = 8; + BOOL = 9; + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; + COMPLEX128 = 15; + BFLOAT16 = 16; + FLOAT8E4M3FN = 17; + FLOAT8E4M3FNUZ = 18; + FLOAT8E5M2 = 19; + FLOAT8E5M2FNUZ = 20; + } + int32 data_type = 2; + string name = 7; + repeated int64 dims = 1; + + // Data fields (one of these is typically used) + repeated float float_data = 4; + repeated int32 int32_data = 5; + repeated int64 int64_data = 6; + bytes raw_data = 9; +} + +message AttributeProto { + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSOR = 11; + SPARSE_TENSORS = 12; + TYPE_PROTO = 13; + TYPE_PROTOS = 14; + } + string name = 1; + string doc_string = 13; + int32 type = 20; + float f = 2; + int64 i = 3; + bytes s = 4; + TensorProto t = 5; + repeated float floats = 6; + repeated int64 ints = 7; + repeated bytes strings = 8; + repeated TensorProto tensors = 9; +} + +message NodeProto { + repeated string input = 1; + repeated string output = 2; + string name = 3; + string op_type = 4; + string domain = 7; + repeated AttributeProto attribute = 5; + string doc_string = 6; +} + +message GraphProto { + string name = 1; + repeated NodeProto node = 2; + repeated TensorProto initializer = 5; + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + repeated ValueInfoProto value_info = 13; +} + +message OperatorSetIdProto { string domain = 1; int64 version = 2; string extension = 3; } + +message ModelProto { + int64 ir_version = 1; + repeated OperatorSetIdProto opset_import = 8; + GraphProto graph = 7; +} + diff --git a/crates/luminal_onnx/tests/smoke.rs b/crates/luminal_onnx/tests/smoke.rs new file mode 100644 index 00000000..f7c13549 --- /dev/null +++ b/crates/luminal_onnx/tests/smoke.rs @@ -0,0 +1,186 @@ +use luminal::prelude::*; +use luminal_onnx::import_onnx; +use luminal_onnx::onnx::proto as onnx; + +fn build_linear_softmax_model(n: i64, in_dim: i64, out_dim: i64) -> onnx::ModelProto { + use onnx::*; + // Shapes + let x_vi = ValueInfoProto { + name: "X".into(), + r#type: Some(TypeProto { + value: Some(type_proto::Value::TensorType(TypeProtoTensor { + elem_type: onnx::tensor_proto::DataType::Float as i32, + shape: Some(TensorShapeProto { + dim: vec![ + TensorShapeProtoDimension { + value: Some(onnx::tensor_shape_proto_dimension::Value::DimValue(n)), + }, + TensorShapeProtoDimension { + value: Some(onnx::tensor_shape_proto_dimension::Value::DimValue( + in_dim, + )), + }, + ], + }), + })), + }), + doc_string: String::new(), + metadata_props: vec![], + }; + let y_vi = ValueInfoProto { + name: "Y".into(), + r#type: Some(TypeProto { + value: Some(type_proto::Value::TensorType(TypeProtoTensor { + elem_type: onnx::tensor_proto::DataType::Float as i32, + shape: Some(TensorShapeProto { + dim: vec![ + TensorShapeProtoDimension { + value: Some(onnx::tensor_shape_proto_dimension::Value::DimValue(n)), + }, + TensorShapeProtoDimension { + value: Some(onnx::tensor_shape_proto_dimension::Value::DimValue( + out_dim, + )), + }, + ], + }), + })), + }), + doc_string: String::new(), + metadata_props: vec![], + }; + + // Initializers W [in,out], B [out] + let mut w_vals = Vec::with_capacity((in_dim * out_dim) as usize); + for i in 0..(in_dim * out_dim) { + w_vals.push((i as f32 * 0.01).sin()); + } + let w = TensorProto { + data_type: onnx::tensor_proto::DataType::Float as i32, + name: "W".into(), + dims: vec![in_dim, out_dim], + float_data: w_vals.clone(), + int32_data: vec![], + int64_data: vec![], + raw_data: vec![], + }; + let mut b_vals = Vec::with_capacity(out_dim as usize); + for i in 0..out_dim { + b_vals.push((i as f32 * 0.1).cos()); + } + let b = TensorProto { + data_type: onnx::tensor_proto::DataType::Float as i32, + name: "B".into(), + dims: vec![out_dim], + float_data: b_vals.clone(), + int32_data: vec![], + int64_data: vec![], + raw_data: vec![], + }; + + // Nodes: Y0=MatMul(X,W); Y1=Add(Y0,B); Y=Softmax(Y1) + let mm = NodeProto { + input: vec!["X".into(), "W".into()], + output: vec!["Y0".into()], + name: String::new(), + op_type: "MatMul".into(), + domain: String::new(), + attribute: vec![], + doc_string: String::new(), + }; + let add = NodeProto { + input: vec!["Y0".into(), "B".into()], + output: vec!["Y1".into()], + name: String::new(), + op_type: "Add".into(), + domain: String::new(), + attribute: vec![], + doc_string: String::new(), + }; + let sm = NodeProto { + input: vec!["Y1".into()], + output: vec!["Y".into()], + name: String::new(), + op_type: "Softmax".into(), + domain: String::new(), + attribute: vec![], + doc_string: String::new(), + }; + + let graph = GraphProto { + name: "linear_sm".into(), + node: vec![mm, add, sm], + initializer: vec![w, b], + input: vec![x_vi], + output: vec![y_vi], + value_info: vec![], + }; + ModelProto { + ir_version: onnx::Version::IrVersion2019122 as i64, + opset_import: vec![OperatorSetIdProto { + domain: String::new(), + version: 13, + extension: String::new(), + }], + graph: Some(graph), + } +} + +#[test] +fn test_import_and_run() { + let n = 2; + let in_dim = 4; + let out_dim = 3; + let model = build_linear_softmax_model(n, in_dim, out_dim); + // Write to temp file + let mut tmpf = tempfile::NamedTempFile::new().unwrap(); + use prost::Message as _; + let mut buf = Vec::new(); + model.encode(&mut buf).unwrap(); + std::io::Write::write_all(&mut tmpf, &buf).unwrap(); + let path = tmpf.path().to_path_buf(); + + // Import + let mut res = import_onnx(&path).expect("import ok"); + + // Build inputs + let mut x_data = vec![0f32; (n * in_dim) as usize]; + for (i, v) in x_data.iter_mut().enumerate() { + *v = ((i as f32) * 0.2).sin(); + } + let _x = res + .inputs + .get("X") + .copied() + .expect("X exists") + .set(x_data.clone()); + + // Execute imported graph + res.graph.execute(); + let y = res.outputs.get("Y").copied().expect("Y exists").data(); + + // Reference using luminal directly + let mut g2 = Graph::new(); + // The same initializers W,B must be reconstructed from the onnx model initializers used above + let mut w_vals = Vec::with_capacity((in_dim * out_dim) as usize); + for i in 0..(in_dim * out_dim) { + w_vals.push((i as f32 * 0.01).sin()); + } + let mut b_vals = Vec::with_capacity(out_dim as usize); + for i in 0..out_dim { + b_vals.push((i as f32 * 0.1).cos()); + } + let x2 = g2.tensor((n as usize, in_dim as usize)).set(x_data); + let w2 = g2.tensor((in_dim as usize, out_dim as usize)).set(w_vals); + let b2 = g2.tensor((out_dim as usize,)).set(b_vals); + let y2 = (x2.matmul(w2) + b2.expand_dim(0, n as usize)) + .softmax(1) + .retrieve(); + g2.execute(); + let y_ref = y2.data(); + + assert_eq!(y.len(), y_ref.len()); + for (a, b) in y.iter().zip(y_ref.iter()) { + assert!((a - b).abs() < 1e-3, "{a} vs {b}"); + } +} diff --git a/examples/moondream/src/main.rs b/examples/moondream/src/main.rs index da55be3d..67949c6b 100644 --- a/examples/moondream/src/main.rs +++ b/examples/moondream/src/main.rs @@ -165,7 +165,7 @@ fn main() { let output_id = argmax(&logits.data()); logits.drop(); output_ids.push(output_id); - println!("ID: {}", output_id); + println!("ID: {output_id}"); // Get the current decoded output let current_output = tokenizer.decode(&output_ids, false).unwrap();