Skip to content

Commit 664a907

Browse files
committed
unstack fixes
1 parent 23c7f0b commit 664a907

File tree

1 file changed

+4
-3
lines changed
  • dfdx-core/src/tensor_ops/unstack

1 file changed

+4
-3
lines changed

dfdx-core/src/tensor_ops/unstack/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{shapes::*, tensor::*};
2+
use std::vec::Vec;
23

34
mod cpu_kernel;
45
#[cfg(feature = "cuda")]
@@ -21,15 +22,15 @@ mod webgpu_kernel;
2122
/// # use dfdx_core::prelude::*;
2223
/// # let dev: Cpu = Default::default();
2324
/// let stack: Tensor<Rank3<2, 3, 4>, f32, _> = dev.zeros();
24-
/// let [a, b]: [Tensor<Rank2<3, 4>, f32, _>; 2] = stack.unstack();
25+
/// let ([a, b], _tape): ([Tensor<Rank2<3, 4>, f32, _>; 2], _) = stack.unstack();
2526
/// ```
2627
///
2728
/// Unstacking to a vec:
2829
/// ```rust
2930
/// # use dfdx_core::prelude::*;
3031
/// # let dev: Cpu = Default::default();
31-
/// let stack: Tensor<(usize, Const::<3>, Const::<4>>, f32, _> = dev.zeros_like(&(2, Const, Const));
32-
/// let unstack: Vec<Tensor<Rank2<3, 4>, f32, _>> = stack.unstack();
32+
/// let stack: Tensor<(usize, Const::<3>, Const::<4>), f32, _> = dev.zeros_like(&(2, Const, Const));
33+
/// let (unstack, _tape): (Vec<Tensor<Rank2<3, 4>, f32, _>>, _) = stack.unstack();
3334
/// ```
3435
pub trait TryUnstack<Head: Dim>: Sized {
3536
type Unstacked;

0 commit comments

Comments
 (0)