diff --git a/burn-book/src/advanced/backend-extension/README.md b/burn-book/src/advanced/backend-extension/README.md index 59ec2ce519..3e62eb072e 100644 --- a/burn-book/src/advanced/backend-extension/README.md +++ b/burn-book/src/advanced/backend-extension/README.md @@ -34,7 +34,7 @@ pub trait Backend: burn::tensor::backend::Backend { You can then implement your new custom backend trait for any backend that you want to support: ```rust, ignore -impl Backend for burn_tch::LibTorch { +impl Backend for burn_tch::LibTorch { fn my_new_function(tensor: TchTensor) -> TchTensor { // My Tch implementation } @@ -63,7 +63,7 @@ impl Backend for burn_autodiff::Autodiff { } } -impl Backend for burn_autodiff::Autodiff> { +impl Backend for burn_autodiff::Autodiff> { fn my_new_function(tensor: AutodiffTensor) -> AutodiffTensor { // My own backward implementation, generic over a backend implementation. // diff --git a/crates/burn-backend/src/backend/base.rs b/crates/burn-backend/src/backend/base.rs index 98d9a27663..31a75bd4d0 100644 --- a/crates/burn-backend/src/backend/base.rs +++ b/crates/burn-backend/src/backend/base.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::element::Element; -use crate::ops::*; use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; +use crate::{ElementComparison, ops::*}; use crate::{QTensorPrimitive, TensorData, TensorMetadata}; use super::DeviceOps; @@ -83,12 +83,12 @@ pub trait Backend: /// Tensor primitive to be used for all float operations. type FloatTensorPrimitive: TensorMetadata + 'static; /// Default float element type. - type FloatElem: Element; + type FloatElem: Element + ElementComparison; /// Tensor primitive to be used for all int operations. type IntTensorPrimitive: TensorMetadata + 'static; /// Int element type. - type IntElem: Element; + type IntElem: Element + ElementComparison; /// Tensor primitive to be used for all bool operations. type BoolTensorPrimitive: TensorMetadata + 'static; diff --git a/crates/burn-backend/src/backend/ops/sort.rs b/crates/burn-backend/src/backend/ops/sort.rs index 261d7218dc..5350c57653 100644 --- a/crates/burn-backend/src/backend/ops/sort.rs +++ b/crates/burn-backend/src/backend/ops/sort.rs @@ -34,7 +34,7 @@ pub fn sort + BasicOps>( descending: bool, ) -> K::Primitive where - >::Elem: Element, + >::Elem: ElementComparison, { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; @@ -51,7 +51,7 @@ pub fn sort_data + BasicOps>( descending: bool, ) -> K::Primitive where - >::Elem: Element, + >::Elem: ElementComparison, { let dims = data.shape.clone(); let data_slice = data.as_mut_slice().unwrap(); @@ -92,7 +92,7 @@ pub fn sort_with_indices + BasicOps>( descending: bool, ) -> (K::Primitive, IntTensor) where - >::Elem: Element, + >::Elem: ElementComparison, { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; @@ -109,7 +109,7 @@ fn sort_data_with_indices + BasicOps>( descending: bool, ) -> (K::Primitive, IntTensor) where - >::Elem: Element, + >::Elem: Element + ElementComparison, { let dims = data.shape.clone(); let mut indices_data = dim_indices::(&dims, dim); @@ -191,7 +191,7 @@ pub fn argsort + BasicOps>( descending: bool, ) -> IntTensor where - >::Elem: Element, + >::Elem: ElementComparison, { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; @@ -209,7 +209,7 @@ fn argsort_data + BasicOps>( descending: bool, ) -> IntTensor where - >::Elem: Element, + >::Elem: ElementComparison, { let dims = data.shape.clone(); let mut indices_data = dim_indices::(&dims, dim); @@ -252,7 +252,7 @@ fn sort_slice>( permute_both: bool, descending: bool, ) where - >::Elem: Element, + >::Elem: ElementComparison, { let ndims = dims.len(); let strides = compute_strides(dims); diff --git a/crates/burn-backend/src/data/compare.rs b/crates/burn-backend/src/data/compare.rs index 71702cd520..4a953cefb4 100644 --- a/crates/burn-backend/src/data/compare.rs +++ b/crates/burn-backend/src/data/compare.rs @@ -4,7 +4,7 @@ use burn_std::{DType, bf16, f16}; use num_traits::{Float, ToPrimitive}; use super::TensorData; -use crate::element::Element; +use crate::{ElementComparison, element::Element}; /// The tolerance used to compare to floating point numbers. /// @@ -269,7 +269,7 @@ impl TensorData { let mut num_diff = 0; let max_num_diff = 5; for (i, (a, b)) in self.iter::().zip(other.iter::()).enumerate() { - if a.cmp(&b).is_ne() { + if !a.eq(&b) { // Only print the first 5 different values. if num_diff < max_num_diff { message += format!("\n => Position {i}: {a} != {b}").as_str(); @@ -362,7 +362,7 @@ impl TensorData { /// /// If any value is not within the half-open range bounded inclusively below /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { + pub fn assert_within_range(&self, range: core::ops::Range) { for elem in self.iter::() { if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() { panic!("Element ({elem:?}) is not within range {range:?}"); @@ -379,7 +379,10 @@ impl TensorData { /// # Panics /// /// If any value is not within the half-open range bounded inclusively (`start..=end`). - pub fn assert_within_range_inclusive(&self, range: core::ops::RangeInclusive) { + pub fn assert_within_range_inclusive( + &self, + range: core::ops::RangeInclusive, + ) { let start = range.start(); let end = range.end(); diff --git a/crates/burn-backend/src/element/base.rs b/crates/burn-backend/src/element/base.rs index 620fc57274..a5e09ec89a 100644 --- a/crates/burn-backend/src/element/base.rs +++ b/crates/burn-backend/src/element/base.rs @@ -14,7 +14,7 @@ pub trait Element: ToElement + ElementRandom + ElementConversion - + ElementComparison + + ElementEquality + ElementLimits + bytemuck::CheckedBitPattern + bytemuck::NoUninit @@ -63,6 +63,12 @@ pub trait ElementRandom { fn random(distribution: Distribution, rng: &mut R) -> Self; } +/// Element trait for equality of a tensor. +pub trait ElementEquality { + /// Returns whether `self` and `other` are equal. + fn eq(&self, other: &Self) -> bool; +} + /// Element ordering trait. pub trait ElementComparison { /// Returns and [Ordering] between `self` and `other`. @@ -104,6 +110,11 @@ macro_rules! make_element { $dtype } } + impl ElementEquality for $type { + fn eq(&self, other: &Self) -> bool { + self == other + } + } impl ElementConversion for $type { #[inline(always)] diff --git a/crates/burn-backend/src/tensor/ops/float.rs b/crates/burn-backend/src/tensor/ops/float.rs index 979b0794b2..e4af31bad6 100644 --- a/crates/burn-backend/src/tensor/ops/float.rs +++ b/crates/burn-backend/src/tensor/ops/float.rs @@ -6,7 +6,8 @@ use crate::{ element::ElementConversion, ops::TransactionPrimitive, tensor::{ - BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, TensorKind, + BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered, + TensorKind, }, }; @@ -440,6 +441,102 @@ impl Numeric for Float { } } + fn abs(tensor: Self::Primitive) -> Self::Primitive { + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)), + } + } + + fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + q_bin_ops!(lhs, rhs, float_powf, q_powf) + } + + fn powf_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()), + } + } + + fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + q_bin_ops!(lhs, rhs, float_powf, q_powf) + } + + fn powi_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()), + } + } + + fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { + TensorPrimitive::Float(B::float_random(shape, distribution, device)) + } + + fn sign(tensor: Self::Primitive) -> Self::Primitive { + TensorPrimitive::Float(B::float_sign(tensor.tensor())) + } + + fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_sort(tensor, dim, descending)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending)) + } + } + } + + fn sort_with_indices( + tensor: Self::Primitive, + dim: usize, + descending: bool, + ) -> (Self::Primitive, IntTensor) { + match tensor { + TensorPrimitive::Float(tensor) => { + let (values, indices) = B::float_sort_with_indices(tensor, dim, descending); + (TensorPrimitive::Float(values), indices) + } + TensorPrimitive::QFloat(tensor) => { + let (values, indices) = B::q_sort_with_indices(tensor, dim, descending); + (TensorPrimitive::QFloat(values), indices) + } + } + } + + fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { + match tensor { + TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending), + TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending), + } + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors don't have a compatible shape. + fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_matmul(lhs, rhs)) + } + (lhs, rhs) => B::q_matmul(lhs, rhs), + } + } +} +impl Ordered for Float +where + ::FloatElem: crate::element::ElementComparison, +{ fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)), @@ -560,140 +657,48 @@ impl Numeric for Float { } } - fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive { - match tensor { - TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_clamp(tensor, min, max)) - } - TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max), - } - } - - fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive { + fn max_abs(tensor: Self::Primitive) -> Self::Primitive { match tensor { - TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_clamp_min(tensor, min)) - } - TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min), + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)), } } - fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive { + fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_clamp_max(tensor, max)) - } - TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max), - } - } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - match tensor { - TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)), - } - } - - fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - q_bin_ops!(lhs, rhs, float_powf, q_powf) - } - - fn powf_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { - match lhs { - TensorPrimitive::Float(lhs) => { - TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem())) + TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim)) } - TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()), - } - } - - fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - q_bin_ops!(lhs, rhs, float_powf, q_powf) - } - - fn powi_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { - match lhs { - TensorPrimitive::Float(lhs) => { - TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem())) + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim)) } - TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()), } } - fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { - TensorPrimitive::Float(B::float_random(shape, distribution, device)) - } - - fn sign(tensor: Self::Primitive) -> Self::Primitive { - TensorPrimitive::Float(B::float_sign(tensor.tensor())) - } - - fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { + fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_sort(tensor, dim, descending)) - } - TensorPrimitive::QFloat(tensor) => { - TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending)) + TensorPrimitive::Float(B::float_clamp(tensor, min, max)) } + TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max), } } - fn sort_with_indices( - tensor: Self::Primitive, - dim: usize, - descending: bool, - ) -> (Self::Primitive, IntTensor) { + fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { - let (values, indices) = B::float_sort_with_indices(tensor, dim, descending); - (TensorPrimitive::Float(values), indices) - } - TensorPrimitive::QFloat(tensor) => { - let (values, indices) = B::q_sort_with_indices(tensor, dim, descending); - (TensorPrimitive::QFloat(values), indices) + TensorPrimitive::Float(B::float_clamp_min(tensor, min)) } + TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min), } } - fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { - match tensor { - TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending), - TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending), - } - } - - fn max_abs(tensor: Self::Primitive) -> Self::Primitive { - match tensor { - TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)), - } - } - - fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim)) - } - TensorPrimitive::QFloat(tensor) => { - TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim)) - } - } - } - - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors don't have a compatible shape. - fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - match (lhs, rhs) { - (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { - TensorPrimitive::Float(B::float_matmul(lhs, rhs)) + TensorPrimitive::Float(B::float_clamp_max(tensor, max)) } - (lhs, rhs) => B::q_matmul(lhs, rhs), + TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max), } } } diff --git a/crates/burn-backend/src/tensor/ops/int.rs b/crates/burn-backend/src/tensor/ops/int.rs index 7cfca73d16..7dfb3d253c 100644 --- a/crates/burn-backend/src/tensor/ops/int.rs +++ b/crates/burn-backend/src/tensor/ops/int.rs @@ -2,12 +2,12 @@ use alloc::vec::Vec; use burn_std::{DType, Shape, Slice}; use crate::{ - AutodiffBackend, Backend, Distribution, ExecutionError, TensorData, + AutodiffBackend, Backend, Distribution, ExecutionError, TensorData, backend, element::ElementConversion, ops::TransactionPrimitive, tensor::{ BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric, - TensorKind, + Ordered, TensorKind, }, }; @@ -259,6 +259,82 @@ impl Numeric for Int { B::int_cumprod(tensor, dim) } + fn abs(tensor: Self::Primitive) -> Self::Primitive { + B::int_abs(tensor) + } + + fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + B::int_powf(lhs, B::int_into_float(rhs)) + } + + fn powf_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { + B::int_powf_scalar(lhs, rhs.elem()) + } + + fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + B::int_powi(lhs, rhs) + } + + fn powi_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { + B::int_powi_scalar(lhs, rhs.elem()) + } + + fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { + B::int_random(shape, distribution, device) + } + + fn sign(tensor: Self::Primitive) -> Self::Primitive { + B::int_sign(tensor) + } + + fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { + B::int_sort(tensor, dim, descending) + } + + fn sort_with_indices( + tensor: Self::Primitive, + dim: usize, + descending: bool, + ) -> (Self::Primitive, IntTensor) { + B::int_sort_with_indices(tensor, dim, descending) + } + + fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { + B::int_argsort(tensor, dim, descending) + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors don't have a compatible shape. + fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { + B::int_matmul(lhs, rhs) + } +} + +impl BasicAutodiffOps for Int { + type InnerKind = Int; + + fn inner( + tensor: >::Primitive, + ) -> ::InnerBackend>>::Primitive { + B::int_inner(tensor) + } + + fn from_inner( + inner: ::InnerBackend>>::Primitive, + ) -> >::Primitive { + B::int_from_inner(inner) + } +} + +impl Ordered for Int +where + ::IntElem: crate::element::ElementComparison, +{ fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_cummin(tensor, dim) } @@ -356,75 +432,4 @@ impl Numeric for Int { fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive { B::int_clamp_max(tensor, max) } - - fn abs(tensor: Self::Primitive) -> Self::Primitive { - B::int_abs(tensor) - } - - fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::int_powf(lhs, B::int_into_float(rhs)) - } - - fn powf_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { - B::int_powf_scalar(lhs, rhs.elem()) - } - - fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::int_powi(lhs, rhs) - } - - fn powi_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { - B::int_powi_scalar(lhs, rhs.elem()) - } - - fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { - B::int_random(shape, distribution, device) - } - - fn sign(tensor: Self::Primitive) -> Self::Primitive { - B::int_sign(tensor) - } - - fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { - B::int_sort(tensor, dim, descending) - } - - fn sort_with_indices( - tensor: Self::Primitive, - dim: usize, - descending: bool, - ) -> (Self::Primitive, IntTensor) { - B::int_sort_with_indices(tensor, dim, descending) - } - - fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { - B::int_argsort(tensor, dim, descending) - } - - /// Applies the matrix multiplication operation. - /// - /// `C = AB` - /// - /// # Panics - /// - /// If the two tensors don't have a compatible shape. - fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { - B::int_matmul(lhs, rhs) - } -} - -impl BasicAutodiffOps for Int { - type InnerKind = Int; - - fn inner( - tensor: >::Primitive, - ) -> ::InnerBackend>>::Primitive { - B::int_inner(tensor) - } - - fn from_inner( - inner: ::InnerBackend>>::Primitive, - ) -> >::Primitive { - B::int_from_inner(inner) - } } diff --git a/crates/burn-backend/src/tensor/ops/mod.rs b/crates/burn-backend/src/tensor/ops/mod.rs index 5a5514c487..6852450021 100644 --- a/crates/burn-backend/src/tensor/ops/mod.rs +++ b/crates/burn-backend/src/tensor/ops/mod.rs @@ -4,10 +4,12 @@ mod bool; mod float; mod int; mod numeric; +mod orderable; pub use autodiff::*; pub use base::*; pub use numeric::*; +pub use orderable::*; /// Computation to be used to update the existing values in indexed assignment operations (scatter/select). #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/crates/burn-backend/src/tensor/ops/numeric.rs b/crates/burn-backend/src/tensor/ops/numeric.rs index 6b0b859686..4ae62f494c 100644 --- a/crates/burn-backend/src/tensor/ops/numeric.rs +++ b/crates/burn-backend/src/tensor/ops/numeric.rs @@ -477,557 +477,6 @@ where /// function, which is more high-level and designed for public use. fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - /// Computes the cumulative minimum of elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the cumulative minimum of. - /// * `dim` - The dimension along which to compute the cumulative minimum. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the minimum - /// of all elements up to and including that position along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the cumulative minimum of elements along a dimension, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("cummin"))] - #[cfg_attr(not(doc), doc = "`Tensor::cummin`")] - /// function, which is more high-level and designed for public use. - fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Computes the cumulative maximum of elements along a dimension. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to compute the cumulative maximum of. - /// * `dim` - The dimension along which to compute the cumulative maximum. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the maximum - /// of all elements up to and including that position along the specified dimension. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For computing the cumulative maximum of elements along a dimension, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("cummax"))] - #[cfg_attr(not(doc), doc = "`Tensor::cummax`")] - /// function, which is more high-level and designed for public use. - fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Element-wise greater than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the corresponding element - /// of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between two tensors, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("greater"))] - #[cfg_attr(not(doc), doc = "`Tensor::greater`")] - /// function, which is more high-level and designed for public use. - fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; - - /// Element-wise greater than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than the right hand side - /// scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than comparison between a tensor and a scalar, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("greater_elem"))] - #[cfg_attr(not(doc), doc = "`Tensor::greater_elem`")] - /// function, which is more high-level and designed for public use. - fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; - - /// Element-wise greater than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the - /// corresponding element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between two tensors, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal"))] - #[cfg_attr(not(doc), doc = "`Tensor::greater_equal`")] - /// function, which is more high-level and designed for public use. - fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; - - /// Element-wise greater than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is greater than or equal to the right - /// hand side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal_elem"))] - #[cfg_attr(not(doc), doc = "`Tensor::greater_equal_elem`")] - /// function, which is more high-level and designed for public use. - fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; - - /// Element-wise less than comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than the corresponding element of - /// the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between two tensors, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("lower"))] - #[cfg_attr(not(doc), doc = "`Tensor::lower`")] - /// function, which is more high-level and designed for public use. - fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; - - /// Element-wise less than comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than the right hand side scalar, - /// and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than comparison between a tensor and a scalar, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("lower_elem"))] - #[cfg_attr(not(doc), doc = "`Tensor::lower_elem`")] - /// function, which is more high-level and designed for public use. - fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; - - /// Element-wise less than or equal comparison between two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensors, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the corresponding - /// element of the right hand side tensor, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between two tensors, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal"))] - #[cfg_attr(not(doc), doc = "`Tensor::lower_equal`")] - /// function, which is more high-level and designed for public use. - fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; - - /// Element-wise less than or equal comparison between a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the same shape as the input tensor, where each element is true if the - /// corresponding element of the left hand side tensor is less than or equal to the right hand - /// side scalar, and false otherwise. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal_elem"))] - #[cfg_attr(not(doc), doc = "`Tensor::lower_equal_elem`")] - /// function, which is more high-level and designed for public use. - fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; - - /// Gets the indices of the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the maximum elements. - /// * `tensor` - The tensor to get the indices of the maximum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// maximum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("argmax"))] - #[cfg_attr(not(doc), doc = "`Tensor::argmax`")] - /// function, which is more high-level and designed for public use. - fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor; - - /// Gets the indices of the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the indices of the minimum elements. - /// * `tensor` - The tensor to get the indices of the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor, where each element is the index of the - /// minimum element of the input tensor at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("argmin"))] - #[cfg_attr(not(doc), doc = "`Tensor::argmin`")] - /// function, which is more high-level and designed for public use. - fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A single-element tensor containing the maximum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("max"))] - #[cfg_attr(not(doc), doc = "`Tensor::max`")] - /// function, which is more high-level and designed for public use. - fn max(tensor: Self::Primitive) -> Self::Primitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. - /// Each element is the maximum element of the corresponding input dim. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim"))] - #[cfg_attr(not(doc), doc = "`Tensor::max_dim`")] - /// function, which is more high-level and designed for public use. - fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape - /// as the input tensor, where each element is the index of the maximum element of the input tensor - /// at the corresponding index along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim_with_indices"))] - #[cfg_attr(not(doc), doc = "`Tensor::max_dim_with_indices`")] - /// function, which is more high-level and designed for public use. - fn max_dim_with_indices(tensor: Self::Primitive, dim: usize) - -> (Self::Primitive, IntTensor); - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A single-element tensor containing the maximum absolute element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum absolute elements of a tensor, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs"))] - #[cfg_attr(not(doc), doc = "`Tensor::max_abs`")] - /// function, which is more high-level and designed for public use. - fn max_abs(tensor: Self::Primitive) -> Self::Primitive; - - /// Gets the maximum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the maximum elements from. - /// * `dim` - The axis along which to get the maximum elements. - /// - /// # Returns - /// - /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. - /// Each element is the maximum absolute element of the corresponding input dim. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the maximum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs_dim"))] - #[cfg_attr(not(doc), doc = "`Tensor::max_abs_dim`")] - /// function, which is more high-level and designed for public use. - fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A single-element tensor containing the minimum element of the input tensor. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("min"))] - #[cfg_attr(not(doc), doc = "`Tensor::min`")] - /// function, which is more high-level and designed for public use. - fn min(tensor: Self::Primitive) -> Self::Primitive; - - /// Gets the minimum elements of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// * `dim` - The axis along which to get the minimum elements. - /// - /// # Returns - /// - /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. - /// Each element is the minimum element of the corresponding input dim. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim"))] - #[cfg_attr(not(doc), doc = "`Tensor::min_dim`")] - /// function, which is more high-level and designed for public use. - fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; - - /// Gets the minimum elements and indices of a tensor along an axis. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to get the minimum elements from. - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensor and corresponding indices, where - /// each element is the minimum element of the input tensor at the corresponding index - /// along the specified axis. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users, and not recommended to import - /// or use this function directly. - /// - /// For getting the minimum elements of a tensor along an axis, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim_with_indices"))] - #[cfg_attr(not(doc), doc = "`Tensor::min_dim_with_indices`")] - /// function, which is more high-level and designed for public use. - fn min_dim_with_indices(tensor: Self::Primitive, dim: usize) - -> (Self::Primitive, IntTensor); - - /// Clamp the tensor between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor between the given min and max values, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("clamp"))] - #[cfg_attr(not(doc), doc = "`Tensor::clamp`")] - /// function, which is more high-level and designed for public use. - fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive; - - /// Clamps a tensor under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor under a minimum value, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_min"))] - #[cfg_attr(not(doc), doc = "`Tensor::clamp_min`")] - /// function, which is more high-level and designed for public use. - fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive; - - /// Clamps a tensor over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - /// # Remarks - /// - /// This is a low-level function used internally by the library to call different backend functions - /// with static dispatch. It is not designed for direct usage by users. - /// - /// For clamping a tensor over a maximum value, users should prefer the - #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_max"))] - #[cfg_attr(not(doc), doc = "`Tensor::clamp_max`")] - /// function, which is more high-level and designed for public use. - fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive; - /// Calculate absolute value on all elements of a tensor /// /// # Arguments diff --git a/crates/burn-backend/src/tensor/ops/orderable.rs b/crates/burn-backend/src/tensor/ops/orderable.rs new file mode 100644 index 0000000000..f114004bc2 --- /dev/null +++ b/crates/burn-backend/src/tensor/ops/orderable.rs @@ -0,0 +1,568 @@ +use crate::{ + Backend, ElementComparison, + tensor::{IntTensor, Numeric}, +}; + +/// Trait that list all operations that can be applied on all numerical tensors. +/// +/// # Warnings +/// +/// This is an internal trait, use the public API provided by the +#[cfg_attr(doc, doc = crate::doc_tensor!())] +#[cfg_attr(not(doc), doc = "`Tensor`")] +/// struct. +pub trait Ordered: Numeric +where + Self::Elem: ElementComparison, +{ + /// Computes the cumulative minimum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative minimum of. + /// * `dim` - The dimension along which to compute the cumulative minimum. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the minimum + /// of all elements up to and including that position along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the cumulative minimum of elements along a dimension, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("cummin"))] + #[cfg_attr(not(doc), doc = "`Tensor::cummin`")] + /// function, which is more high-level and designed for public use. + fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Computes the cumulative maximum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative maximum of. + /// * `dim` - The dimension along which to compute the cumulative maximum. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the maximum + /// of all elements up to and including that position along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the cumulative maximum of elements along a dimension, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("cummax"))] + #[cfg_attr(not(doc), doc = "`Tensor::cummax`")] + /// function, which is more high-level and designed for public use. + fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Element-wise greater than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the corresponding element + /// of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between two tensors, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("greater"))] + #[cfg_attr(not(doc), doc = "`Tensor::greater`")] + /// function, which is more high-level and designed for public use. + fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; + + /// Element-wise greater than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than the right hand side + /// scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than comparison between a tensor and a scalar, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("greater_elem"))] + #[cfg_attr(not(doc), doc = "`Tensor::greater_elem`")] + /// function, which is more high-level and designed for public use. + fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; + + /// Element-wise greater than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the + /// corresponding element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between two tensors, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal"))] + #[cfg_attr(not(doc), doc = "`Tensor::greater_equal`")] + /// function, which is more high-level and designed for public use. + fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; + + /// Element-wise greater than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is greater than or equal to the right + /// hand side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal_elem"))] + #[cfg_attr(not(doc), doc = "`Tensor::greater_equal_elem`")] + /// function, which is more high-level and designed for public use. + fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; + + /// Element-wise less than comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than the corresponding element of + /// the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between two tensors, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("lower"))] + #[cfg_attr(not(doc), doc = "`Tensor::lower`")] + /// function, which is more high-level and designed for public use. + fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; + + /// Element-wise less than comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than the right hand side scalar, + /// and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than comparison between a tensor and a scalar, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("lower_elem"))] + #[cfg_attr(not(doc), doc = "`Tensor::lower_elem`")] + /// function, which is more high-level and designed for public use. + fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; + + /// Element-wise less than or equal comparison between two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensors, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the corresponding + /// element of the right hand side tensor, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between two tensors, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal"))] + #[cfg_attr(not(doc), doc = "`Tensor::lower_equal`")] + /// function, which is more high-level and designed for public use. + fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; + + /// Element-wise less than or equal comparison between a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the same shape as the input tensor, where each element is true if the + /// corresponding element of the left hand side tensor is less than or equal to the right hand + /// side scalar, and false otherwise. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal_elem"))] + #[cfg_attr(not(doc), doc = "`Tensor::lower_equal_elem`")] + /// function, which is more high-level and designed for public use. + fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive; + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the maximum elements. + /// * `tensor` - The tensor to get the indices of the maximum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// maximum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("argmax"))] + #[cfg_attr(not(doc), doc = "`Tensor::argmax`")] + /// function, which is more high-level and designed for public use. + fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the indices of the minimum elements. + /// * `tensor` - The tensor to get the indices of the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the index of the + /// minimum element of the input tensor at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("argmin"))] + #[cfg_attr(not(doc), doc = "`Tensor::argmin`")] + /// function, which is more high-level and designed for public use. + fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A single-element tensor containing the maximum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("max"))] + #[cfg_attr(not(doc), doc = "`Tensor::max`")] + /// function, which is more high-level and designed for public use. + fn max(tensor: Self::Primitive) -> Self::Primitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. + /// Each element is the maximum element of the corresponding input dim. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim"))] + #[cfg_attr(not(doc), doc = "`Tensor::max_dim`")] + /// function, which is more high-level and designed for public use. + fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape + /// as the input tensor, where each element is the index of the maximum element of the input tensor + /// at the corresponding index along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim_with_indices"))] + #[cfg_attr(not(doc), doc = "`Tensor::max_dim_with_indices`")] + /// function, which is more high-level and designed for public use. + fn max_dim_with_indices(tensor: Self::Primitive, dim: usize) + -> (Self::Primitive, IntTensor); + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A single-element tensor containing the maximum absolute element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum absolute elements of a tensor, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs"))] + #[cfg_attr(not(doc), doc = "`Tensor::max_abs`")] + /// function, which is more high-level and designed for public use. + fn max_abs(tensor: Self::Primitive) -> Self::Primitive; + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements from. + /// * `dim` - The axis along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. + /// Each element is the maximum absolute element of the corresponding input dim. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the maximum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs_dim"))] + #[cfg_attr(not(doc), doc = "`Tensor::max_abs_dim`")] + /// function, which is more high-level and designed for public use. + fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A single-element tensor containing the minimum element of the input tensor. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("min"))] + #[cfg_attr(not(doc), doc = "`Tensor::min`")] + /// function, which is more high-level and designed for public use. + fn min(tensor: Self::Primitive) -> Self::Primitive; + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// * `dim` - The axis along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. + /// Each element is the minimum element of the corresponding input dim. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim"))] + #[cfg_attr(not(doc), doc = "`Tensor::min_dim`")] + /// function, which is more high-level and designed for public use. + fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + + /// Gets the minimum elements and indices of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements from. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor and corresponding indices, where + /// each element is the minimum element of the input tensor at the corresponding index + /// along the specified axis. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For getting the minimum elements of a tensor along an axis, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim_with_indices"))] + #[cfg_attr(not(doc), doc = "`Tensor::min_dim_with_indices`")] + /// function, which is more high-level and designed for public use. + fn min_dim_with_indices(tensor: Self::Primitive, dim: usize) + -> (Self::Primitive, IntTensor); + + /// Clamp the tensor between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor between the given min and max values, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("clamp"))] + #[cfg_attr(not(doc), doc = "`Tensor::clamp`")] + /// function, which is more high-level and designed for public use. + fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive; + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor under a minimum value, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_min"))] + #[cfg_attr(not(doc), doc = "`Tensor::clamp_min`")] + /// function, which is more high-level and designed for public use. + fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive; + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users. + /// + /// For clamping a tensor over a maximum value, users should prefer the + #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_max"))] + #[cfg_attr(not(doc), doc = "`Tensor::clamp_max`")] + /// function, which is more high-level and designed for public use. + fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive; +} diff --git a/crates/burn-candle/src/element.rs b/crates/burn-candle/src/element.rs index bc78c13712..33aba05054 100644 --- a/crates/burn-candle/src/element.rs +++ b/crates/burn-candle/src/element.rs @@ -1,14 +1,14 @@ use std::borrow::Borrow; -use burn_backend::{Element, bf16, f16}; +use burn_backend::{Element, ElementComparison, bf16, f16}; use candle_core::{FloatDType, Tensor, WithDType}; /// Candle element pub trait CandleElement: Element + WithDType {} /// Candle float element -pub trait FloatCandleElement: CandleElement + FloatDType {} +pub trait FloatCandleElement: CandleElement + FloatDType + ElementComparison {} /// Candle int element -pub trait IntCandleElement: CandleElement {} +pub trait IntCandleElement: CandleElement + ElementComparison {} impl CandleElement for f64 {} impl FloatCandleElement for f64 {} diff --git a/crates/burn-cubecl/src/element.rs b/crates/burn-cubecl/src/element.rs index c9a02b7d04..e77ed23b03 100644 --- a/crates/burn-cubecl/src/element.rs +++ b/crates/burn-cubecl/src/element.rs @@ -1,4 +1,4 @@ -use burn_backend::{Element, bf16, f16}; +use burn_backend::{Element, ElementComparison, bf16, f16}; use cubecl::{ CubeElement as CubeElem, flex32, prelude::{Float, Int, Numeric}, @@ -18,11 +18,11 @@ pub trait MatmulElement: } /// The float element type for the jit backend. -pub trait FloatElement: MatmulElement + Float {} +pub trait FloatElement: MatmulElement + Float + ElementComparison {} /// The int element type for the jit backend. pub trait IntElement: - MatmulElement + Int + ReducePrecision + MatmulElement + Int + ReducePrecision + ElementComparison { } diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index 5aea29c2ca..dce91c2380 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -1,4 +1,4 @@ -use burn_backend::Element; +use burn_backend::{Element, ElementComparison}; use num_traits::Signed; #[cfg(not(feature = "std"))] @@ -10,14 +10,14 @@ use num_traits::Pow; use libm::{log1p, log1pf}; /// A float element for ndarray backend. -pub trait FloatNdArrayElement: NdArrayElement + Signed +pub trait FloatNdArrayElement: NdArrayElement + Signed + ElementComparison where Self: Sized, { } /// An int element for ndarray backend. -pub trait IntNdArrayElement: NdArrayElement {} +pub trait IntNdArrayElement: NdArrayElement + ElementComparison {} /// A general element for ndarray backend. pub trait NdArrayElement: diff --git a/crates/burn-router/src/channel/base.rs b/crates/burn-router/src/channel/base.rs index b2cb35082e..5ce6d730da 100644 --- a/crates/burn-router/src/channel/base.rs +++ b/crates/burn-router/src/channel/base.rs @@ -1,4 +1,5 @@ use alloc::string::String; +use burn_backend::ElementComparison; use burn_backend::{DType, Element, Shape, backend::DeviceOps}; use burn_ir::TensorIr; @@ -16,9 +17,9 @@ pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized { /// Client type. type Client: RunnerClient; /// Float element type. - type FloatElem: Element; + type FloatElem: Element + ElementComparison; /// Int element type. - type IntElem: Element; + type IntElem: Element + ElementComparison; /// Bool element type. type BoolElem: Element; diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index 968cf66ff5..dea6e78937 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -1,4 +1,6 @@ -use crate::IntoKind; +use std::marker::PhantomData; + +use crate::{IntoKind, TchFloatElement}; use super::TchTensor; use super::element::TchElement; @@ -105,19 +107,19 @@ impl DeviceOps for LibTorchDevice {} /// /// Refer to the [tch] crate for more information. #[derive(Clone, Copy, Default, Debug)] -pub struct LibTorch { - _e: E, +pub struct LibTorch { + _e: PhantomData, + _f: PhantomData, } -impl Backend for LibTorch { +impl Backend for LibTorch { type Device = LibTorchDevice; type FloatTensorPrimitive = TchTensor; - type FloatElem = E; + type FloatElem = F; type IntTensorPrimitive = TchTensor; type IntElem = i64; - type BoolTensorPrimitive = TchTensor; type BoolElem = bool; diff --git a/crates/burn-tch/src/element.rs b/crates/burn-tch/src/element.rs index 36db5ce8ea..73d33159c1 100644 --- a/crates/burn-tch/src/element.rs +++ b/crates/burn-tch/src/element.rs @@ -1,4 +1,4 @@ -use burn_backend::Element; +use burn_backend::{Element, ElementComparison}; use burn_backend::{bf16, f16}; /// The element type for the tch backend. @@ -9,9 +9,18 @@ pub trait TchElement: Element + tch::kind::Element { } } +/// The float element type for the tch backend. +pub trait TchFloatElement: TchElement + ElementComparison {} + +/// The int element type for the tch backend. +pub trait TchIntElement: TchElement + ElementComparison {} + impl TchElement for f64 {} +impl TchFloatElement for f64 {} impl TchElement for f32 {} +impl TchFloatElement for f32 {} impl TchElement for f16 {} +impl TchFloatElement for f16 {} impl TchElement for bf16 { fn kind() -> tch::Kind { let mut kind = ::KIND; @@ -24,9 +33,14 @@ impl TchElement for bf16 { } impl TchElement for i64 {} + +impl TchIntElement for i64 {} impl TchElement for i32 {} +impl TchIntElement for i32 {} impl TchElement for i16 {} +impl TchIntElement for i16 {} impl TchElement for i8 {} +impl TchIntElement for i8 {} impl TchElement for u8 {} diff --git a/crates/burn-tch/src/ops/activation.rs b/crates/burn-tch/src/ops/activation.rs index b0636bdb1a..967e188634 100644 --- a/crates/burn-tch/src/ops/activation.rs +++ b/crates/burn-tch/src/ops/activation.rs @@ -1,7 +1,7 @@ -use crate::{LibTorch, TchTensor, element::TchElement}; +use crate::{LibTorch, TchFloatElement, TchTensor, element::TchElement}; use burn_backend::ops::ActivationOps; -impl ActivationOps for LibTorch { +impl ActivationOps for LibTorch { fn relu(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) } diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index c2dae4bf6d..1c58b0deb0 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -1,4 +1,5 @@ use super::TchOps; +use crate::TchFloatElement; use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; use burn_backend::ElementConversion; use burn_backend::ExecutionError; @@ -7,7 +8,7 @@ use burn_backend::tensor::BoolTensor; use burn_backend::tensor::IntTensor; use burn_backend::{Backend, Shape, TensorData, TensorMetadata, ops::BoolTensorOps}; -impl BoolTensorOps for LibTorch { +impl BoolTensorOps for LibTorch { fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { burn_backend::DType::Bool => TchTensor::from_data::(data, (*device).into()), @@ -38,7 +39,7 @@ impl BoolTensorOps for LibTorch { tensor.tensor.device().into() } - fn bool_empty(shape: Shape, device: & as Backend>::Device) -> TchTensor { + fn bool_empty(shape: Shape, device: & as Backend>::Device) -> TchTensor { let tensor = tch::Tensor::empty( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), @@ -47,7 +48,7 @@ impl BoolTensorOps for LibTorch { TchTensor::new(tensor) } - fn bool_zeros(shape: Shape, device: & as Backend>::Device) -> TchTensor { + fn bool_zeros(shape: Shape, device: & as Backend>::Device) -> TchTensor { let tensor = tch::Tensor::zeros( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), @@ -56,7 +57,7 @@ impl BoolTensorOps for LibTorch { TchTensor::new(tensor) } - fn bool_ones(shape: Shape, device: & as Backend>::Device) -> TchTensor { + fn bool_ones(shape: Shape, device: & as Backend>::Device) -> TchTensor { let tensor = tch::Tensor::ones( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 6f935a3d42..20366b70fd 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -6,11 +6,13 @@ use burn_backend::{ tensor::{IntElem, IntTensor}, }; -use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; +use crate::{ + IntoKind, LibTorch, LibTorchDevice, TchFloatElement, TchShape, TchTensor, element::TchElement, +}; use super::TchOps; -impl IntTensorOps for LibTorch { +impl IntTensorOps for LibTorch { fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { burn_backend::DType::I64 => TchTensor::from_data::(data, (*device).into()), @@ -43,7 +45,7 @@ impl IntTensorOps for LibTorch { fn int_empty( shape: Shape, - device: & as Backend>::Device, + device: & as Backend>::Device, dtype: IntDType, ) -> TchTensor { let tensor = tch::Tensor::empty( @@ -206,7 +208,7 @@ impl IntTensorOps for LibTorch { fn int_zeros( shape: Shape, - device: & as Backend>::Device, + device: & as Backend>::Device, dtype: IntDType, ) -> TchTensor { let shape = TchShape::from(shape); @@ -217,7 +219,7 @@ impl IntTensorOps for LibTorch { fn int_ones( shape: Shape, - device: & as Backend>::Device, + device: & as Backend>::Device, dtype: IntDType, ) -> TchTensor { let shape = TchShape::from(shape); @@ -229,7 +231,7 @@ impl IntTensorOps for LibTorch { fn int_full( shape: Shape, fill_value: i64, - device: & as Backend>::Device, + device: & as Backend>::Device, dtype: IntDType, ) -> TchTensor { let shape = TchShape::from(shape); diff --git a/crates/burn-tch/src/ops/module.rs b/crates/burn-tch/src/ops/module.rs index 6ff386c3e1..d34745d95c 100644 --- a/crates/burn-tch/src/ops/module.rs +++ b/crates/burn-tch/src/ops/module.rs @@ -1,4 +1,4 @@ -use crate::{LibTorch, TchTensor, element::TchElement}; +use crate::{LibTorch, TchFloatElement, TchTensor, element::TchElement}; use burn_backend::{ TensorMetadata, ops::{ @@ -8,7 +8,7 @@ use burn_backend::{ }, }; -impl ModuleOps for LibTorch { +impl ModuleOps for LibTorch { fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { // Workaround for MPS "Placeholder storage has not been allocated" error. // See: https://github.com/pytorch/pytorch/issues/123995 @@ -289,7 +289,7 @@ impl ModuleOps for LibTorch { padding: usize, dilation: usize, ceil_mode: bool, - ) -> MaxPool1dWithIndices> { + ) -> MaxPool1dWithIndices> { let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( &x.tensor, kernel_size as i64, @@ -329,7 +329,7 @@ impl ModuleOps for LibTorch { padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, - ) -> MaxPool2dWithIndices> { + ) -> MaxPool2dWithIndices> { let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( &x.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], @@ -351,7 +351,7 @@ impl ModuleOps for LibTorch { ceil_mode: bool, output_grad: TchTensor, indices: TchTensor, - ) -> MaxPool2dBackward> { + ) -> MaxPool2dBackward> { let grad = tch::Tensor::max_pool2d_with_indices_backward( &x.tensor, &output_grad.tensor, diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 7d82b7d20e..238b417535 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -5,9 +5,9 @@ use burn_backend::{ tensor::{Device, FloatTensor, IntTensor, QuantizedTensor}, }; -use crate::{LibTorch, LibTorchDevice, TchElement}; +use crate::{LibTorch, LibTorchDevice, TchElement, TchFloatElement}; -impl QTensorOps for LibTorch { +impl QTensorOps for LibTorch { fn q_from_data(_data: TensorData, _device: &LibTorchDevice) -> QuantizedTensor { unimplemented!() } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index f42739c33d..8f70bf2f36 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -1,4 +1,5 @@ use super::TchOps; +use crate::TchFloatElement; use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; use burn_backend::backend::ExecutionError; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor}; @@ -8,7 +9,7 @@ use burn_backend::{ }; use burn_backend::{bf16, f16}; -impl FloatTensorOps for LibTorch { +impl FloatTensorOps for LibTorch { fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { DType::F64 => TchTensor::from_data::(data, (*device).into()), @@ -100,7 +101,7 @@ impl FloatTensorOps for LibTorch { fn float_empty( shape: Shape, - device: & as Backend>::Device, + device: & as Backend>::Device, dtype: FloatDType, ) -> TchTensor { let tensor = tch::Tensor::empty( @@ -115,7 +116,7 @@ impl FloatTensorOps for LibTorch { TchOps::add(lhs, rhs) } - fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_add_scalar(lhs: TchTensor, rhs: F) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( @@ -128,7 +129,7 @@ impl FloatTensorOps for LibTorch { TchOps::sub(lhs, rhs) } - fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_sub_scalar(lhs: TchTensor, rhs: F) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( @@ -141,7 +142,7 @@ impl FloatTensorOps for LibTorch { TchOps::mul(lhs, rhs) } - fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_mul_scalar(lhs: TchTensor, rhs: F) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( @@ -154,7 +155,7 @@ impl FloatTensorOps for LibTorch { TchOps::div(lhs, rhs) } - fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_div_scalar(lhs: TchTensor, rhs: F) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( @@ -167,7 +168,7 @@ impl FloatTensorOps for LibTorch { TchOps::remainder(lhs, rhs) } - fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_remainder_scalar(lhs: TchTensor, rhs: F) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( @@ -187,7 +188,7 @@ impl FloatTensorOps for LibTorch { } fn float_neg(tensor: TchTensor) -> TchTensor { - Self::float_mul_scalar(tensor, (-1f32).elem::()) + Self::float_mul_scalar(tensor, (-1f32).elem::()) } fn float_recip(tensor: TchTensor) -> TchTensor { @@ -246,7 +247,7 @@ impl FloatTensorOps for LibTorch { TchTensor::new(output) } - fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor { + fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: F) -> TchTensor { let value: f64 = value.elem(); tensor.unary_ops( @@ -259,7 +260,7 @@ impl FloatTensorOps for LibTorch { TchOps::equal(lhs, rhs) } - fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_equal_elem(lhs: TchTensor, rhs: F) -> TchTensor { TchOps::equal_elem(lhs, rhs.elem::()) } @@ -267,7 +268,7 @@ impl FloatTensorOps for LibTorch { TchOps::greater(lhs, rhs) } - fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_greater_elem(lhs: TchTensor, rhs: F) -> TchTensor { TchOps::greater_elem(lhs, rhs.elem::()) } @@ -275,7 +276,7 @@ impl FloatTensorOps for LibTorch { TchOps::greater_equal(lhs, rhs) } - fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_greater_equal_elem(lhs: TchTensor, rhs: F) -> TchTensor { TchOps::greater_equal_elem(lhs, rhs.elem::()) } @@ -283,7 +284,7 @@ impl FloatTensorOps for LibTorch { TchOps::lower(lhs, rhs) } - fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_lower_elem(lhs: TchTensor, rhs: F) -> TchTensor { TchOps::lower_elem(lhs, rhs.elem::()) } @@ -291,7 +292,7 @@ impl FloatTensorOps for LibTorch { TchOps::lower_equal(lhs, rhs) } - fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor { + fn float_lower_equal_elem(lhs: TchTensor, rhs: F) -> TchTensor { TchOps::lower_equal_elem(lhs, rhs.elem::()) } @@ -462,18 +463,21 @@ impl FloatTensorOps for LibTorch { TchOps::cat(tensors, dim) } - fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor { + fn float_clamp_min(tensor: TchTensor, min: F) -> TchTensor { TchOps::clamp_min(tensor, min.elem::()) } - fn float_clamp_max(tensor: TchTensor, max: as Backend>::FloatElem) -> TchTensor { + fn float_clamp_max( + tensor: TchTensor, + max: as Backend>::FloatElem, + ) -> TchTensor { TchOps::clamp_max(tensor, max.elem::()) } fn float_clamp( tensor: TchTensor, - min: as Backend>::FloatElem, - max: as Backend>::FloatElem, + min: as Backend>::FloatElem, + max: as Backend>::FloatElem, ) -> TchTensor { TchOps::clamp(tensor, min.elem::(), max.elem::()) } diff --git a/crates/burn-tch/src/ops/transaction.rs b/crates/burn-tch/src/ops/transaction.rs index 323c25b228..dd417872c6 100644 --- a/crates/burn-tch/src/ops/transaction.rs +++ b/crates/burn-tch/src/ops/transaction.rs @@ -1,5 +1,5 @@ use burn_backend::ops::TransactionOps; -use crate::{LibTorch, TchElement}; +use crate::{LibTorch, TchElement, TchFloatElement}; -impl TransactionOps for LibTorch {} +impl TransactionOps for LibTorch {} diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index b346ff4bd0..0b4071496d 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,9 +1,11 @@ use crate::ops::FloatElem; -use crate::{BasicOps, Numeric, Shape, Slice, Tensor, backend::Backend, cast::ToElement}; +use crate::{BasicOps, Shape, Slice, Tensor, backend::Backend, cast::ToElement}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; use alloc::vec::Vec; +use burn_backend::ElementComparison; +use burn_backend::tensor::Ordered; /// The struct should always be used with the [check](crate::check) macro. /// @@ -466,10 +468,13 @@ impl TensorCheck { check } - pub(crate) fn one_hot_tensor>( + pub(crate) fn one_hot_tensor>( index_tensor: Tensor, num_classes: usize, - ) -> Self { + ) -> Self + where + >::Elem: ElementComparison, + { let mut check = Self::Ok; if index_tensor .clone() diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index a2cce869dc..a69ea78c48 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -8,6 +8,7 @@ mod float; mod fmod; mod int; mod numeric; +mod orderable; mod pad; mod take; mod transaction; diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 18de21e998..feaa06054a 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -659,771 +659,6 @@ where Self::new(K::cumprod(self.primitive, dim)) } - /// Computes the cumulative minimum of elements along the given *dimension* or *axis*. - /// - /// # Arguments - /// - /// * `dim` - The dimension or axis along which to compute the cumulative minimum. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device); - /// let result = tensor.clone().cummin(0); - /// println!("{result}"); - /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]] - /// let result = tensor.cummin(1); - /// println!("{result}"); - /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]] - /// } - /// ``` - pub fn cummin(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("CumMin", dim)); - Self::new(K::cummin(self.primitive, dim)) - } - - /// Computes the cumulative maximum of elements along the given *dimension* or *axis*. - /// - /// # Arguments - /// - /// * `dim` - The dimension or axis along which to compute the cumulative maximum. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device); - /// let result = tensor.clone().cummax(0); - /// println!("{result}"); - /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]] - /// let result = tensor.cummax(1); - /// println!("{result}"); - /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]] - /// } - /// ``` - pub fn cummax(self, dim: usize) -> Self { - check!(TensorCheck::aggregate_dim::("CumMax", dim)); - Self::new(K::cummax(self.primitive, dim)) - } - /// Applies element wise greater comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.greater(tensor2); - /// println!("{tensor}"); - /// // [[false, false, false], [true, true, true]] - /// } - /// ``` - pub fn greater(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); - Tensor::new(K::greater(self.primitive, other.primitive)) - } - - /// Applies element wise greater-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.greater_equal(tensor2); - /// println!("{tensor}"); - /// // [[true, false, false], [true, true, true]] - /// } - /// ``` - pub fn greater_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); - Tensor::new(K::greater_equal(self.primitive, other.primitive)) - } - - /// Applies element wise lower comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.lower(tensor2); - /// println!("{tensor}"); - /// // [[false, true, true], [false, false, false]] - /// } - /// ``` - pub fn lower(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); - Tensor::new(K::lower(self.primitive, other.primitive)) - } - - /// Applies element wise lower-equal comparison and returns a boolean tensor. - /// - /// # Panics - /// - /// If the two tensors don't have the same shape. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.lower_equal(tensor2); - /// println!("{tensor}"); - /// // [[true, true, true], [false, false, false]] - /// } - /// ``` - pub fn lower_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); - Tensor::new(K::lower_equal(self.primitive, other.primitive)) - } - - /// Applies greater than `other` comparison and returns a boolean tensor. - /// - /// # Arguments - /// - /// * `other` - The element to compare. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.greater_elem(3.0); - /// println!("{tensor}"); - /// // [[false, false, true], [true, true, true]] - /// } - /// ``` - pub fn greater_elem(self, other: E) -> Tensor { - Tensor::new(K::greater_elem(self.primitive, other.elem())) - } - - /// Applies greater-equal than `other` comparison and returns a boolean tensor. - /// - /// # Arguments - /// - /// * `other` - The element to compare. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.greater_equal_elem(3.0); - /// println!("{tensor}"); - /// // [[false, false, true], [true, true, true]] - /// } - /// ``` - pub fn greater_equal_elem(self, other: E) -> Tensor { - Tensor::new(K::greater_equal_elem(self.primitive, other.elem())) - } - - /// Applies lower than `other` comparison and returns a boolean tensor. - /// - /// # Arguments - /// - /// * `other` - The element to compare. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.lower_elem(3.0); - /// println!("{tensor}"); - /// // [[true, true, false], [false, false, false]] - /// } - /// ``` - pub fn lower_elem(self, other: E) -> Tensor { - Tensor::new(K::lower_elem(self.primitive, other.elem())) - } - - /// Applies lower-equal than `other` comparison and returns a boolean tensor. - /// - /// # Arguments - /// - /// * `other` - The element to compare. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.lower_equal_elem(3.0); - /// println!("{tensor}"); - /// // [[true, true, true], [false, false, false]] - /// } - /// ``` - pub fn lower_equal_elem(self, other: E) -> Tensor { - Tensor::new(K::lower_equal_elem(self.primitive, other.elem())) - } - - /// Applies the argmax function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); - /// let tensor = tensor.argmax(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmax(self, dim: usize) -> Tensor { - Tensor::new(K::argmax(self.primitive, dim)) - } - - /// Find the maximum value. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.max(); - /// println!("{tensor}"); - /// // [9.0] - /// } - /// ``` - pub fn max(self) -> Tensor { - Tensor::new(K::max(self.primitive)) - } - - /// Find the maximum value along the given dimension. - /// - /// # Arguments - /// - /// * `dim` - The dimension or axis along which to aggregate the elements; - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimension will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.max_dim(0); - /// println!("{tensor}"); - /// // [[5.0, 9.0, 6.0]] - /// } - /// ``` - pub fn max_dim(self, dim: I) -> Self { - let dim = dim.expect_dim_index(D); - check!(TensorCheck::aggregate_dim::("Max", dim)); - Tensor::new(K::max_dim(self.primitive, dim)) - } - - /// Find the maximum value along the given dimensions. - /// - /// # Arguments - /// - /// * `dims` - The dimensions or axis along which to aggregate the elements; - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimensions will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.max_dims(&[0, 1]); - /// println!("{tensor}"); - /// // [[9.0]] - /// } - /// ``` - pub fn max_dims(self, dims: &[I]) -> Self { - dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim)) - } - - /// Find the maximum value along the given dimension. - /// - /// Also returns the indices. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let (tensor, index) = tensor.max_dim_with_indices(0); - /// // [[5.0, 9.0, 6.0]] - /// println!("{tensor}"); - /// // [[1, 1, 1]] - /// println!("{index}"); - /// } - /// ``` - pub fn max_dim_with_indices(self, dim: I) -> (Self, Tensor) { - let dim = dim.expect_dim_index(D); - check!(TensorCheck::aggregate_dim::("Max", dim)); - - let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Finds the maximum pair wise values with another tensor. - /// - /// # Arguments - /// - /// * `other` - Other tensor to find maximum elements with - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors containing the maximum value found - /// in the input tensors. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.max_pair(tensor2); - /// println!("{tensor}"); - /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]] - /// } - /// ``` - pub fn max_pair(self, other: Self) -> Self { - let mask = self.clone().lower(other.clone()); - self.mask_where(mask, other) - } - - /// Find the maximum absolute value. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device); - /// let tensor = tensor.max_abs(); - /// println!("{tensor}"); - /// // [7.0] - /// } - /// ``` - pub fn max_abs(self) -> Tensor { - Tensor::new(K::max_abs(self.primitive)) - } - - /// Find the maximum absolute value along the given dimension. - /// - /// # Arguments - /// - /// * `dim` - The dimension or axis along which to aggregate the elements, - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimension will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.max_dim(0); - /// println!("{tensor}"); - /// // [[5.0, 9.0, 6.0]] - /// } - /// ``` - pub fn max_abs_dim(self, dim: I) -> Self { - let dim = dim.expect_dim_index(D); - check!(TensorCheck::aggregate_dim::("MaxAbs", dim)); - - Tensor::new(K::max_abs_dim(self.primitive, dim)) - } - - /// Find the maximum absolute value along the given dimensions. - /// - /// # Arguments - /// - /// * `dims` - The dimensions or axes along which to aggregate the elements, - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimensions will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.max_abs_dims(&[0, 1]); - /// println!("{tensor}"); - /// // [[9.0]] - /// } - /// ``` - pub fn max_abs_dims(self, dims: &[I]) -> Self { - dims.iter() - .fold(self, |tensor, &dim| tensor.max_abs_dim(dim)) - } - - /// Applies the argmin function along the given dimension and returns an integer tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = Default::default(); - /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); - /// let tensor = tensor.argmin(1); - /// println!("{:?}", tensor.shape()); - /// // Shape { dims: [2, 1, 3] } - /// } - /// ``` - pub fn argmin(self, dim: usize) -> Tensor { - Tensor::new(K::argmin(self.primitive, dim)) - } - - /// Find the minimum value. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.min(); - /// println!("{tensor}"); - /// // [-2.0] - /// } - /// ``` - pub fn min(self) -> Tensor { - Tensor::new(K::min(self.primitive)) - } - - /// Find the minimum value along the given dimension. - /// - /// # Arguments - /// - /// * `dim` - The dimension or axis along which to aggregate the elements; - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimension will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.min_dim(0); - /// println!("{tensor}"); - /// // [[1.0, -2.0, 3.0]] - /// } - /// ``` - pub fn min_dim(self, dim: I) -> Self { - let dim = dim.expect_dim_index(D); - check!(TensorCheck::aggregate_dim::("Min", dim)); - Tensor::new(K::min_dim(self.primitive, dim)) - } - - /// Find the minimum value along the given dimensions. - /// - /// # Arguments - /// - /// * `dims` - The dimensions or axes along which to aggregate the elements; - /// supports negative indexing. - /// - /// # Returns - /// - /// The returned tensor will have the same rank, - /// but the aggregated dimensions will have size 1. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor = tensor.min_dims(&[0, 1]); - /// println!("{tensor}"); - /// // [[-2.0]] - /// } - /// ``` - pub fn min_dims(self, dims: &[I]) -> Self { - dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim)) - } - - /// Find the minimum value along the given dimension. - /// - /// Also returns the indices. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor = Tensor::::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let (tensor, index) = tensor.min_dim_with_indices(0); - /// println!("{tensor}"); - /// // [[5.0, -2.0, 3.0]] - /// println!("{}", index); - /// // [[1, 0, 0]] - /// } - /// ``` - pub fn min_dim_with_indices(self, dim: I) -> (Self, Tensor) { - let dim = dim.expect_dim_index(D); - check!(TensorCheck::aggregate_dim::("Min", dim)); - - let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); - - let tensor = Tensor::new(tensor); - let index = Tensor::new(index); - - (tensor, index) - } - - /// Finds the minimum pair wise values with another tensor. - /// - /// # Arguments - /// - /// * `other` - Other tensor to find minimum elements with - /// - /// # Returns - /// - /// A tensor with the same shape as the input tensors containing the minimum value found - /// between each element of the two source tensors. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Shape}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); - /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); - /// let tensor = tensor1.min_pair(tensor2); - /// println!("{tensor}"); - /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]] - /// } - pub fn min_pair(self, other: Self) -> Self { - let mask = other.clone().lower(self.clone()); - self.mask_where(mask, other) - } - - /// Clamp element wise between the given min and max values. - /// - /// # Arguments - /// - /// * `min` - The minimum value. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped between the given min and max values. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// - /// fn example() { - /// let device = Default::default(); - /// let tensor = Tensor::::from_ints( - /// [ - /// [1, 2, 3], - /// [4, 5, 6], - /// [7, 8, 9] - /// ], - /// &device); - /// let tensor = tensor.clamp(2, 6); - /// println!("{tensor}"); - /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]] - /// } - /// ``` - pub fn clamp(self, min: E, max: E) -> Self { - Self::new(K::clamp(self.primitive, min.elem(), max.elem())) - } - - /// Clamp element wise under a minimum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `min` - The minimum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped under the given min value. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// - /// fn example() { - /// let device = Default::default(); - /// let tensor = Tensor::::from_ints( - /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - /// &device); - /// let tensor = tensor.clamp_min(4); - /// println!("{tensor}"); - /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]] - /// } - /// ``` - pub fn clamp_min(self, min: E) -> Self { - Self::new(K::clamp_min(self.primitive, min.elem())) - } - - /// Clamp element wise over a maximum value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to clamp. - /// * `max` - The maximum value. - /// - /// # Returns - /// - /// A new tensor with the values clamped over the given max value. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Int, Tensor}; - /// - /// fn example() { - /// let device = Default::default(); - /// let tensor = Tensor::::from_ints( - /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - /// &device); - /// let tensor = tensor.clamp_max(5); - /// println!("{tensor}"); - /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]] - /// } - /// ``` - pub fn clamp_max(self, max: E) -> Self { - Self::new(K::clamp_max(self.primitive, max.elem())) - } - /// Apply element wise absolute value operation. /// /// # Example @@ -1964,109 +1199,6 @@ where ) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example(){ - /// let device = Default::default(); - /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); - /// let one_hot: Tensor = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - self.one_hot_fill(num_classes, 1.0, 0.0, -1) - } - - /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. - /// - /// # Arguments - /// - /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. - /// * `on_value`: The value to assign for active positions (corresponding to indices). - /// * `off_value`: The value to assign for inactive positions. - /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. - /// - /// # Returns - /// - /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. - /// - /// # Example - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Float}; - /// fn example>>() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); - /// // One-hot encoding - /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); - /// println!("{tensor}"); - /// // [[[5.0, 0.0, 0.0], - /// // [0.0, 0.0, 5.0]], - /// // [[0.0, 5.0, 0.0], - /// // [0.0, 0.0, 5.0]]] - /// } - /// ``` - pub fn one_hot_fill( - self, - num_classes: usize, - on_value: f32, - off_value: f32, - axis: i64, - ) -> Tensor { - check!(TensorCheck::one_hot_tensor_rank::()); - // Initialize shape from the current tensor dimensions and prepare for modification - let mut shape = self.shape(); - let device = self.device(); - let rank = self.dims().len(); - - // Adjust negative axis to a positive index - let axis = if axis < 0 { - axis + rank as i64 + 1 - } else { - axis - }; - - // Ensure axis is within valid range - if axis < 0 || axis > rank as i64 { - panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); - } - // Convert the input tensor to integer indices - let indices: Tensor = - Tensor::from_data(self.to_data().convert::(), &device); - // Insert the new dimension for the one-hot representation - shape.insert(axis as usize, num_classes); - // Adjust indices to valid range and handle invalid indices - let adjusted_indices = indices - .clone() - .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices - .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices - // Unsqueeze the indices tensor along the specified axis - let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); - - // Initialize the output tensor with the off_value - let output = Tensor::full(shape.clone(), off_value, &device); - - // Prepare scatter tensor for on_value and off_value adjustments - let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); - - // Scatter on_value at the appropriate indices to create the one-hot representation - output.scatter( - axis as usize, - indices_unsqueezed, - scatter_on_values, - IndexingUpdateOp::Add, - ) - } - /// Applies the matrix multiplication operation. /// /// ```math diff --git a/crates/burn-tensor/src/tensor/api/orderable.rs b/crates/burn-tensor/src/tensor/api/orderable.rs new file mode 100644 index 0000000000..4ca7a41b99 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/orderable.rs @@ -0,0 +1,880 @@ +use burn_backend::{ + Backend, ElementComparison, ElementConversion, + tensor::{Bool, IndexingUpdateOp, Int, Ordered}, +}; +use burn_std::AsIndex; + +use crate::check; +use crate::{Tensor, check::TensorCheck}; + +impl Tensor +where + B: Backend, + K: Ordered, + K::Elem: ElementComparison, +{ + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + /// + /// # Arguments + /// + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Float}; + /// fn example>>() { + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); + /// // One-hot encoding + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] + /// } + /// ``` + pub fn one_hot_fill( + self, + num_classes: usize, + on_value: f32, + off_value: f32, + axis: i64, + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); + // Initialize shape from the current tensor dimensions and prepare for modification + let mut shape = self.shape(); + let device = self.device(); + let rank = self.dims().len(); + + // Adjust negative axis to a positive index + let axis = if axis < 0 { + axis + rank as i64 + 1 + } else { + axis + }; + + // Ensure axis is within valid range + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + // Convert the input tensor to integer indices + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); + // Insert the new dimension for the one-hot representation + shape.insert(axis as usize, num_classes); + // Adjust indices to valid range and handle invalid indices + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + // Unsqueeze the indices tensor along the specified axis + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value + let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation + output.scatter( + axis as usize, + indices_unsqueezed, + scatter_on_values, + IndexingUpdateOp::Add, + ) + } + /// Applies element wise greater comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.greater(tensor2); + /// println!("{tensor}"); + /// // [[false, false, false], [true, true, true]] + /// } + /// ``` + pub fn greater(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); + Tensor::new(K::greater(self.primitive, other.primitive)) + } + + /// Applies element wise greater-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.greater_equal(tensor2); + /// println!("{tensor}"); + /// // [[true, false, false], [true, true, true]] + /// } + /// ``` + pub fn greater_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); + Tensor::new(K::greater_equal(self.primitive, other.primitive)) + } + + /// Applies element wise lower comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.lower(tensor2); + /// println!("{tensor}"); + /// // [[false, true, true], [false, false, false]] + /// } + /// ``` + pub fn lower(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); + Tensor::new(K::lower(self.primitive, other.primitive)) + } + + /// Applies element wise lower-equal comparison and returns a boolean tensor. + /// + /// # Panics + /// + /// If the two tensors don't have the same shape. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.lower_equal(tensor2); + /// println!("{tensor}"); + /// // [[true, true, true], [false, false, false]] + /// } + /// ``` + pub fn lower_equal(self, other: Self) -> Tensor { + check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); + Tensor::new(K::lower_equal(self.primitive, other.primitive)) + } + + /// Applies greater than `other` comparison and returns a boolean tensor. + /// + /// # Arguments + /// + /// * `other` - The element to compare. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.greater_elem(3.0); + /// println!("{tensor}"); + /// // [[false, false, true], [true, true, true]] + /// } + /// ``` + pub fn greater_elem(self, other: E) -> Tensor { + Tensor::new(K::greater_elem(self.primitive, other.elem())) + } + + /// Applies greater-equal than `other` comparison and returns a boolean tensor. + /// + /// # Arguments + /// + /// * `other` - The element to compare. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.greater_equal_elem(3.0); + /// println!("{tensor}"); + /// // [[false, false, true], [true, true, true]] + /// } + /// ``` + pub fn greater_equal_elem(self, other: E) -> Tensor { + Tensor::new(K::greater_equal_elem(self.primitive, other.elem())) + } + + /// Applies lower than `other` comparison and returns a boolean tensor. + /// + /// # Arguments + /// + /// * `other` - The element to compare. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.lower_elem(3.0); + /// println!("{tensor}"); + /// // [[true, true, false], [false, false, false]] + /// } + /// ``` + pub fn lower_elem(self, other: E) -> Tensor { + Tensor::new(K::lower_elem(self.primitive, other.elem())) + } + + /// Applies lower-equal than `other` comparison and returns a boolean tensor. + /// + /// # Arguments + /// + /// * `other` - The element to compare. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.lower_equal_elem(3.0); + /// println!("{tensor}"); + /// // [[true, true, true], [false, false, false]] + /// } + /// ``` + pub fn lower_equal_elem(self, other: E) -> Tensor { + Tensor::new(K::lower_equal_elem(self.primitive, other.elem())) + } + + /// Applies the argmax function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); + /// let tensor = tensor.argmax(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmax(self, dim: usize) -> Tensor { + Tensor::new(K::argmax(self.primitive, dim)) + } + /// Find the maximum value. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.max(); + /// println!("{tensor}"); + /// // [9.0] + /// } + /// ``` + pub fn max(self) -> Tensor { + Tensor::new(K::max(self.primitive)) + } + + /// Find the maximum value along the given dimension. + /// + /// Also returns the indices. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let (tensor, index) = tensor.max_dim_with_indices(0); + /// // [[5.0, 9.0, 6.0]] + /// println!("{tensor}"); + /// // [[1, 1, 1]] + /// println!("{index}"); + /// } + /// ``` + pub fn max_dim_with_indices(self, dim: I) -> (Self, Tensor) { + let dim = dim.expect_dim_index(D); + check!(TensorCheck::aggregate_dim::("Max", dim)); + + let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Find the maximum absolute value. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device); + /// let tensor = tensor.max_abs(); + /// println!("{tensor}"); + /// // [7.0] + /// } + /// ``` + pub fn max_abs(self) -> Tensor { + Tensor::new(K::max_abs(self.primitive)) + } + + /// Finds the maximum pair wise values with another tensor. + /// + /// # Arguments + /// + /// * `other` - Other tensor to find maximum elements with + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors containing the maximum value found + /// in the input tensors. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.max_pair(tensor2); + /// println!("{tensor}"); + /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]] + /// } + /// ``` + pub fn max_pair(self, other: Self) -> Self { + let mask = self.clone().lower(other.clone()); + self.mask_where(mask, other) + } + + /// Find the maximum absolute value along the given dimension. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to aggregate the elements, + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimension will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.max_dim(0); + /// println!("{tensor}"); + /// // [[5.0, 9.0, 6.0]] + /// } + /// ``` + pub fn max_abs_dim(self, dim: I) -> Self { + let dim = dim.expect_dim_index(D); + check!(TensorCheck::aggregate_dim::("MaxAbs", dim)); + + Tensor::new(K::max_abs_dim(self.primitive, dim)) + } + + /// Find the maximum absolute value along the given dimensions. + /// + /// # Arguments + /// + /// * `dims` - The dimensions or axes along which to aggregate the elements, + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimensions will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.max_abs_dims(&[0, 1]); + /// println!("{tensor}"); + /// // [[9.0]] + /// } + /// ``` + pub fn max_abs_dims(self, dims: &[I]) -> Self { + dims.iter() + .fold(self, |tensor, &dim| tensor.max_abs_dim(dim)) + } + + /// Applies the argmin function along the given dimension and returns an integer tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); + /// let tensor = tensor.argmin(1); + /// println!("{:?}", tensor.shape()); + /// // Shape { dims: [2, 1, 3] } + /// } + /// ``` + pub fn argmin(self, dim: usize) -> Tensor { + Tensor::new(K::argmin(self.primitive, dim)) + } + + /// Find the minimum value. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.min(); + /// println!("{tensor}"); + /// // [-2.0] + /// } + /// ``` + pub fn min(self) -> Tensor { + Tensor::new(K::min(self.primitive)) + } + + /// Find the minimum value along the given dimension. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to aggregate the elements; + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimension will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.min_dim(0); + /// println!("{tensor}"); + /// // [[1.0, -2.0, 3.0]] + /// } + /// ``` + pub fn min_dim(self, dim: I) -> Self { + let dim = dim.expect_dim_index(D); + check!(TensorCheck::aggregate_dim::("Min", dim)); + Tensor::new(K::min_dim(self.primitive, dim)) + } + + /// Find the minimum value along the given dimensions. + /// + /// # Arguments + /// + /// * `dims` - The dimensions or axes along which to aggregate the elements; + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimensions will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.min_dims(&[0, 1]); + /// println!("{tensor}"); + /// // [[-2.0]] + /// } + /// ``` + pub fn min_dims(self, dims: &[I]) -> Self { + dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim)) + } + + /// Find the minimum value along the given dimension. + /// + /// Also returns the indices. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let (tensor, index) = tensor.min_dim_with_indices(0); + /// println!("{tensor}"); + /// // [[5.0, -2.0, 3.0]] + /// println!("{}", index); + /// // [[1, 0, 0]] + /// } + /// ``` + pub fn min_dim_with_indices(self, dim: I) -> (Self, Tensor) { + let dim = dim.expect_dim_index(D); + check!(TensorCheck::aggregate_dim::("Min", dim)); + + let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); + + let tensor = Tensor::new(tensor); + let index = Tensor::new(index); + + (tensor, index) + } + + /// Finds the minimum pair wise values with another tensor. + /// + /// # Arguments + /// + /// * `other` - Other tensor to find minimum elements with + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensors containing the minimum value found + /// between each element of the two source tensors. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); + /// let tensor = tensor1.min_pair(tensor2); + /// println!("{tensor}"); + /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]] + /// } + pub fn min_pair(self, other: Self) -> Self { + let mask = other.clone().lower(self.clone()); + self.mask_where(mask, other) + } + + /// Clamp element wise between the given min and max values. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped between the given min and max values. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_ints( + /// [ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ], + /// &device); + /// let tensor = tensor.clamp(2, 6); + /// println!("{tensor}"); + /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]] + /// } + /// ``` + pub fn clamp(self, min: E, max: E) -> Self { + Self::new(K::clamp(self.primitive, min.elem(), max.elem())) + } + + /// Clamp element wise under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped under the given min value. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_ints( + /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + /// &device); + /// let tensor = tensor.clamp_min(4); + /// println!("{tensor}"); + /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]] + /// } + /// ``` + pub fn clamp_min(self, min: E) -> Self { + Self::new(K::clamp_min(self.primitive, min.elem())) + } + + /// Clamp element wise over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// A new tensor with the values clamped over the given max value. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Int, Tensor}; + /// + /// fn example() { + /// let device = Default::default(); + /// let tensor = Tensor::::from_ints( + /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + /// &device); + /// let tensor = tensor.clamp_max(5); + /// println!("{tensor}"); + /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]] + /// } + /// ``` + pub fn clamp_max(self, max: E) -> Self { + Self::new(K::clamp_max(self.primitive, max.elem())) + } + /// Computes the cumulative minimum of elements along the given *dimension* or *axis*. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to compute the cumulative minimum. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device); + /// let result = tensor.clone().cummin(0); + /// println!("{result}"); + /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]] + /// let result = tensor.cummin(1); + /// println!("{result}"); + /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]] + /// } + /// ``` + pub fn cummin(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("CumMin", dim)); + Self::new(K::cummin(self.primitive, dim)) + } + + /// Computes the cumulative maximum of elements along the given *dimension* or *axis*. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to compute the cumulative maximum. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device); + /// let result = tensor.clone().cummax(0); + /// println!("{result}"); + /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]] + /// let result = tensor.cummax(1); + /// println!("{result}"); + /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]] + /// } + /// ``` + pub fn cummax(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("CumMax", dim)); + Self::new(K::cummax(self.primitive, dim)) + } + /// Find the maximum value along the given dimension. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to aggregate the elements; + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimension will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.max_dim(0); + /// println!("{tensor}"); + /// // [[5.0, 9.0, 6.0]] + /// } + /// ``` + pub fn max_dim(self, dim: I) -> Self { + let dim = dim.expect_dim_index(D); + check!(TensorCheck::aggregate_dim::("Max", dim)); + Tensor::new(K::max_dim(self.primitive, dim)) + } + + /// Find the maximum value along the given dimensions. + /// + /// # Arguments + /// + /// * `dims` - The dimensions or axis along which to aggregate the elements; + /// supports negative indexing. + /// + /// # Returns + /// + /// The returned tensor will have the same rank, + /// but the aggregated dimensions will have size 1. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); + /// let tensor = tensor.max_dims(&[0, 1]); + /// println!("{tensor}"); + /// // [[9.0]] + /// } + /// ``` + pub fn max_dims(self, dims: &[I]) -> Self { + dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim)) + } +} diff --git a/crates/burn-tensor/src/tensor/linalg/vector_norm.rs b/crates/burn-tensor/src/tensor/linalg/vector_norm.rs index 3c5569b57a..aefb74ae02 100644 --- a/crates/burn-tensor/src/tensor/linalg/vector_norm.rs +++ b/crates/burn-tensor/src/tensor/linalg/vector_norm.rs @@ -1,3 +1,6 @@ +use burn_backend::ElementComparison; +use burn_backend::tensor::Ordered; + use crate::backend::Backend; use crate::tensor::{BasicOps, Tensor}; use crate::{ElementConversion, Numeric}; @@ -200,7 +203,8 @@ pub fn max_abs_norm( dim: usize, ) -> Tensor where - K: BasicOps + Numeric, + K: BasicOps + Ordered, + >::Elem: ElementComparison, { x.max_abs_dim(dim) } @@ -220,7 +224,8 @@ pub fn min_abs_norm( dim: usize, ) -> Tensor where - K: BasicOps + Numeric, + K: BasicOps + Ordered, + >::Elem: ElementComparison, { x.abs().min_dim(dim) } diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index f85406431d..67c01585de 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use alloc::vec::Vec; use burn_tensor::{ - Bool, Element, ElementConversion, Int, Shape, Tensor, TensorData, + Bool, Element, ElementComparison, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend, ops::{BoolTensor, IntTensor}, }; @@ -75,7 +75,7 @@ pub(crate) struct UnionFind { labels: Vec, } -impl Solver for UnionFind { +impl Solver for UnionFind { fn init(max_labels: usize) -> Self { let mut labels = Vec::with_capacity(max_labels); labels.push(0.elem()); diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs index 1709cd7dc3..50d5e94cf2 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs @@ -18,7 +18,7 @@ use std::cmp::Ordering; -use burn_tensor::{Element, ElementConversion}; +use burn_tensor::{Element, ElementComparison, ElementConversion}; use ndarray::{Array2, Axis, s}; #[allow(non_snake_case)] @@ -29,7 +29,7 @@ use crate::Connectivity; use super::{Solver, StatsOp, max_labels}; -pub fn process>( +pub fn process>( img_arr: Array2, stats: &mut impl StatsOp, ) -> Array2 {