Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

<!-- next-header -->
## [Unreleased] - ReleaseDate
- Fix bugs in the new wrapper setup where consumed and modified shapes
weren't respected during wrapper construction.

## [0.9.1] - 2025-09-10

- Add `StatefulInferer::replace_inferer` which works with a `&mut
Expand Down
102 changes: 88 additions & 14 deletions crates/cervo-core/src/recurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,23 @@ where
/// Create a new recurrency tracker for the model.
///
pub fn new(inferer: T, info: Vec<RecurrentInfo>) -> Result<Self> {
let inputs = inferer.raw_input_shapes();
let outputs = inferer.raw_output_shapes();
let raw_inputs = inferer.raw_input_shapes();
let raw_outputs = inferer.raw_output_shapes();

let mut offset = 0;
let keys = info
.iter()
.map(|info| {
let inslot = inputs
let inslot = raw_inputs
.iter()
.position(|input| info.inkey == input.0)
.with_context(|| format!("no input named {}", info.inkey))?;
let outslot = outputs
let outslot = raw_outputs
.iter()
.position(|output| info.outkey == output.0)
.with_context(|| format!("no output named {}", info.outkey))?;

let numels = inputs[inslot].1.iter().product();
let numels = raw_inputs[inslot].1.iter().product();
offset += numels;
Ok(RecurrentPair {
inslot,
Expand All @@ -136,16 +136,21 @@ where
})
.collect::<Result<TVec<RecurrentPair>>>()?;

let inputs = inferer.input_shapes();
let outputs = inferer.output_shapes();

let inputs = inputs
.iter()
.filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
.cloned()
.collect::<Vec<_>>();

let outputs = outputs
.iter()
.filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
.cloned()
.collect::<Vec<_>>();

Ok(Self {
inner: inferer,
state: RecurrentState {
Expand Down Expand Up @@ -248,23 +253,23 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
/// Create a new recurrency tracker for the model.
///
pub fn new<T: Inferer>(inner: Inner, inferer: &T, info: Vec<RecurrentInfo>) -> Result<Self> {
let inputs = inner.input_shapes(inferer);
let outputs = inner.output_shapes(inferer);
let raw_inputs = inferer.raw_input_shapes();
let raw_outputs = inferer.raw_output_shapes();

let mut offset = 0;
let keys = info
.iter()
.map(|info| {
let inslot = inputs
let inslot = raw_inputs
.iter()
.position(|input| info.inkey == input.0)
.with_context(|| format!("no input named {}", info.inkey))?;
let outslot = outputs
let outslot = raw_outputs
.iter()
.position(|output| info.outkey == output.0)
.with_context(|| format!("no output named {}", info.outkey))?;

let numels = inputs[inslot].1.iter().product();
let numels = raw_inputs[inslot].1.iter().product();
offset += numels;
Ok(RecurrentPair {
inslot,
Expand All @@ -275,16 +280,21 @@ impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
})
.collect::<Result<TVec<RecurrentPair>>>()?;

let inputs = inner.input_shapes(inferer);
let outputs = inner.output_shapes(inferer);

let inputs = inputs
.iter()
.filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
.cloned()
.collect::<Vec<_>>();

let outputs = outputs
.iter()
.filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
.cloned()
.collect::<Vec<_>>();

Ok(Self {
inner,
state: RecurrentState {
Expand Down Expand Up @@ -338,6 +348,8 @@ mod tests {
batcher::ScratchPadView,
inferer::State,
prelude::{Batcher, Inferer},
recurrent::RecurrentTrackerWrapper,
wrapper::InfererWrapper,
};

use super::RecurrentTracker;
Expand Down Expand Up @@ -371,6 +383,7 @@ mod tests {
end_called: false.into(),
begin_called: false.into(),
inputs: vec![
("epsilon".to_owned(), vec![2]),
(hidden_name_in.to_owned(), vec![2, 1]),
(cell_name_in.to_owned(), vec![2, 3]),
],
Expand All @@ -390,15 +403,15 @@ mod tests {
}

fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
assert_eq!(batch.inner().input_name(0), "lstm_hidden_state");
let hidden_value = batch.input_slot(0);
assert_eq!(batch.inner().input_name(1), "lstm_hidden_state");
let hidden_value = batch.input_slot(1);
let hidden_new = hidden_value.iter().map(|v| *v + 1.0).collect::<Vec<_>>();

assert_eq!(batch.inner().output_name(0), "lstm_hidden_state");
batch.output_slot_mut(0).copy_from_slice(&hidden_new);

assert_eq!(batch.inner().input_name(1), "lstm_cell_state");
let cell_value = batch.input_slot(1);
assert_eq!(batch.inner().input_name(2), "lstm_cell_state");
let cell_value = batch.input_slot(2);
let cell_new = cell_value.iter().map(|v| *v + 2.0).collect::<Vec<_>>();

assert_eq!(batch.inner().output_name(1), "lstm_cell_state");
Expand Down Expand Up @@ -585,4 +598,65 @@ mod tests {
assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
}

#[test]
fn test_wrapper_does_not_expose_inner_hidden() {
// Imagine Recurrent<Epsilon<...>>. We want to assert that
// Recurrent hides its own fields while also not exposing any
// fields from the inner epsilon wrapper.

struct DummyEpsilonWrapper {
inputs: Vec<(String, Vec<usize>)>,
}

impl InfererWrapper for DummyEpsilonWrapper {
fn invoke(
&self,
_inferer: &dyn Inferer,
_batch: &mut ScratchPadView<'_>,
) -> anyhow::Result<(), anyhow::Error> {
Ok(())
}
fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
&self.inputs
}
fn output_shapes<'a>(
&'a self,
_inferer: &'a dyn Inferer,
) -> &'a [(String, Vec<usize>)] {
_inferer.output_shapes()
}
fn begin_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
fn end_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
}

let inferer = DummyInferer::default();
let wrapper = DummyEpsilonWrapper {
inputs: vec![
("lstm_hidden_state".to_owned(), vec![2, 1]),
("lstm_cell_state".to_owned(), vec![2, 3]),
],
};

let recurrent = RecurrentTrackerWrapper::wrap(wrapper, &inferer).unwrap();

assert_eq!(recurrent.input_shapes(&inferer).len(), 0);
assert_eq!(
recurrent.output_shapes(&inferer).len(),
2,
"only hidden and cell state are recurrent: {:?}",
recurrent.output_shapes(&inferer)
);

assert_eq!(recurrent.output_shapes(&inferer)[0].0, "hidden_output");
assert_eq!(recurrent.output_shapes(&inferer)[1].0, "cell_output");

assert_eq!(recurrent.state.inputs.len(), 0);
assert_eq!(recurrent.state.outputs.len(), 2);

assert_eq!(recurrent.state.keys.len(), 2);
// slots are still correct despite epsilon being hidden
assert_eq!(recurrent.state.keys[0].inslot, 1);
assert_eq!(recurrent.state.keys[1].inslot, 2);
}
}
Loading