diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e60aa4..dabaf6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate +- Fix bugs in the new wrapper setup where consumed and modified shapes + weren't respected during wrapper construction. + ## [0.9.1] - 2025-09-10 - Add `StatefulInferer::replace_inferer` which works with a `&mut diff --git a/crates/cervo-core/src/recurrent.rs b/crates/cervo-core/src/recurrent.rs index ac8ed2d..48167f8 100644 --- a/crates/cervo-core/src/recurrent.rs +++ b/crates/cervo-core/src/recurrent.rs @@ -109,23 +109,23 @@ where /// Create a new recurrency tracker for the model. /// pub fn new(inferer: T, info: Vec) -> Result { - let inputs = inferer.raw_input_shapes(); - let outputs = inferer.raw_output_shapes(); + let raw_inputs = inferer.raw_input_shapes(); + let raw_outputs = inferer.raw_output_shapes(); let mut offset = 0; let keys = info .iter() .map(|info| { - let inslot = inputs + let inslot = raw_inputs .iter() .position(|input| info.inkey == input.0) .with_context(|| format!("no input named {}", info.inkey))?; - let outslot = outputs + let outslot = raw_outputs .iter() .position(|output| info.outkey == output.0) .with_context(|| format!("no output named {}", info.outkey))?; - let numels = inputs[inslot].1.iter().product(); + let numels = raw_inputs[inslot].1.iter().product(); offset += numels; Ok(RecurrentPair { inslot, @@ -136,16 +136,21 @@ where }) .collect::>>()?; + let inputs = inferer.input_shapes(); + let outputs = inferer.output_shapes(); + let inputs = inputs .iter() .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k)) .cloned() .collect::>(); + let outputs = outputs .iter() .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k)) .cloned() .collect::>(); + Ok(Self { inner: inferer, state: RecurrentState { @@ -248,23 +253,23 @@ impl RecurrentTrackerWrapper { /// Create a new recurrency tracker for the model. /// pub fn new(inner: Inner, inferer: &T, info: Vec) -> Result { - let inputs = inner.input_shapes(inferer); - let outputs = inner.output_shapes(inferer); + let raw_inputs = inferer.raw_input_shapes(); + let raw_outputs = inferer.raw_output_shapes(); let mut offset = 0; let keys = info .iter() .map(|info| { - let inslot = inputs + let inslot = raw_inputs .iter() .position(|input| info.inkey == input.0) .with_context(|| format!("no input named {}", info.inkey))?; - let outslot = outputs + let outslot = raw_outputs .iter() .position(|output| info.outkey == output.0) .with_context(|| format!("no output named {}", info.outkey))?; - let numels = inputs[inslot].1.iter().product(); + let numels = raw_inputs[inslot].1.iter().product(); offset += numels; Ok(RecurrentPair { inslot, @@ -275,16 +280,21 @@ impl RecurrentTrackerWrapper { }) .collect::>>()?; + let inputs = inner.input_shapes(inferer); + let outputs = inner.output_shapes(inferer); + let inputs = inputs .iter() .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k)) .cloned() .collect::>(); + let outputs = outputs .iter() .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k)) .cloned() .collect::>(); + Ok(Self { inner, state: RecurrentState { @@ -338,6 +348,8 @@ mod tests { batcher::ScratchPadView, inferer::State, prelude::{Batcher, Inferer}, + recurrent::RecurrentTrackerWrapper, + wrapper::InfererWrapper, }; use super::RecurrentTracker; @@ -371,6 +383,7 @@ mod tests { end_called: false.into(), begin_called: false.into(), inputs: vec![ + ("epsilon".to_owned(), vec![2]), (hidden_name_in.to_owned(), vec![2, 1]), (cell_name_in.to_owned(), vec![2, 3]), ], @@ -390,15 +403,15 @@ mod tests { } fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> { - assert_eq!(batch.inner().input_name(0), "lstm_hidden_state"); - let hidden_value = batch.input_slot(0); + assert_eq!(batch.inner().input_name(1), "lstm_hidden_state"); + let hidden_value = batch.input_slot(1); let hidden_new = hidden_value.iter().map(|v| *v + 1.0).collect::>(); assert_eq!(batch.inner().output_name(0), "lstm_hidden_state"); batch.output_slot_mut(0).copy_from_slice(&hidden_new); - assert_eq!(batch.inner().input_name(1), "lstm_cell_state"); - let cell_value = batch.input_slot(1); + assert_eq!(batch.inner().input_name(2), "lstm_cell_state"); + let cell_value = batch.input_slot(2); let cell_new = cell_value.iter().map(|v| *v + 2.0).collect::>(); assert_eq!(batch.inner().output_name(1), "lstm_cell_state"); @@ -585,4 +598,65 @@ mod tests { assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0)); assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0)); } + + #[test] + fn test_wrapper_does_not_expose_inner_hidden() { + // Imagine Recurrent>. We want to assert that + // Recurrent hides its own fields while also not exposing any + // fields from the inner epsilon wrapper. + + struct DummyEpsilonWrapper { + inputs: Vec<(String, Vec)>, + } + + impl InfererWrapper for DummyEpsilonWrapper { + fn invoke( + &self, + _inferer: &dyn Inferer, + _batch: &mut ScratchPadView<'_>, + ) -> anyhow::Result<(), anyhow::Error> { + Ok(()) + } + fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + &self.inputs + } + fn output_shapes<'a>( + &'a self, + _inferer: &'a dyn Inferer, + ) -> &'a [(String, Vec)] { + _inferer.output_shapes() + } + fn begin_agent(&self, _inferer: &dyn Inferer, _id: u64) {} + fn end_agent(&self, _inferer: &dyn Inferer, _id: u64) {} + } + + let inferer = DummyInferer::default(); + let wrapper = DummyEpsilonWrapper { + inputs: vec![ + ("lstm_hidden_state".to_owned(), vec![2, 1]), + ("lstm_cell_state".to_owned(), vec![2, 3]), + ], + }; + + let recurrent = RecurrentTrackerWrapper::wrap(wrapper, &inferer).unwrap(); + + assert_eq!(recurrent.input_shapes(&inferer).len(), 0); + assert_eq!( + recurrent.output_shapes(&inferer).len(), + 2, + "only hidden and cell state are recurrent: {:?}", + recurrent.output_shapes(&inferer) + ); + + assert_eq!(recurrent.output_shapes(&inferer)[0].0, "hidden_output"); + assert_eq!(recurrent.output_shapes(&inferer)[1].0, "cell_output"); + + assert_eq!(recurrent.state.inputs.len(), 0); + assert_eq!(recurrent.state.outputs.len(), 2); + + assert_eq!(recurrent.state.keys.len(), 2); + // slots are still correct despite epsilon being hidden + assert_eq!(recurrent.state.keys[0].inslot, 1); + assert_eq!(recurrent.state.keys[1].inslot, 2); + } }