Gettingt elements from tensor #2723
-
Hi, I find there are some differences between burn and pytorch tensor. // An array of shape [4, 2, 3]
data = np.array([[[1,2,3],
[1,2,3]],
[[4,5,6],
[4,5,6]],
[[1,2,3],
[1,2,3]],
[[4,5,6],
[4,5,6]]])
data2 = torch.FloatTensor(data)
data2[:, 0, :].shape // The shape is [4, 3]
data2[:, 0:1, :].shape // The shape is [4, 1, 3] While for burn tensor, because I must use slice (I cannot find a way to use index), the shape is always [4, 1, 3]: // An array of shape [4, 2, 3]
let data = [[[1,2,3],
[1,2,3]],
[[4,5,6],
[4,5,6]],
[[1,2,3],
[1,2,3]],
[[4,5,6],
[4,5,6]]];
let data2 = Tensor::<NdArray, 3>::from_floats(data, &device);
data2.clone().slice([None, Some((0, 1)), None]) // The shape is [4, 1, 3]
// Cannot use data2.clone().slice([0..2, 0, 0..2]) How to subset the tensor like pytorch And here is another example, how to get an element in a tensor by the index ? let tensor = Tensor::<NdArray, 1>::from_floats([1,2,3]);
tensor[2] // Don't want to use tensor.slice.....because it returns a tensor |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Since there is no const expression yet available on Rust stable, operations typically don't change the rank, so it would be equivalent to PyTorch's |
Beta Was this translation helpful? Give feedback.
Since there is no const expression yet available on Rust stable, operations typically don't change the rank, so it would be equivalent to PyTorch's
keep_dim=True
. You can usesqueeze
afterward or areshape
to achieve the intended effect.