Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/hl_ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ pub(super) mod tests {
#[test]
fn test_add() {
test_binary(27, 27, |a, b| a + b, |a, b| (&a + &b).unwrap());
test_binary((2, 1), (1, 3), |a, b| a + b, |a, b| (&a + &b).unwrap());
}

#[test]
Expand All @@ -410,6 +411,12 @@ pub(super) mod tests {
#[test]
fn test_mul() {
test_binary(27, 27, |a, b| a * b, |a, b| (&a * &b).unwrap());
test_binary(
(2, 1, 3),
(1, 4, 1),
|a, b| a * b,
|a, b| (&a * &b).unwrap(),
);
}

#[test]
Expand Down
24 changes: 24 additions & 0 deletions src/hl_ops/movement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ mod tests {
let out2 = inp.unsqueeze(3);
assert_eq!(out1.dims(), &[2, 1, 2, 3]);
assert_eq!(out2.dims(), &[2, 2, 3, 1]);
test_unary(
(1, 3),
|a| a.expand((2, 3)),
|a| a.broadcast_as((2, 3)).unwrap(),
);
test_unary((2, 1, 3), |a| a.squeeze(1), |a| a.reshape((2, 3)).unwrap());
}

#[test]
Expand Down Expand Up @@ -571,6 +577,24 @@ mod tests {
);
}

#[test]
fn test_gather_and_inverse_permutation() {
let mut cx = Graph::new();
let data = cx.tensor((2, 3));
let indexes = cx.tensor(4).as_dtype(DType::Int);
let gathered = data.gather(indexes).output();
let perm = cx.tensor(6).as_dtype(DType::Int);
let inv = perm.inverse_permutation(0).output();
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.].into());
rt.set_data(indexes.id, vec![5, 0, 3, 2].into());
rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0].into());
rt.execute(&cx.dyn_map);
assert_eq!(*rt.get_f32(gathered.id), vec![5., 0., 3., 2.]);
assert_eq!(*rt.get_f32(inv.id), vec![5., 3., 1., 0., 2., 4.]);
}

// // #[test]
// // fn test_cumsum() {
// // let mut cx = Graph::new();
Expand Down
24 changes: 24 additions & 0 deletions src/hl_ops/reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ impl GraphTensor {
#[cfg(test)]
mod tests {
use crate::hl_ops::unary::tests::test_unary;
use candle_core::{Device, Tensor};

#[test]
fn test_sum() {
test_unary((2, 3), |a| a.sum(1), |a| a.sum(1).unwrap());
test_unary(
(2, 3, 4),
|a| a.sum((0, 2)),
|a| a.sum(2).unwrap().sum(0).unwrap(),
);
}

#[test]
Expand All @@ -73,5 +79,23 @@ mod tests {
#[test]
fn test_mean() {
test_unary((2, 3), |a| a.mean(1), |a| a.mean(1).unwrap());
test_unary(
(2, 3, 4),
|a| a.mean((0, 2)),
|a| a.sum(2).unwrap().sum(0).unwrap() / 8.0,
);
}

#[test]
fn test_prod() {
test_unary(
(2, 3),
|a| a.prod(1),
|a| {
let v = a.to_vec2::<f32>().unwrap();
let out: Vec<f32> = v.iter().map(|row| row.iter().product()).collect();
Tensor::from_vec(out, v.len(), &Device::Cpu).unwrap()
},
);
}
}
44 changes: 44 additions & 0 deletions src/hl_ops/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,45 @@ pub(super) mod tests {
use itertools::Itertools;
use ordered_float::NotNan;

fn cumsum_ref_2d(a: Tensor) -> Tensor {
let v = a.to_vec2::<f32>().unwrap();
let mut out = vec![vec![0.0; v[0].len()]; v.len()];
for (i, row) in v.iter().enumerate() {
let mut acc = 0.0;
for (j, val) in row.iter().enumerate() {
acc += val;
out[i][j] = acc;
}
}
Tensor::new(out, a.device()).unwrap()
}

fn cummax_ref_2d(a: Tensor) -> Tensor {
let v = a.to_vec2::<f32>().unwrap();
let mut out = vec![vec![0.0; v[0].len()]; v.len()];
for (i, row) in v.iter().enumerate() {
let mut acc = f32::NEG_INFINITY;
for (j, val) in row.iter().enumerate() {
acc = acc.max(*val);
out[i][j] = acc;
}
}
Tensor::new(out, a.device()).unwrap()
}

fn cumprod_ref_2d(a: Tensor) -> Tensor {
let v = a.to_vec2::<f32>().unwrap();
let mut out = vec![vec![0.0; v[0].len()]; v.len()];
for (i, row) in v.iter().enumerate() {
let mut acc = 1.0;
for (j, val) in row.iter().enumerate() {
acc *= val;
out[i][j] = acc;
}
}
Tensor::new(out, a.device()).unwrap()
}

pub fn test_unary(
shape: impl ToShape,
func: fn(GraphTensor) -> GraphTensor,
Expand Down Expand Up @@ -346,6 +385,7 @@ pub(super) mod tests {
#[test]
fn test_softmax() {
test_unary(27, |a| a.softmax(0), |a| softmax(&a, 0).unwrap());
test_unary((4, 5), |a| a.softmax(1), |a| softmax(&a, 1).unwrap());
}

#[test]
Expand Down Expand Up @@ -373,11 +413,15 @@ pub(super) mod tests {
},
);
}

#[test]
fn test_cumsum() {
test_unary(27, |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
test_unary((27, 63), |a| a.cumsum(1), |a| a.cumsum(1).unwrap());
test_unary((27, 63), |a| a.cumsum(0), |a| a.cumsum(0).unwrap());
test_unary((2, 3), |a| a.cumsum(1), cumsum_ref_2d);
test_unary((2, 3), |a| a.cummax(1), cummax_ref_2d);
test_unary((2, 3), |a| a.cumprod(1), cumprod_ref_2d);
}
#[test]
fn test_argmax() {
Expand Down
21 changes: 21 additions & 0 deletions src/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,4 +1033,25 @@ mod tests {
let x = x.simplify();
assert_eq!(x.len(), 15); // Should be 11 if we can re-enable mul-div-associative-rev
}

#[test]
fn test_simplify_preserves_eval() {
let x = Expression::from('x');
let y = Expression::from('y');
let expr = ((x + 3) * 2) - (x * 2) + (y % 5);
let simplified = expr.simplify();
let env = [('x', 7), ('y', 11)].into_iter().collect();
assert_eq!(expr.exec(&env).unwrap(), simplified.exec(&env).unwrap());
let x = Expression::from('x');
let y = Expression::from('y');
let z = Expression::from('z');
let expr = (x + y) * (y - x);
let substituted = expr.substitute('x', z + 1).substitute('y', z - 1);
let simplified = substituted.simplify();
let env = [('z', 10)].into_iter().collect();
assert_eq!(
substituted.exec(&env).unwrap(),
simplified.exec(&env).unwrap()
);
}
}
71 changes: 71 additions & 0 deletions src/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,77 @@ mod tests {
println!("Val: {:?}", tracker.valid_expression());
}

#[test]
fn test_permute_and_expand() {
let mut tracker = ShapeTracker::new((2, 3, 4));
assert!(tracker.is_contiguous());
assert_eq!(
tracker.strides.as_slice(),
&[
Expression::from(12),
Expression::from(4),
Expression::from(1)
]
);
tracker.permute(&[1, 2, 0]);
assert_eq!(
tracker.dims.as_slice(),
&[
Expression::from(3),
Expression::from(4),
Expression::from(2)
]
);
assert_eq!(
tracker.strides.as_slice(),
&[
Expression::from(4),
Expression::from(1),
Expression::from(12)
]
);
assert!(!tracker.is_contiguous());
tracker.expand_dim(1, 1);
assert_eq!(
tracker.dims.as_slice(),
&[
Expression::from(3),
Expression::from(1),
Expression::from(4),
Expression::from(2)
]
);
assert_eq!(
tracker.strides.as_slice(),
&[
Expression::from(4),
Expression::from(0),
Expression::from(1),
Expression::from(12)
]
);
let removed = tracker.remove_dim(1);
assert_eq!(removed, Expression::from(1));
assert_eq!(
tracker.dims.as_slice(),
&[
Expression::from(3),
Expression::from(4),
Expression::from(2)
]
);
let mut tracker = ShapeTracker::new((1, 3));
tracker.expand((2, 3));
assert_eq!(
tracker.dims.as_slice(),
&[Expression::from(2), Expression::from(3)]
);
assert_eq!(
tracker.strides.as_slice(),
&[Expression::from(0), Expression::from(1)]
);
}

// #[test]
// fn test_merge_dims() {
// let mut tracker = ShapeTracker::new((10, 5, 3));
Expand Down
Loading