diff --git a/src/backend/cpu/eval.rs b/src/backend/cpu/eval.rs index 4cc8bdd5..30abad13 100644 --- a/src/backend/cpu/eval.rs +++ b/src/backend/cpu/eval.rs @@ -1,6 +1,7 @@ use super::ndarray::*; use crate::backend::cpu::kernel; use crate::core::{Dtype, Operation, Term}; +use TaggedNdArray::*; // TODO: this convenience method should live in open_hypergraphs use open_hypergraphs::layer::*; @@ -54,8 +55,6 @@ impl EvalState { } => { let (i, j) = (sources[0], sources[1]); let k = targets[0]; - - use TaggedNdArray::*; if let Ok([F32(f), F32(g), F32(h)]) = self.data[..].get_disjoint_mut([i, j, k]) { kernel::batch_matmul(f, g, h); } else { @@ -63,6 +62,19 @@ impl EvalState { } } + Add(_) => { + let (i, j) = (sources[0], sources[1]); + let k = targets[0]; + + if let Ok([F32(a), F32(b), F32(c)]) = self.data[..].get_disjoint_mut([i, j, k]) { + for i in 0..a.data.len() { + c.data[i] = a.data[i] + b.data[i]; + } + } else { + panic!("invalid types!"); + } + } + // this should be ruled out by typechecking op => { panic!("unknown operation {:?}", op); @@ -95,7 +107,42 @@ impl EvalState { #[cfg(test)] mod test { use super::*; - use crate::core::{Dtype, Operation, Shape}; + use crate::core::{Dtype, NdArrayType, Operation, Shape}; + + #[test] + fn test_add() { + let f = Operation::Add(NdArrayType { + shape: Shape(vec![2, 2]), + dtype: Dtype::F32, + }) + .term(); + + let x = NdArray { + data: vec![1., 2., 3., 4.], + shape: Shape(vec![2, 2]), + }; + let y = NdArray { + data: vec![10., 20., 30., 40.], + shape: Shape(vec![2, 2]), + }; + let expected = NdArray { + data: vec![11., 22., 33., 44.], + shape: Shape(vec![2, 2]), + }; + + let mut state = EvalState::new(f); + + // TODO: fix hack - API for EvalState? + state.data[0] = x.into(); + state.data[1] = y.into(); + + let [actual] = state.eval()[..] else { + panic!("unexpected coarity at eval time") + }; + + let tagged: TaggedNdArray = expected.into(); + assert_eq!(&tagged, actual); + } #[test] fn test_mat_mul() { diff --git a/src/core/operation.rs b/src/core/operation.rs index f4569da0..786a0c8a 100644 --- a/src/core/operation.rs +++ b/src/core/operation.rs @@ -16,18 +16,25 @@ pub enum Operation { }, /// Broadcast a value of shape x to one of shape n+x. - Broadcast { - n: Shape, - x: NdArrayType, - }, + Broadcast { n: Shape, x: NdArrayType }, /// Reshape a - Reshape { - x: NdArrayType, - y: NdArrayType, - }, + Reshape { x: NdArrayType, y: NdArrayType }, + /// Create a copy Copy(NdArrayType), + + /// Pointwise addition of two values of similar shapes + Add(NdArrayType), + + /// Pointwise subtraction of two values of similar shapes + Sub(NdArrayType), + + /// Pointwise multiplication of two values of similar shapes + Mul(NdArrayType), + + /// Pointwise negation of value + Negate(NdArrayType), } pub type Term = OpenHypergraph; @@ -79,6 +86,14 @@ impl Operation { Reshape { x, y } => (vec![x.clone()], vec![y.clone()]), Copy(x) => (vec![x.clone()], vec![x.clone(), x.clone()]), + + Add(x) => (vec![x.clone(), x.clone()], vec![x.clone()]), + + Sub(x) => (vec![x.clone(), x.clone()], vec![x.clone()]), + + Mul(x) => (vec![x.clone(), x.clone()], vec![x.clone()]), + + Negate(x) => (vec![x.clone()], vec![x.clone()]), } }