Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions burn-book/src/advanced/backend-extension/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub trait Backend: burn::tensor::backend::Backend {
You can then implement your new custom backend trait for any backend that you want to support:

```rust, ignore
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
impl<E: TchElement> Backend for burn_tch::LibTorch<E,F,I> {
fn my_new_function(tensor: TchTensor<E, 2>) -> TchTensor<E, 2> {
// My Tch implementation
}
Expand Down Expand Up @@ -63,7 +63,7 @@ impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {
}
}

impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {
impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E,F,I>> {
fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {
// My own backward implementation, generic over a backend implementation.
//
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-backend/src/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::element::Element;
use crate::ops::*;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use crate::{ElementComparison, ops::*};
use crate::{QTensorPrimitive, TensorData, TensorMetadata};

use super::DeviceOps;
Expand Down Expand Up @@ -83,12 +83,12 @@ pub trait Backend:
/// Tensor primitive to be used for all float operations.
type FloatTensorPrimitive: TensorMetadata + 'static;
/// Default float element type.
type FloatElem: Element;
type FloatElem: Element + ElementComparison;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc ElementComparison was added specially for sorting ops, which have a default cpu implementation that operates on TensorData elements. We could likely remove the stricter ElementComparison bound, and only require ElementEquality, except for the data sorting implementation

/// Compare two elements
fn compare<E: ElementComparison>(a: &E, b: &E, descending: bool) -> Ordering {
if descending { b.cmp(a) } else { a.cmp(b) }
}

Would have to check, but I believe everything else should still compile and work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc ElementComparison was added specially for sorting ops, which have a default cpu implementation that operates on TensorData elements. We could likely remove the stricter ElementComparison bound, and only require ElementEquality, except for the data sorting implementation

right now, {Float,Int}TensorOps bundles methods that rely on both Orderable and Numeric, hence the added constraint. How would removing the bound work? Would we just remove the ordering ops from {Float,Int}TensorOps traits and then leave it to the stuff defined under orderable, tensorData sort impl, or would we take a similar approach to what was done to numeric and define a float,int ordering trait that would be implemented directly on the end element types (just moving the existing methods to a new home)? or were you thinking of another approach entirely?


/// Tensor primitive to be used for all int operations.
type IntTensorPrimitive: TensorMetadata + 'static;
/// Int element type.
type IntElem: Element;
type IntElem: Element + ElementComparison;

/// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive: TensorMetadata + 'static;
Expand Down
14 changes: 7 additions & 7 deletions crates/burn-backend/src/backend/ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn sort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> K::Primitive
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
Expand All @@ -51,7 +51,7 @@ pub fn sort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> K::Primitive
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let dims = data.shape.clone();
let data_slice = data.as_mut_slice().unwrap();
Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn sort_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> (K::Primitive, IntTensor<B>)
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
Expand All @@ -109,7 +109,7 @@ fn sort_data_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> (K::Primitive, IntTensor<B>)
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: Element + ElementComparison,
{
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
Expand Down Expand Up @@ -191,7 +191,7 @@ pub fn argsort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> IntTensor<B>
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
Expand All @@ -209,7 +209,7 @@ fn argsort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
descending: bool,
) -> IntTensor<B>
where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
Expand Down Expand Up @@ -252,7 +252,7 @@ fn sort_slice<B: Backend, K: BasicOps<B>>(
permute_both: bool,
descending: bool,
) where
<K as BasicOps<B>>::Elem: Element,
<K as BasicOps<B>>::Elem: ElementComparison,
{
let ndims = dims.len();
let strides = compute_strides(dims);
Expand Down
11 changes: 7 additions & 4 deletions crates/burn-backend/src/data/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_std::{DType, bf16, f16};
use num_traits::{Float, ToPrimitive};

use super::TensorData;
use crate::element::Element;
use crate::{ElementComparison, element::Element};

/// The tolerance used to compare to floating point numbers.
///
Expand Down Expand Up @@ -269,7 +269,7 @@ impl TensorData {
let mut num_diff = 0;
let max_num_diff = 5;
for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
if a.cmp(&b).is_ne() {
if !a.eq(&b) {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!("\n => Position {i}: {a} != {b}").as_str();
Expand Down Expand Up @@ -362,7 +362,7 @@ impl TensorData {
///
/// If any value is not within the half-open range bounded inclusively below
/// and exclusively above (`start..end`).
pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
pub fn assert_within_range<E: Element + ElementComparison>(&self, range: core::ops::Range<E>) {
for elem in self.iter::<E>() {
if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
panic!("Element ({elem:?}) is not within range {range:?}");
Expand All @@ -379,7 +379,10 @@ impl TensorData {
/// # Panics
///
/// If any value is not within the half-open range bounded inclusively (`start..=end`).
pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
pub fn assert_within_range_inclusive<E: Element + ElementComparison>(
&self,
range: core::ops::RangeInclusive<E>,
) {
let start = range.start();
let end = range.end();

Expand Down
13 changes: 12 additions & 1 deletion crates/burn-backend/src/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub trait Element:
ToElement
+ ElementRandom
+ ElementConversion
+ ElementComparison
+ ElementEquality
+ ElementLimits
+ bytemuck::CheckedBitPattern
+ bytemuck::NoUninit
Expand Down Expand Up @@ -63,6 +63,12 @@ pub trait ElementRandom {
fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
}

/// Element trait for equality of a tensor.
pub trait ElementEquality {
/// Returns whether `self` and `other` are equal.
fn eq(&self, other: &Self) -> bool;
}

/// Element ordering trait.
pub trait ElementComparison {
/// Returns and [Ordering] between `self` and `other`.
Expand Down Expand Up @@ -104,6 +110,11 @@ macro_rules! make_element {
$dtype
}
}
impl ElementEquality for $type {
fn eq(&self, other: &Self) -> bool {
self == other
}
}

impl ElementConversion for $type {
#[inline(always)]
Expand Down
Loading