Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions src/backend/cpu/eval.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::ndarray::*;
use crate::backend::cpu::kernel;
use crate::core::{Dtype, Operation, Term};
use TaggedNdArray::*;

// TODO: this convenience method should live in open_hypergraphs
use open_hypergraphs::layer::*;
Expand Down Expand Up @@ -54,15 +55,26 @@ impl EvalState {
} => {
let (i, j) = (sources[0], sources[1]);
let k = targets[0];

use TaggedNdArray::*;
if let Ok([F32(f), F32(g), F32(h)]) = self.data[..].get_disjoint_mut([i, j, k]) {
kernel::batch_matmul(f, g, h);
} else {
panic!("invalid types!");
}
}

Add(_) => {
let (i, j) = (sources[0], sources[1]);
let k = targets[0];

if let Ok([F32(a), F32(b), F32(c)]) = self.data[..].get_disjoint_mut([i, j, k]) {
for i in 0..a.data.len() {
c.data[i] = a.data[i] + b.data[i];
}
} else {
panic!("invalid types!");
}
}

// this should be ruled out by typechecking
op => {
panic!("unknown operation {:?}", op);
Expand Down Expand Up @@ -95,7 +107,42 @@ impl EvalState {
#[cfg(test)]
mod test {
use super::*;
use crate::core::{Dtype, Operation, Shape};
use crate::core::{Dtype, NdArrayType, Operation, Shape};

#[test]
fn test_add() {
let f = Operation::Add(NdArrayType {
shape: Shape(vec![2, 2]),
dtype: Dtype::F32,
})
.term();

let x = NdArray {
data: vec![1., 2., 3., 4.],
shape: Shape(vec![2, 2]),
};
let y = NdArray {
data: vec![10., 20., 30., 40.],
shape: Shape(vec![2, 2]),
};
let expected = NdArray {
data: vec![11., 22., 33., 44.],
shape: Shape(vec![2, 2]),
};

let mut state = EvalState::new(f);

// TODO: fix hack - API for EvalState?
state.data[0] = x.into();
state.data[1] = y.into();

let [actual] = state.eval()[..] else {
panic!("unexpected coarity at eval time")
};

let tagged: TaggedNdArray = expected.into();
assert_eq!(&tagged, actual);
}

#[test]
fn test_mat_mul() {
Expand Down
31 changes: 23 additions & 8 deletions src/core/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@ pub enum Operation {
},

/// Broadcast a value of shape x to one of shape n+x.
Broadcast {
n: Shape,
x: NdArrayType,
},
Broadcast { n: Shape, x: NdArrayType },

/// Reshape a
Reshape {
x: NdArrayType,
y: NdArrayType,
},
Reshape { x: NdArrayType, y: NdArrayType },

/// Create a copy
Copy(NdArrayType),

/// Pointwise addition of two values of similar shapes
Add(NdArrayType),

/// Pointwise subtraction of two values of similar shapes
Sub(NdArrayType),

/// Pointwise multiplication of two values of similar shapes
Mul(NdArrayType),

/// Pointwise negation of value
Negate(NdArrayType),
}

pub type Term = OpenHypergraph<PrimitiveType, Operation>;
Expand Down Expand Up @@ -79,6 +86,14 @@ impl Operation {
Reshape { x, y } => (vec![x.clone()], vec![y.clone()]),

Copy(x) => (vec![x.clone()], vec![x.clone(), x.clone()]),

Add(x) => (vec![x.clone(), x.clone()], vec![x.clone()]),

Sub(x) => (vec![x.clone(), x.clone()], vec![x.clone()]),

Mul(x) => (vec![x.clone(), x.clone()], vec![x.clone()]),

Negate(x) => (vec![x.clone()], vec![x.clone()]),
}
}

Expand Down