diff --git a/Cargo.toml b/Cargo.toml index 68c3cb2dfd..56f069c9ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,7 +147,9 @@ half = { version = "2.7.1", features = [ ], default-features = false } macerator = { version = "0.2.9" } matrixmultiply = { version = "0.3.10", default-features = false } -ndarray = { version = "0.17.1", default-features = false } +ndarray = { version = "0.17.1", default-features = false, features = [ + "half", +] } num-traits = { version = "0.2.19", default-features = false, features = [ "libm", ] } # libm is for no_std @@ -191,3 +193,6 @@ tracel-xtask = { version = "=2.2.1" } [profile.dev] debug = 1 # Speed up compilation time and not necessary. + +[patch.crates-io] +ndarray = { git = 'https://github.com/swfsql/ndarray.git', rev = 'bc68d2f3f99680ccd7aba1069efa686bcf94dedf' } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 9a5b5fc807..938687bc59 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -71,6 +71,7 @@ openblas-src = { workspace = true, optional = true } paste = { workspace = true } rand = { workspace = true, default-features = false, features = ["small_rng"] } spin = { workspace = true } +half = { workspace = true } # SIMD bytemuck = { workspace = true, optional = true } diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index eff7837739..59e8e4c805 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -1,6 +1,6 @@ use burn_tensor::Element; use ndarray::LinalgScalar; -use num_traits::Signed; +use num_traits::{AsPrimitive, Signed}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] @@ -61,6 +61,8 @@ impl QuantElement for i8 {} impl FloatNdArrayElement for f64 {} impl FloatNdArrayElement for f32 {} +impl FloatNdArrayElement for half::bf16 {} +impl FloatNdArrayElement for half::f16 {} impl IntNdArrayElement for i64 {} impl IntNdArrayElement for i32 {} @@ -83,28 +85,35 @@ macro_rules! make_elem { impl ExpElement for $ty { #[inline(always)] fn exp_elem(self) -> Self { - (self as f64).exp() as $ty + let self_f64: f64 = self.as_(); + self_f64.exp().as_() } #[inline(always)] fn log_elem(self) -> Self { - (self as f64).ln() as $ty + let self_f64: f64 = self.as_(); + self_f64.ln().as_() } #[inline(always)] fn log1p_elem(self) -> Self { - log1p(self as f64) as $ty + let self_f64: f64 = self.as_(); + log1p(self_f64).as_() } #[inline(always)] fn powf_elem(self, value: f32) -> Self { - (self as f64).pow(value) as $ty + let self_f64: f64 = self.as_(); + self_f64.pow(value).as_() } #[inline(always)] fn powi_elem(self, value: i32) -> Self { #[cfg(feature = "std")] - let val = f64::powi(self as f64, value) as $ty; + let val = { + let self_f64: f64 = self.as_(); + f64::powi(self_f64, value).as_() + }; #[cfg(not(feature = "std"))] let val = Self::powf_elem(self, value as f32); @@ -114,17 +123,20 @@ macro_rules! make_elem { #[inline(always)] fn sqrt_elem(self) -> Self { - (self as f64).sqrt() as $ty + let self_f64: f64 = self.as_(); + self_f64.sqrt().as_() } #[inline(always)] fn abs_elem(self) -> Self { - (self as f64).abs() as $ty + let self_f64: f64 = self.as_(); + self_f64.abs().as_() } #[inline(always)] fn int_abs_elem(self) -> Self { - (self as i64).abs() as $ty + let self_i64: i64 = self.as_(); + self_i64.abs().as_() } } }; @@ -137,28 +149,35 @@ macro_rules! make_elem { impl ExpElement for $ty { #[inline(always)] fn exp_elem(self) -> Self { - (self as f32).exp() as $ty + let self_f32: f32 = self.as_(); + self_f32.exp().as_() } #[inline(always)] fn log_elem(self) -> Self { - (self as f32).ln() as $ty + let self_f32: f32 = self.as_(); + self_f32.ln().as_() } #[inline(always)] fn log1p_elem(self) -> Self { - log1pf(self as f32) as $ty + let self_f32: f32 = self.as_(); + log1pf(self_f32).as_() } #[inline(always)] fn powf_elem(self, value: f32) -> Self { - (self as f32).pow(value) as $ty + let self_f32: f32 = self.as_(); + self_f32.pow(value).as_() } #[inline(always)] fn powi_elem(self, value: i32) -> Self { #[cfg(feature = "std")] - let val = f32::powi(self as f32, value) as $ty; + let val = { + let self_f32: f32 = self.as_(); + f32::powi(self_f32, value).as_() + }; #[cfg(not(feature = "std"))] let val = Self::powf_elem(self, value as f32); @@ -168,17 +187,20 @@ macro_rules! make_elem { #[inline(always)] fn sqrt_elem(self) -> Self { - (self as f32).sqrt() as $ty + let self_f32: f32 = self.as_(); + self_f32.sqrt().as_() } #[inline(always)] fn abs_elem(self) -> Self { - (self as f32).abs() as $ty + let self_f32: f32 = self.as_(); + self_f32.abs().as_() } #[inline(always)] fn int_abs_elem(self) -> Self { - (self as i32).unsigned_abs() as $ty + let self_i32: i32 = self.as_(); + self_i32.unsigned_abs().as_() } } }; @@ -190,8 +212,12 @@ make_elem!(double u64); make_elem!(single f32); make_elem!(single i32); -make_elem!(single i16); -make_elem!(single i8); make_elem!(single u32); + +make_elem!(single half::bf16); +make_elem!(single half::f16); +make_elem!(single i16); make_elem!(single u16); + +make_elem!(single i8); make_elem!(single u8); diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index ebf9d92000..b3fc63d066 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -4,6 +4,7 @@ use burn_tensor::{DType, quantization::QuantValue}; use burn_tensor::{ElementConversion, Slice}; use core::fmt::Debug; use core::marker::PhantomData; +use half::f16; use ndarray::IntoDimension; use ndarray::SliceInfo; use ndarray::Zip; @@ -453,7 +454,7 @@ where { pub fn add(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( - E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 + E, VecAdd, lhs, rhs, u8, i8, u16, i16, f16, u32, i32, f32, u64, i64, f64 ); let array = &lhs + &rhs; @@ -470,6 +471,7 @@ where i8, u16, i16, + f16, u32, i32, f32, @@ -484,7 +486,7 @@ where pub fn sub(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( - E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 + E, VecSub, lhs, rhs, u8, i8, u16, i16, f16, u32, i32, f32, u64, i64, f64 ); let array = lhs - rhs; @@ -501,6 +503,7 @@ where i8, u16, i16, + f16, u32, i32, f32, @@ -515,7 +518,7 @@ where pub fn mul(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = - dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64); + dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, f16, u32, i32, f32, f64); let array = lhs * rhs; array.into_shared() @@ -530,6 +533,7 @@ where rhs.elem(), u16, i16, + f16, u32, i32, f32, @@ -541,14 +545,14 @@ where } pub fn div(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64); + let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f16, f32, f64); let array = lhs / rhs; array.into_shared() } pub fn div_scalar(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64); + let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f16, f32, f64); let array = lhs / rhs; array.into_shared() @@ -823,6 +827,7 @@ where i8, u16, i16, + f16, u32, i32, f32, @@ -849,6 +854,7 @@ where i8, u16, i16, + f16, u32, i32, f32, @@ -875,6 +881,7 @@ where i8, u16, i16, + f16, u32, i32, f32, @@ -933,14 +940,14 @@ where } pub(crate) fn abs(tensor: SharedArray) -> SharedArray { - let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); + let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, f16, i32, f32, f64); tensor.mapv_into(|a| a.abs_elem()).into_shared() } pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( - E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + E, VecEquals, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape @@ -962,6 +969,7 @@ where i8, u16, i16, + f16, u32, f32, i32, @@ -975,7 +983,7 @@ where pub(crate) fn greater(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( - E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + E, VecGreater, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape @@ -997,6 +1005,7 @@ where i8, u16, i16, + f16, u32, f32, i32, @@ -1018,6 +1027,7 @@ where i8, u16, i16, + f16, u32, f32, i32, @@ -1045,6 +1055,7 @@ where i8, u16, i16, + f16, u32, f32, i32, @@ -1058,7 +1069,7 @@ where pub(crate) fn lower_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( - E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape @@ -1080,6 +1091,7 @@ where i8, u16, i16, + f16, u32, f32, i32, @@ -1093,7 +1105,7 @@ where pub(crate) fn lower(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( - E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + E, VecLower, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape @@ -1116,6 +1128,7 @@ where i8, u16, i16, + f16, u32, f32, i32, diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index ec32715c54..aeec4aaed4 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -357,7 +357,9 @@ where execute_with_int_dtype!(lhs, I, |lhs| -> NdArrayTensor { execute_with_float_dtype!(rhs, E, |rhs| { NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| { - (a.elem::().pow(*b as u32)).elem() + use num_traits::AsPrimitive; + let b_u32: u32 = (*b).as_(); + (a.elem::().pow(b_u32)).elem() }) }) }) diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index e56bd5fe41..847b67debc 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -25,6 +25,20 @@ macro_rules! module_op { (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{ #[allow(unused_parens, unreachable_patterns)] match ($($x),+) { + ($(NdArrayTensor::F16($x)),+) => { + type $element = half::f16; + $op( + $($x),+ + $(, $opt.map(|o| match o { NdArrayTensor::F16(val) => val, _ => panic!("Optional argument type mismatch") }))* + ) + } + ($(NdArrayTensor::BF16($x)),+) => { + type $element = half::bf16; + $op( + $($x),+ + $(, $opt.map(|o| match o { NdArrayTensor::BF16(val) => val, _ => panic!("Optional argument type mismatch") }))* + ) + } ($(NdArrayTensor::F32($x)),+) => { type $element = f32; $op( diff --git a/crates/burn-ndarray/src/ops/simd/avgpool.rs b/crates/burn-ndarray/src/ops/simd/avgpool.rs index 2b381789d6..0cbd67787a 100644 --- a/crates/burn-ndarray/src/ops/simd/avgpool.rs +++ b/crates/burn-ndarray/src/ops/simd/avgpool.rs @@ -42,6 +42,15 @@ pub(crate) fn try_avg_pool2d_simd( padding, with_pad, ))), + DType::F16 if is_accelerated::(PhantomData) => { + Ok(cast(avg_pool_nhwc::( + cast(x), + ksize, + stride, + padding, + with_pad, + ))) + } _ => Err(x), } } diff --git a/crates/burn-ndarray/src/ops/simd/base.rs b/crates/burn-ndarray/src/ops/simd/base.rs index 005316f72b..25a8b165e3 100644 --- a/crates/burn-ndarray/src/ops/simd/base.rs +++ b/crates/burn-ndarray/src/ops/simd/base.rs @@ -94,6 +94,16 @@ macro_rules! impl_minmax { impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64); +impl MinMax for half::f16 { + fn min(self, other: Self) -> Self { + self.min(other) + } + + fn max(self, other: Self) -> Self { + self.max(other) + } +} + impl MinMax for f32 { fn min(self, other: Self) -> Self { self.min(other) diff --git a/crates/burn-ndarray/src/ops/simd/conv.rs b/crates/burn-ndarray/src/ops/simd/conv.rs index ef7bf634a3..078a3eab5f 100644 --- a/crates/burn-ndarray/src/ops/simd/conv.rs +++ b/crates/burn-ndarray/src/ops/simd/conv.rs @@ -25,6 +25,7 @@ pub fn try_conv2d_simd( match E::dtype() { DType::F64 => conv2d::(x, weight, bias, options, PhantomData), DType::F32 => conv2d::(x, weight, bias, options, PhantomData), + DType::F16 => conv2d::(x, weight, bias, options, PhantomData), DType::I64 => conv2d::(x, weight, bias, options, PhantomData), DType::I32 => conv2d::(x, weight, bias, options, PhantomData), DType::I16 => conv2d::(x, weight, bias, options, PhantomData), diff --git a/crates/burn-ndarray/src/ops/simd/maxpool.rs b/crates/burn-ndarray/src/ops/simd/maxpool.rs index 62eaf83154..661695d7ba 100644 --- a/crates/burn-ndarray/src/ops/simd/maxpool.rs +++ b/crates/burn-ndarray/src/ops/simd/maxpool.rs @@ -23,6 +23,7 @@ macro_rules! launch_kernel { match <$ty as Element>::dtype() { DType::F64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::F32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::F16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index fa4c243682..f69c1069ee 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -453,7 +453,7 @@ where } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { - cat_with_dtype!(tensors, dim, [F64, F32]) + cat_with_dtype!(tensors, dim, [F64, F32, BF16, F16]) } fn float_clamp_min(tensor: FloatTensor, min: E) -> FloatTensor { @@ -488,7 +488,7 @@ where execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op( lhs, rhs, - |a: &E, b: &E| a.powf(*b) + |a: &E, b: &E| num_traits::Float::powf(*a, *b) )) } diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 3dbb859513..e5719537c4 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -18,6 +18,8 @@ pub type SharedArray = ArcArray; pub enum NdArrayTensor { F64(SharedArray), F32(SharedArray), + BF16(SharedArray), + F16(SharedArray), I64(SharedArray), I32(SharedArray), I16(SharedArray), @@ -54,6 +56,8 @@ where DType::F64 => cast::(array).into(), DType::F32 => cast::(array).into(), DType::Flex32 => cast::(array).into(), + DType::BF16 => cast::(array).into(), + DType::F16 => cast::(array).into(), DType::I64 => cast::(array).into(), DType::I32 => cast::(array).into(), DType::I16 => cast::(array).into(), @@ -78,7 +82,7 @@ macro_rules! impl_from { } impl_from!( - f64 => F64, f32 => F32, + f64 => F64, f32 => F32, half::bf16 => BF16, half::f16 => F16, i64 => I64, i32 => I32, i16 => I16, i8 => I8, u64 => U64, u32 => U32, u16 => U16, u8 => U8, bool => Bool @@ -116,7 +120,7 @@ macro_rules! execute_with_dtype { // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32, + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8, Bool => bool @@ -144,7 +148,7 @@ macro_rules! execute_with_dtype { // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32, + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8, Bool => bool @@ -168,7 +172,7 @@ macro_rules! execute_with_float_dtype { // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32 + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16 ]) }}; @@ -180,7 +184,7 @@ macro_rules! execute_with_float_dtype { // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32 + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16 ]) }}; } @@ -236,7 +240,7 @@ macro_rules! execute_with_numeric_dtype { // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32, + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) @@ -250,7 +254,7 @@ macro_rules! execute_with_numeric_dtype { // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32, + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) @@ -273,7 +277,7 @@ macro_rules! cat_with_dtype { if let NdArrayTensor::$dtype(tensor) = t { tensor.view() } else { - panic!("Concatenate data type mismatch (expected f32, got f64)") + panic!("Concatenate data type mismatch (expected $dtype, got {:?})", t.dtype()) } }) .collect::>(); @@ -289,6 +293,8 @@ impl TensorMetadata for NdArrayTensor { match self { NdArrayTensor::F64(_) => DType::F64, NdArrayTensor::F32(_) => DType::F32, + NdArrayTensor::BF16(_) => DType::BF16, + NdArrayTensor::F16(_) => DType::F16, NdArrayTensor::I64(_) => DType::I64, NdArrayTensor::I32(_) => DType::I32, NdArrayTensor::I16(_) => DType::I16, @@ -476,7 +482,7 @@ impl NdArrayTensor { } execute!(data, [ - F64 => f64, F32 => f32, + F64 => f64, F32 => f32, BF16 => half::bf16, F16 => half::f16, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8, Bool => bool