Skip to content

Conversation

A2va
Copy link

@A2va A2va commented Jun 12, 2025

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

As my project is blocked by both #3262 and a TensorFlow-generated ONNX issue, I decided to try implementing this feature.

My implementation is based on the existing float_cast function, so things might be different for integer ones. It is pretty much a work in progress, so there might be errors, and I look forward to receiving any feedback.

Regarding the current state of the implementation, I have added int_cast to all backends, but I am struggling to implement it for the ndarray backend.
I have added the implementation of IntNdArrayElement (crates/burn-ndarray/src/element.rs), but since all uxx are unsigned, it errors with:

the trait bound `u64: Signed` is not satisfied
the following other types implement trait `Signed`:
  f32
  f64
  i128
  i16
  i32
  i64
  i8
  isize

And the Sized trait cannot be removed since it produce other errors for IntTensorOps.

I also get this error for all usage of execute_with_int_dtype and I don't understand it because the conversion between NdArrayTensor<I> to NdArrayTensorInt is implemented.

error[E0308]: mismatched types
   --> crates\burn-ndarray\src\tensor.rs:312:14
    |
311 |           match ($lhs, $rhs) {
    |                 ------------ this expression has type `(tensor::NdArrayTensor<I>, tensor::NdArrayTensor<I>)`
312 |               ($crate::NdArrayTensorInt::I64(lhs), $crate::NdArrayTensorInt::I64(rhs)) => {
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `NdArrayTensor<I>`, found `NdArrayTensorInt`
    |
   ::: crates\burn-ndarray\src\ops\int_tensor.rs:74:9
    |
74  | /         execute_with_int_dtype!((tensor, source), |tensor, source| {
75  | |             NdArrayMathOps::mask_where(tensor, mask, source)
76  | |         })
    | |__________- in this macro invocation
    |
    = note: expected struct `tensor::NdArrayTensor<I>`
                 found enum `tensor::NdArrayTensorInt`
    = note: this error originates in the macro `execute_with_int_dtype` (in Nightly builds, run with -Z macro-backtrace for more info)

I also noticed some inconsistencies in the file naming conventions between the backends. For example, in the ndarray backend it is named int_tensor.rs, whereas in cubecl it's int_ops.rs.
What about unifying the naming?

Testing

Do I need to run tests for each backend, and where?

@A2va A2va marked this pull request as draft June 12, 2025 19:17
@laggui
Copy link
Member

laggui commented Jun 19, 2025

Sorry for the late response! Been a little busy recently 😅 this was on the roadmap so I will absolutely have a look when I get a bit of time

Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Jul 21, 2025
@A2va
Copy link
Author

A2va commented Jul 25, 2025

@laggui Any update ?

@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Jul 26, 2025
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pinging me, this completely slipped through the cracks with the last release sprint 😅

I think it's off to a good start, see my comments below.

For the unsigned ndarray int types, it's probably best to remove the Signed requirement if we expand to unsigned types. And fix the errors at the operation level. Maybe the restriction was for sign, but we could narrow it to Neg? Not 100% sure.

I also get this error for all usage of execute_with_int_dtype and I don't understand it because the conversion between NdArrayTensor<I> to NdArrayTensorInt is implemented.

That's because the backend int primitive was not changed, so the traits still expect NdArrayTensor<I> as output.

Comment on lines +444 to +453
pub enum IntDType {
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
}
Copy link
Member

@laggui laggui Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this follows the current float cast API, but it might be a good time to revisit this since we're expanding it.

We could probably just stick to DType with runtime validation at the high-level method. Would reduce some duplication of variants.

I think after that it would be nice to move to a more idiomatic API using generic types, i.e. tensor.cast::<f16>() instead of tensor.cast(DType::F16) (which would also work for int types). We already have the DType and Element trait in place. But I'm getting ahead of the current PR 😄 I'll take care of this change after.

@A2va
Copy link
Author

A2va commented Aug 5, 2025

I have replaced NdArrayTensor<I> by IntTensor<Self> in all integer operations, however it causes some errors. For example here:

fn float_gather(
        dim: usize,
        tensor: FloatTensor<Self>,
        indices: IntTensor<Self>,
    ) -> FloatTensor<Self> {
        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
            dim, tensor, indices
        ))
    }
error[E0308]: mismatched types
   --> crates\burn-ndarray\src\ops\tensor.rs:155:26
    |
154 |         execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
    |                                                    ---------------------- arguments to this function are incorrect
155 |             dim, tensor, indices
    |                          ^^^^^^^ expected `NdArrayTensor<_>`, found `NdArrayTensorInt`
    |
    = note: expected struct `tensor::base::NdArrayTensor<_>`
                 found enum `tensor::int::NdArrayTensorInt`
note: associated function defined here
   --> crates\burn-ndarray\src\ops\base.rs:515:12
    |
515 |     pub fn gather<I: NdArrayElement>(
    |            ^^^^^^
...
518 |         mut indices: NdArrayTensor<I>,
    |         -----------------------------

Does I need to write the Into trait ? But it isn't required to convert FloatTensor<Self> into NdArrayTensor<E>...

@laggui
Copy link
Member

laggui commented Aug 6, 2025

Ahhh that's because the current execute_with_float_dtype! only handles the float type dispatch correctly (and with your changes, the equivalent int macro also). But this is a mixed type dispatch (float + int) which is not handled. We need to dispatch both float/int types for such operations. In this case, probably combine both:

execute_with_float_dtype!(tensor, |tensor| {
    execute_with_int_dtype!(indices, |indices| {
        NdArrayMathOps::gather(dim, tensor, indices)
    })
})

(might not work with the current macros implementation, but this is just a representation)

Right now code duplication is OK this is a work in progress, but we'll probably want something along the lines of:

generate_dispatch_macro!(execute_with_float_dtype, NdArrayTensorFloat, F32(f32), F64(f64));
generate_dispatch_macro!(execute_with_int_dtype, NdArrayTensorInt, I8(i8), I16(i16), I32(i32), I64(i64));

to have a single macro written for type dispatch, and float/int dispatch can be generated.

@A2va
Copy link
Author

A2va commented Aug 10, 2025

Chaining two execute_with_* macro seems to work, at least it compiles.

Here is the remaining errors:

  1. type annotation needed for bitwise operation, or pow functions
  2. Some types mismatch in the module.rs
    module_op correctly resolve x to the correct type but not for indices.
fn max_pool2d_with_indices_backward(
        x: FloatTensor<Self>,
        kernel_size: [usize; 2],
        stride: [usize; 2],
        padding: [usize; 2],
        dilation: [usize; 2],
        output_grad: FloatTensor<Self>,
        indices: IntTensor<Self>,
    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
        module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
            let output = max_pool2d_backward::<E, I>(
                x,
                kernel_size,
                stride,
                padding,
                dilation,
                output_grad,
                indices,
            );
            MaxPool2dBackward::new(output.into())
        })
    }

So it leads to error like

error[E0271]: type mismatch resolving `<NdArray<E, I, Q> as Backend>::IntTensorPrimitive == NdArrayTensor<I>`
   --> crates\burn-ndarray\src\ops\module.rs:190:13
    |
190 |             MaxPool2dWithIndices::new(output.into(), indices)
    |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ type mismatch resolving `<NdArray<E, I, Q> as Backend>::IntTensorPrimitive == NdArrayTensor<I>`
    |
note: expected this to be `tensor::base::NdArrayTensor<I>`
   --> crates\burn-ndarray\src\backend.rs:51:31
    |
51  |     type IntTensorPrimitive = NdArrayTensorInt;
    |                               ^^^^^^^^^^^^^^^^
    = note: expected struct `tensor::base::NdArrayTensor<I>`
                 found enum `tensor::int::NdArrayTensorInt`

Maybe implementing a macro to convert NdArrayTensorInt to NdArrayTensor<I> could help

  1. Signed trait
    I remove the Signed trait to IntNdArrayElement and sign_op requires this trait, so it fails:
error[E0277]: the trait bound `u8: Signed` is not satisfied
  --> crates\burn-ndarray\src\tensor\int.rs:137:86
   |
137 |             $crate::NdArrayTensorInt::U8(tensor) => $crate::NdArrayTensorInt::U8($op(tensor)),
   |                                                                                      ^^^^^^ the trait `Signed` is not implemented for `u8`
   |
  ::: crates\burn-ndarray\src\ops\int_tensor.rs:449:41
   |
449 |         execute_with_int_dtype!(tensor, NdArrayMathOps::sign_op)
   |         --------------------------------------------------------
   |         |                               |
   |         |                               required by a bound introduced by this call
   |         in this macro invocation
   |
   = help: the following other types implement trait `Signed`:
             f32
             f64
             i128
             i16
             i32
             i64
             i8
             isize
note: required by a bound in `ops::base::NdArrayMathOps::<E>::sign_op`
  --> crates\burn-ndarray\src\ops\base.rs:803:12
   |
801 |     pub(crate) fn sign_op(tensor: NdArrayTensor<E>) -> NdArrayTensor<E>
   |                   ------- required by a bound in this associated function
802 |     where
803 |         E: Signed,
   |            ^^^^^^ required by this bound in `NdArrayMathOps::<E>::sign_op`
   = note: this error originates in the macro `execute_with_int_dtype` (in Nightly builds, run with -Z macro-backtrace for more info)

And using the Neg trait like you suggested results in

error[E0308]: `if` and `else` have incompatible types
   --> crates\burn-ndarray\src\ops\base.rs:816:25
    |
326 |   impl<E> NdArrayMathOps<E>
    |        - found this type parameter
...
813 |                       } else if x < zero {
    |  ____________________________-
814 | |                         -one
    | |                         ---- expected because of this
815 | |                     } else {
816 | |                         zero
    | |                         ^^^^ expected associated type, found type parameter `E`
817 | |                     }
    | |_____________________- `if` and `else` have incompatible types

I had to add a macro for the int_into_float and float_into_int operations, tell me if it needs some change.

@laggui
Copy link
Member

laggui commented Aug 15, 2025

  1. I am not entirely sure of the context, but probably due to the inflexibility of the macros 😅
  2. Definitely caused by the macro not handling the types
  3. For unsigned types the sign is always >=0 (so 0 or 1). We can simply have a Signum trait that computes the sign of a value, and implement it for signed and unsigned types e.g.
pub trait Signum: Zero + One {
    fn signum(self) -> Self;
}

macro_rules! impl_signum_signed {
    ($($t:ty),*) => {
        $(
            impl Signum for $t {
                fn signum(self) -> Self {
                    if self > Self::zero() { Self::one() } else if self < Self::zero() { -Self::one() } else { Self::zero() }
                }
            }
        )*
    };
}

macro_rules! impl_signum_unsigned {
    ($($t:ty),*) => {
        $(
            impl Signum for $t {
                fn signum(self) -> Self {
                    if self > Self::zero() { Self::one() } else { Self::zero() }
                }
            }
        )*
    };
}

impl_signum_signed!(i32, i64, f32, f64); // for all required signed types
impl_signum_unsigned!(u8, u16, u32, u64); // // for all required unsigned types

@A2va
Copy link
Author

A2va commented Aug 15, 2025

Thank you! Your solution for sign op produce conflicting implementations error.

error[E0119]: conflicting implementations of trait `SignOf`
  --> crates\burn-ndarray\src\element.rs:54:1
   |
42 | impl<E: NdArrayElement + num_traits::Signed> SignOf for E {
   | --------------------------------------------------------- first implementation here
...
54 | impl<E: NdArrayElement + num_traits::Unsigned> SignOf for E {
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ conflicting implementation

I used ChatGPT and it suggested two macros, one for each type. Is that okay with you?

Definitely caused by the macro not handling the types

I don't know how to add this to module_op macro, but couldn't we add a function/macro that converts NdArrayTensorInt enum to NdArrayTensor<uxx> ? Because it's all we need to do.

Also I noticed that new_tensor_int that is not used in bool_into_int contrary to new_tensor_float in bool_into_float, does that mean I have to use it ? There is no error anyway.

@A2va A2va changed the title Add int tensor cast - WIP Add int tensor cast Aug 15, 2025
@laggui
Copy link
Member

laggui commented Aug 15, 2025

Thank you! Your solution for sign op produce conflicting implementations error.
I used ChatGPT and it suggested two macros, one for each type. Is that okay with you?

Yeah that makes sense. I wrote that without thinking too much, but there is no mutual exclusivity for the traits so that's why the compiler complains 😅 Just edited my previous comment with an example of what this could look like for a Signum trait, and we can use x.signum() in the sign_op implementation.

Definitely caused by the macro not handling the types

I don't know how to add this to module_op macro, but couldn't we add a function/macro that converts NdArrayTensorInt enum to NdArrayTensor<uxx> ? Because it's all we need to do.

Yeah that's exactly what we need, similar to the problems you faced before that 🙂 gotta handle the conversion

Also I noticed that new_tensor_int that is not used in bool_into_int contrary to new_tensor_float in bool_into_float, does that mean I have to use it ? There is no error anyway.

Hmm weird, I think it should if you return IntTensor<Self> (the primitive, now an enum) and no NdArrayTensor<I> 🤔 should be a type mismatch

@A2va
Copy link
Author

A2va commented Aug 17, 2025

I get another type mismatch for indices here, I tried to use indices.into() but since I isn't know at this point it also errors out. I have added dispatch_int_tensor macro and it works great for max_pool2d_with_indices_backward but only because indices was a IntTensor.

In the function below indices is NdArrayTensor<I>

fn max_pool2d_with_indices(
        x: FloatTensor<Self>,
        kernel_size: [usize; 2],
        stride: [usize; 2],
        padding: [usize; 2],
        dilation: [usize; 2],
    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
        module_op!(inp(x), opt(), E, |x| {
            let (output, indices) =
                max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation); 
            MaxPool2dWithIndices::new(output.into(), indices)
        })
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants