Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ half = { version = "2.7.1", features = [
], default-features = false }
macerator = { version = "0.2.9" }
matrixmultiply = { version = "0.3.10", default-features = false }
ndarray = { version = "0.17.1", default-features = false }
ndarray = { version = "0.17.1", default-features = false, features = [
"half",
] }
num-traits = { version = "0.2.19", default-features = false, features = [
"libm",
] } # libm is for no_std
Expand Down Expand Up @@ -191,3 +193,6 @@ tracel-xtask = { version = "=2.2.1" }

[profile.dev]
debug = 1 # Speed up compilation time and not necessary.

[patch.crates-io]
ndarray = { git = 'https://github.com/swfsql/ndarray.git', rev = 'bc68d2f3f99680ccd7aba1069efa686bcf94dedf' }
1 change: 1 addition & 0 deletions crates/burn-ndarray/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ openblas-src = { workspace = true, optional = true }
paste = { workspace = true }
rand = { workspace = true, default-features = false, features = ["small_rng"] }
spin = { workspace = true }
half = { workspace = true }

# SIMD
bytemuck = { workspace = true, optional = true }
Expand Down
64 changes: 45 additions & 19 deletions crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_tensor::Element;
use ndarray::LinalgScalar;
use num_traits::Signed;
use num_traits::{AsPrimitive, Signed};

#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
Expand Down Expand Up @@ -61,6 +61,8 @@ impl QuantElement for i8 {}

impl FloatNdArrayElement for f64 {}
impl FloatNdArrayElement for f32 {}
impl FloatNdArrayElement for half::bf16 {}
impl FloatNdArrayElement for half::f16 {}

impl IntNdArrayElement for i64 {}
impl IntNdArrayElement for i32 {}
Expand All @@ -83,28 +85,35 @@ macro_rules! make_elem {
impl ExpElement for $ty {
#[inline(always)]
fn exp_elem(self) -> Self {
(self as f64).exp() as $ty
let self_f64: f64 = self.as_();
self_f64.exp().as_()
}

#[inline(always)]
fn log_elem(self) -> Self {
(self as f64).ln() as $ty
let self_f64: f64 = self.as_();
self_f64.ln().as_()
}

#[inline(always)]
fn log1p_elem(self) -> Self {
log1p(self as f64) as $ty
let self_f64: f64 = self.as_();
log1p(self_f64).as_()
}

#[inline(always)]
fn powf_elem(self, value: f32) -> Self {
(self as f64).pow(value) as $ty
let self_f64: f64 = self.as_();
self_f64.pow(value).as_()
}

#[inline(always)]
fn powi_elem(self, value: i32) -> Self {
#[cfg(feature = "std")]
let val = f64::powi(self as f64, value) as $ty;
let val = {
let self_f64: f64 = self.as_();
f64::powi(self_f64, value).as_()
};

#[cfg(not(feature = "std"))]
let val = Self::powf_elem(self, value as f32);
Expand All @@ -114,17 +123,20 @@ macro_rules! make_elem {

#[inline(always)]
fn sqrt_elem(self) -> Self {
(self as f64).sqrt() as $ty
let self_f64: f64 = self.as_();
self_f64.sqrt().as_()
}

#[inline(always)]
fn abs_elem(self) -> Self {
(self as f64).abs() as $ty
let self_f64: f64 = self.as_();
self_f64.abs().as_()
}

#[inline(always)]
fn int_abs_elem(self) -> Self {
(self as i64).abs() as $ty
let self_i64: i64 = self.as_();
self_i64.abs().as_()
}
}
};
Expand All @@ -137,28 +149,35 @@ macro_rules! make_elem {
impl ExpElement for $ty {
#[inline(always)]
fn exp_elem(self) -> Self {
(self as f32).exp() as $ty
let self_f32: f32 = self.as_();
self_f32.exp().as_()
}

#[inline(always)]
fn log_elem(self) -> Self {
(self as f32).ln() as $ty
let self_f32: f32 = self.as_();
self_f32.ln().as_()
}

#[inline(always)]
fn log1p_elem(self) -> Self {
log1pf(self as f32) as $ty
let self_f32: f32 = self.as_();
log1pf(self_f32).as_()
}

#[inline(always)]
fn powf_elem(self, value: f32) -> Self {
(self as f32).pow(value) as $ty
let self_f32: f32 = self.as_();
self_f32.pow(value).as_()
}

#[inline(always)]
fn powi_elem(self, value: i32) -> Self {
#[cfg(feature = "std")]
let val = f32::powi(self as f32, value) as $ty;
let val = {
let self_f32: f32 = self.as_();
f32::powi(self_f32, value).as_()
};

#[cfg(not(feature = "std"))]
let val = Self::powf_elem(self, value as f32);
Expand All @@ -168,17 +187,20 @@ macro_rules! make_elem {

#[inline(always)]
fn sqrt_elem(self) -> Self {
(self as f32).sqrt() as $ty
let self_f32: f32 = self.as_();
self_f32.sqrt().as_()
}

#[inline(always)]
fn abs_elem(self) -> Self {
(self as f32).abs() as $ty
let self_f32: f32 = self.as_();
self_f32.abs().as_()
}

#[inline(always)]
fn int_abs_elem(self) -> Self {
(self as i32).unsigned_abs() as $ty
let self_i32: i32 = self.as_();
self_i32.unsigned_abs().as_()
}
}
};
Expand All @@ -190,8 +212,12 @@ make_elem!(double u64);

make_elem!(single f32);
make_elem!(single i32);
make_elem!(single i16);
make_elem!(single i8);
make_elem!(single u32);

make_elem!(single half::bf16);
make_elem!(single half::f16);
make_elem!(single i16);
make_elem!(single u16);

make_elem!(single i8);
make_elem!(single u8);
33 changes: 23 additions & 10 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use burn_tensor::{DType, quantization::QuantValue};
use burn_tensor::{ElementConversion, Slice};
use core::fmt::Debug;
use core::marker::PhantomData;
use half::f16;
use ndarray::IntoDimension;
use ndarray::SliceInfo;
use ndarray::Zip;
Expand Down Expand Up @@ -453,7 +454,7 @@ where
{
pub fn add(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
let (lhs, rhs) = dispatch_binary_simd!(
E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
E, VecAdd, lhs, rhs, u8, i8, u16, i16, f16, u32, i32, f32, u64, i64, f64
);

let array = &lhs + &rhs;
Expand All @@ -470,6 +471,7 @@ where
i8,
u16,
i16,
f16,
u32,
i32,
f32,
Expand All @@ -484,7 +486,7 @@ where

pub fn sub(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
let (lhs, rhs) = dispatch_binary_simd!(
E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
E, VecSub, lhs, rhs, u8, i8, u16, i16, f16, u32, i32, f32, u64, i64, f64
);

let array = lhs - rhs;
Expand All @@ -501,6 +503,7 @@ where
i8,
u16,
i16,
f16,
u32,
i32,
f32,
Expand All @@ -515,7 +518,7 @@ where

pub fn mul(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
let (lhs, rhs) =
dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64);
dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, f16, u32, i32, f32, f64);

let array = lhs * rhs;
array.into_shared()
Expand All @@ -530,6 +533,7 @@ where
rhs.elem(),
u16,
i16,
f16,
u32,
i32,
f32,
Expand All @@ -541,14 +545,14 @@ where
}

pub fn div(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64);
let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f16, f32, f64);

let array = lhs / rhs;
array.into_shared()
}

pub fn div_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {
let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64);
let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f16, f32, f64);

let array = lhs / rhs;
array.into_shared()
Expand Down Expand Up @@ -823,6 +827,7 @@ where
i8,
u16,
i16,
f16,
u32,
i32,
f32,
Expand All @@ -849,6 +854,7 @@ where
i8,
u16,
i16,
f16,
u32,
i32,
f32,
Expand All @@ -875,6 +881,7 @@ where
i8,
u16,
i16,
f16,
u32,
i32,
f32,
Expand Down Expand Up @@ -933,14 +940,14 @@ where
}

pub(crate) fn abs(tensor: SharedArray<E>) -> SharedArray<E> {
let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64);
let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, f16, i32, f32, f64);

tensor.mapv_into(|a| a.abs_elem()).into_shared()
}

pub(crate) fn equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
let (lhs, rhs) = dispatch_cmp_simd!(
E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
E, VecEquals, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64
);

// Use the helper to broadcast both arrays to a common shape
Expand All @@ -962,6 +969,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand All @@ -975,7 +983,7 @@ where

pub(crate) fn greater(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
let (lhs, rhs) = dispatch_cmp_simd!(
E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
E, VecGreater, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64
);

// Use the helper to broadcast both arrays to a common shape
Expand All @@ -997,6 +1005,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand All @@ -1018,6 +1027,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand Down Expand Up @@ -1045,6 +1055,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand All @@ -1058,7 +1069,7 @@ where

pub(crate) fn lower_equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
let (lhs, rhs) = dispatch_cmp_simd!(
E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64
);

// Use the helper to broadcast both arrays to a common shape
Expand All @@ -1080,6 +1091,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand All @@ -1093,7 +1105,7 @@ where

pub(crate) fn lower(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
let (lhs, rhs) = dispatch_cmp_simd!(
E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
E, VecLower, lhs, rhs, u8, i8, u16, i16, f16, u32, f32, i32, u64, i64, f64
);

// Use the helper to broadcast both arrays to a common shape
Expand All @@ -1116,6 +1128,7 @@ where
i8,
u16,
i16,
f16,
u32,
f32,
i32,
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ where
execute_with_int_dtype!(lhs, I, |lhs| -> NdArrayTensor {
execute_with_float_dtype!(rhs, E, |rhs| {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| {
(a.elem::<i64>().pow(*b as u32)).elem()
use num_traits::AsPrimitive;
let b_u32: u32 = (*b).as_();
(a.elem::<i64>().pow(b_u32)).elem()
})
})
})
Expand Down
Loading
Loading