-
Notifications
You must be signed in to change notification settings - Fork 694
Add int tensor cast #3289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add int tensor cast #3289
Conversation
Sorry for the late response! Been a little busy recently 😅 this was on the roadmap so I will absolutely have a look when I get a bit of time |
This PR has been marked as stale because it has not been updated for over a month |
@laggui Any update ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pinging me, this completely slipped through the cracks with the last release sprint 😅
I think it's off to a good start, see my comments below.
For the unsigned ndarray int types, it's probably best to remove the Signed
requirement if we expand to unsigned types. And fix the errors at the operation level. Maybe the restriction was for sign
, but we could narrow it to Neg
? Not 100% sure.
I also get this error for all usage of execute_with_int_dtype and I don't understand it because the conversion between NdArrayTensor<I> to NdArrayTensorInt is implemented.
That's because the backend int primitive was not changed, so the traits still expect NdArrayTensor<I>
as output.
pub enum IntDType { | ||
I64, | ||
I32, | ||
I16, | ||
I8, | ||
U64, | ||
U32, | ||
U16, | ||
U8, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this follows the current float cast API, but it might be a good time to revisit this since we're expanding it.
We could probably just stick to DType
with runtime validation at the high-level method. Would reduce some duplication of variants.
I think after that it would be nice to move to a more idiomatic API using generic types, i.e. tensor.cast::<f16>()
instead of tensor.cast(DType::F16)
(which would also work for int types). We already have the DType
and Element
trait in place. But I'm getting ahead of the current PR 😄 I'll take care of this change after.
I have replaced fn float_gather(
dim: usize,
tensor: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
dim, tensor, indices
))
}
Does I need to write the Into trait ? But it isn't required to convert |
Ahhh that's because the current execute_with_float_dtype!(tensor, |tensor| {
execute_with_int_dtype!(indices, |indices| {
NdArrayMathOps::gather(dim, tensor, indices)
})
}) (might not work with the current macros implementation, but this is just a representation) Right now code duplication is OK this is a work in progress, but we'll probably want something along the lines of: generate_dispatch_macro!(execute_with_float_dtype, NdArrayTensorFloat, F32(f32), F64(f64));
generate_dispatch_macro!(execute_with_int_dtype, NdArrayTensorInt, I8(i8), I16(i16), I32(i32), I64(i64)); to have a single macro written for type dispatch, and float/int dispatch can be generated. |
Chaining two Here is the remaining errors:
So it leads to error like
Maybe implementing a macro to convert
And using the Neg trait like you suggested results in
I had to add a macro for the |
pub trait Signum: Zero + One {
fn signum(self) -> Self;
}
macro_rules! impl_signum_signed {
($($t:ty),*) => {
$(
impl Signum for $t {
fn signum(self) -> Self {
if self > Self::zero() { Self::one() } else if self < Self::zero() { -Self::one() } else { Self::zero() }
}
}
)*
};
}
macro_rules! impl_signum_unsigned {
($($t:ty),*) => {
$(
impl Signum for $t {
fn signum(self) -> Self {
if self > Self::zero() { Self::one() } else { Self::zero() }
}
}
)*
};
}
impl_signum_signed!(i32, i64, f32, f64); // for all required signed types
impl_signum_unsigned!(u8, u16, u32, u64); // // for all required unsigned types |
Thank you! Your solution for sign op produce conflicting implementations error.
I used ChatGPT and it suggested two macros, one for each type. Is that okay with you?
I don't know how to add this to Also I noticed that |
Yeah that makes sense. I wrote that without thinking too much, but there is no mutual exclusivity for the traits so that's why the compiler complains 😅 Just edited my previous comment with an example of what this could look like for a
Yeah that's exactly what we need, similar to the problems you faced before that 🙂 gotta handle the conversion
Hmm weird, I think it should if you return |
I get another type mismatch for indices here, I tried to use In the function below fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
module_op!(inp(x), opt(), E, |x| {
let (output, indices) =
max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation);
MaxPool2dWithIndices::new(output.into(), indices)
})
} |
Pull Request Template
Checklist
cargo run-checks
command has been executed.Changes
As my project is blocked by both #3262 and a TensorFlow-generated ONNX issue, I decided to try implementing this feature.
My implementation is based on the existing float_cast function, so things might be different for integer ones. It is pretty much a work in progress, so there might be errors, and I look forward to receiving any feedback.
Regarding the current state of the implementation, I have added
int_cast
to all backends, but I am struggling to implement it for the ndarray backend.I have added the implementation of IntNdArrayElement (crates/burn-ndarray/src/element.rs), but since all uxx are unsigned, it errors with:
And the Sized trait cannot be removed since it produce other errors for
IntTensorOps
.I also get this error for all usage of
execute_with_int_dtype
and I don't understand it because the conversion betweenNdArrayTensor<I>
toNdArrayTensorInt
is implemented.I also noticed some inconsistencies in the file naming conventions between the backends. For example, in the ndarray backend it is named
int_tensor.rs
, whereas in cubecl it'sint_ops.rs
.What about unifying the naming?
Testing
Do I need to run tests for each backend, and where?