diff --git a/src/hl_ops/binary.rs b/src/hl_ops/binary.rs index fdfa933a..4ee928d7 100644 --- a/src/hl_ops/binary.rs +++ b/src/hl_ops/binary.rs @@ -400,6 +400,7 @@ pub(super) mod tests { #[test] fn test_add() { test_binary(27, 27, |a, b| a + b, |a, b| (&a + &b).unwrap()); + test_binary((2, 1), (1, 3), |a, b| a + b, |a, b| (&a + &b).unwrap()); } #[test] @@ -410,6 +411,12 @@ pub(super) mod tests { #[test] fn test_mul() { test_binary(27, 27, |a, b| a * b, |a, b| (&a * &b).unwrap()); + test_binary( + (2, 1, 3), + (1, 4, 1), + |a, b| a * b, + |a, b| (&a * &b).unwrap(), + ); } #[test] diff --git a/src/hl_ops/movement.rs b/src/hl_ops/movement.rs index 303f9a8f..d46c5932 100644 --- a/src/hl_ops/movement.rs +++ b/src/hl_ops/movement.rs @@ -542,6 +542,12 @@ mod tests { let out2 = inp.unsqueeze(3); assert_eq!(out1.dims(), &[2, 1, 2, 3]); assert_eq!(out2.dims(), &[2, 2, 3, 1]); + test_unary( + (1, 3), + |a| a.expand((2, 3)), + |a| a.broadcast_as((2, 3)).unwrap(), + ); + test_unary((2, 1, 3), |a| a.squeeze(1), |a| a.reshape((2, 3)).unwrap()); } #[test] @@ -571,6 +577,24 @@ mod tests { ); } + #[test] + fn test_gather_and_inverse_permutation() { + let mut cx = Graph::new(); + let data = cx.tensor((2, 3)); + let indexes = cx.tensor(4).as_dtype(DType::Int); + let gathered = data.gather(indexes).output(); + let perm = cx.tensor(6).as_dtype(DType::Int); + let inv = perm.inverse_permutation(0).output(); + cx.build_search_space::(); + let mut rt = cx.search(NativeRuntime::default(), 1); + rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.].into()); + rt.set_data(indexes.id, vec![5, 0, 3, 2].into()); + rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0].into()); + rt.execute(&cx.dyn_map); + assert_eq!(*rt.get_f32(gathered.id), vec![5., 0., 3., 2.]); + assert_eq!(*rt.get_f32(inv.id), vec![5., 3., 1., 0., 2., 4.]); + } + // // #[test] // // fn test_cumsum() { // // let mut cx = Graph::new(); diff --git a/src/hl_ops/reduction.rs b/src/hl_ops/reduction.rs index 84686a0c..c21cb465 100644 --- a/src/hl_ops/reduction.rs +++ b/src/hl_ops/reduction.rs @@ -59,10 +59,16 @@ impl GraphTensor { #[cfg(test)] mod tests { use crate::hl_ops::unary::tests::test_unary; + use candle_core::{Device, Tensor}; #[test] fn test_sum() { test_unary((2, 3), |a| a.sum(1), |a| a.sum(1).unwrap()); + test_unary( + (2, 3, 4), + |a| a.sum((0, 2)), + |a| a.sum(2).unwrap().sum(0).unwrap(), + ); } #[test] @@ -73,5 +79,23 @@ mod tests { #[test] fn test_mean() { test_unary((2, 3), |a| a.mean(1), |a| a.mean(1).unwrap()); + test_unary( + (2, 3, 4), + |a| a.mean((0, 2)), + |a| a.sum(2).unwrap().sum(0).unwrap() / 8.0, + ); + } + + #[test] + fn test_prod() { + test_unary( + (2, 3), + |a| a.prod(1), + |a| { + let v = a.to_vec2::().unwrap(); + let out: Vec = v.iter().map(|row| row.iter().product()).collect(); + Tensor::from_vec(out, v.len(), &Device::Cpu).unwrap() + }, + ); } } diff --git a/src/hl_ops/unary.rs b/src/hl_ops/unary.rs index 4cf0bb00..9ad1834b 100644 --- a/src/hl_ops/unary.rs +++ b/src/hl_ops/unary.rs @@ -278,6 +278,45 @@ pub(super) mod tests { use itertools::Itertools; use ordered_float::NotNan; + fn cumsum_ref_2d(a: Tensor) -> Tensor { + let v = a.to_vec2::().unwrap(); + let mut out = vec![vec![0.0; v[0].len()]; v.len()]; + for (i, row) in v.iter().enumerate() { + let mut acc = 0.0; + for (j, val) in row.iter().enumerate() { + acc += val; + out[i][j] = acc; + } + } + Tensor::new(out, a.device()).unwrap() + } + + fn cummax_ref_2d(a: Tensor) -> Tensor { + let v = a.to_vec2::().unwrap(); + let mut out = vec![vec![0.0; v[0].len()]; v.len()]; + for (i, row) in v.iter().enumerate() { + let mut acc = f32::NEG_INFINITY; + for (j, val) in row.iter().enumerate() { + acc = acc.max(*val); + out[i][j] = acc; + } + } + Tensor::new(out, a.device()).unwrap() + } + + fn cumprod_ref_2d(a: Tensor) -> Tensor { + let v = a.to_vec2::().unwrap(); + let mut out = vec![vec![0.0; v[0].len()]; v.len()]; + for (i, row) in v.iter().enumerate() { + let mut acc = 1.0; + for (j, val) in row.iter().enumerate() { + acc *= val; + out[i][j] = acc; + } + } + Tensor::new(out, a.device()).unwrap() + } + pub fn test_unary( shape: impl ToShape, func: fn(GraphTensor) -> GraphTensor, @@ -346,6 +385,7 @@ pub(super) mod tests { #[test] fn test_softmax() { test_unary(27, |a| a.softmax(0), |a| softmax(&a, 0).unwrap()); + test_unary((4, 5), |a| a.softmax(1), |a| softmax(&a, 1).unwrap()); } #[test] @@ -373,11 +413,15 @@ pub(super) mod tests { }, ); } + #[test] fn test_cumsum() { test_unary(27, |a| a.cumsum(0), |a| a.cumsum(0).unwrap()); test_unary((27, 63), |a| a.cumsum(1), |a| a.cumsum(1).unwrap()); test_unary((27, 63), |a| a.cumsum(0), |a| a.cumsum(0).unwrap()); + test_unary((2, 3), |a| a.cumsum(1), cumsum_ref_2d); + test_unary((2, 3), |a| a.cummax(1), cummax_ref_2d); + test_unary((2, 3), |a| a.cumprod(1), cumprod_ref_2d); } #[test] fn test_argmax() { diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index cc92d936..f372160d 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -1033,4 +1033,25 @@ mod tests { let x = x.simplify(); assert_eq!(x.len(), 15); // Should be 11 if we can re-enable mul-div-associative-rev } + + #[test] + fn test_simplify_preserves_eval() { + let x = Expression::from('x'); + let y = Expression::from('y'); + let expr = ((x + 3) * 2) - (x * 2) + (y % 5); + let simplified = expr.simplify(); + let env = [('x', 7), ('y', 11)].into_iter().collect(); + assert_eq!(expr.exec(&env).unwrap(), simplified.exec(&env).unwrap()); + let x = Expression::from('x'); + let y = Expression::from('y'); + let z = Expression::from('z'); + let expr = (x + y) * (y - x); + let substituted = expr.substitute('x', z + 1).substitute('y', z - 1); + let simplified = substituted.simplify(); + let env = [('z', 10)].into_iter().collect(); + assert_eq!( + substituted.exec(&env).unwrap(), + simplified.exec(&env).unwrap() + ); + } } diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index d7909fd4..541b68a6 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -276,6 +276,77 @@ mod tests { println!("Val: {:?}", tracker.valid_expression()); } + #[test] + fn test_permute_and_expand() { + let mut tracker = ShapeTracker::new((2, 3, 4)); + assert!(tracker.is_contiguous()); + assert_eq!( + tracker.strides.as_slice(), + &[ + Expression::from(12), + Expression::from(4), + Expression::from(1) + ] + ); + tracker.permute(&[1, 2, 0]); + assert_eq!( + tracker.dims.as_slice(), + &[ + Expression::from(3), + Expression::from(4), + Expression::from(2) + ] + ); + assert_eq!( + tracker.strides.as_slice(), + &[ + Expression::from(4), + Expression::from(1), + Expression::from(12) + ] + ); + assert!(!tracker.is_contiguous()); + tracker.expand_dim(1, 1); + assert_eq!( + tracker.dims.as_slice(), + &[ + Expression::from(3), + Expression::from(1), + Expression::from(4), + Expression::from(2) + ] + ); + assert_eq!( + tracker.strides.as_slice(), + &[ + Expression::from(4), + Expression::from(0), + Expression::from(1), + Expression::from(12) + ] + ); + let removed = tracker.remove_dim(1); + assert_eq!(removed, Expression::from(1)); + assert_eq!( + tracker.dims.as_slice(), + &[ + Expression::from(3), + Expression::from(4), + Expression::from(2) + ] + ); + let mut tracker = ShapeTracker::new((1, 3)); + tracker.expand((2, 3)); + assert_eq!( + tracker.dims.as_slice(), + &[Expression::from(2), Expression::from(3)] + ); + assert_eq!( + tracker.strides.as_slice(), + &[Expression::from(0), Expression::from(1)] + ); + } + // #[test] // fn test_merge_dims() { // let mut tracker = ShapeTracker::new((10, 5, 3));