diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d3ec58..7032d2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate + +- Breaking: `Inferer.begin_agent` and `Inferer.end_agent` now take + `&self`, changed from a mutable reference. + +### Wrapper rework + +To support a wider variety of uses, we have implemented a new category +of wrappers that do not require ownership of the inferer. This allows +for more flexible usage patterns, where the inferer policy can be +replaced in a live application without losing any state kept in +wrappers. + +This change is currently non-breaking and is implemented separately +from the old wrapper system. + ## [0.8.0] - 2025-05-28 - Added a new `RecurrentTracker` wrapper to handle recurrent inputs/outputs if the recurrent data is only needed durign network diff --git a/Cargo.lock b/Cargo.lock index 3a86349..fe8a3e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,12 +741,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -799,12 +798,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking_lot" version = "0.12.3" @@ -1332,9 +1325,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tracing-core" -version = "0.1.30" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", "valuable", @@ -1342,20 +1335,20 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] [[package]] name = "tracing-subscriber" -version = "0.3.16" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", diff --git a/crates/cervo-cli/src/commands/benchmark.rs b/crates/cervo-cli/src/commands/benchmark.rs index 7abc430..1ad90b0 100644 --- a/crates/cervo-cli/src/commands/benchmark.rs +++ b/crates/cervo-cli/src/commands/benchmark.rs @@ -1,7 +1,9 @@ use anyhow::{bail, Result}; use cervo::asset::AssetData; -use cervo::core::prelude::{Batcher, Inferer, InfererExt, State}; -use cervo::core::recurrent::{RecurrentInfo, RecurrentTracker}; +use cervo::core::epsilon::EpsilonInjectorWrapper; +use cervo::core::prelude::{Batcher, Inferer, State}; +use cervo::core::recurrent::{RecurrentInfo, RecurrentTrackerWrapper}; +use cervo::core::wrapper::{BaseWrapper, InfererWrapper, InfererWrapperExt}; use clap::Parser; use clap::ValueEnum; use serde::Serialize; @@ -167,7 +169,7 @@ struct Record { total: f64, } -fn execute_load_metrics( +fn execute_load_metrics( batch_size: usize, data: HashMap>, count: usize, @@ -222,85 +224,112 @@ pub fn build_inputs_from_desc( .collect() } -fn do_run(mut inferer: impl Inferer, batch_size: usize, config: &Args) -> Result { - let shapes = inferer.input_shapes().to_vec(); - let observations = build_inputs_from_desc(batch_size as u64, &shapes); - for id in 0..batch_size { - inferer.begin_agent(id as u64); - } - let res = execute_load_metrics(batch_size, observations, config.count, &mut inferer)?; - for id in 0..batch_size { - inferer.end_agent(id as u64); +fn do_run( + wrapper: impl InfererWrapper + 'static, + inferer: impl Inferer + 'static, + config: &Args, +) -> Result> { + let mut model = wrapper.wrap(Box::new(inferer) as Box); + + let mut records = Vec::with_capacity(config.batch_sizes.len()); + for batch_size in config.batch_sizes.clone() { + let mut reader = File::open(&config.file)?; + let inferer = if cervo::nnef::is_nnef_tar(&config.file) { + cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])? + } else { + match config.file.extension().and_then(|ext| ext.to_str()) { + Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?, + Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?, + Some(other) => bail!("unknown file type {:?}", other), + None => bail!("missing file extension {:?}", config.file), + } + }; + + model = model + .with_new_inferer(Box::new(inferer) as Box) + .map_err(|(_, e)| e)?; + + let shapes = model.input_shapes().to_vec(); + let observations = build_inputs_from_desc(batch_size as u64, &shapes); + for id in 0..batch_size { + model.begin_agent(id as u64); + } + let res = execute_load_metrics(batch_size, observations, config.count, &mut model)?; + + // Print Text + if matches!(config.output, OutputFormat::Text) { + println!( + "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total", + res.batch_size, res.mean, res.stddev, res.total, + ); + } + + records.push(res); + for id in 0..batch_size { + model.end_agent(id as u64); + } } - Ok(res) + Ok(records) } fn run_apply_epsilon_config( - inferer: impl Inferer, - batch_size: usize, + wrapper: impl InfererWrapper + 'static, + inferer: impl Inferer + 'static, config: &Args, -) -> Result { +) -> Result> { if let Some(epsilon) = config.with_epsilon.as_ref() { - let inferer = inferer.with_default_epsilon(epsilon)?; - do_run(inferer, batch_size, config) + let wrapper = EpsilonInjectorWrapper::wrap(wrapper, &inferer, epsilon)?; + do_run(wrapper, inferer, config) } else { - do_run(inferer, batch_size, config) + do_run(wrapper, inferer, config) } } -fn run_apply_recurrent(inferer: impl Inferer, batch_size: usize, config: &Args) -> Result { +fn run_apply_recurrent( + wrapper: impl InfererWrapper + 'static, + inferer: impl Inferer + 'static, + config: &Args, +) -> Result> { if let Some(recurrent) = config.recurrent.as_ref() { if matches!(recurrent, RecurrentConfig::None) { - run_apply_epsilon_config(inferer, batch_size, config) + run_apply_epsilon_config(wrapper, inferer, config) } else { - let inferer = match recurrent { + let wrapper = match recurrent { RecurrentConfig::None => unreachable!(), - RecurrentConfig::Auto => RecurrentTracker::wrap(inferer), + RecurrentConfig::Auto => RecurrentTrackerWrapper::wrap(wrapper, &inferer), RecurrentConfig::Mapped(map) => { let infos = map .iter() .cloned() .map(|(inkey, outkey)| RecurrentInfo { inkey, outkey }) .collect::>(); - RecurrentTracker::new(inferer, infos) + RecurrentTrackerWrapper::new(wrapper, &inferer, infos) } }?; - run_apply_epsilon_config(inferer, batch_size, config) + run_apply_epsilon_config(wrapper, inferer, config) } } else { - run_apply_epsilon_config(inferer, batch_size, config) + run_apply_epsilon_config(wrapper, inferer, config) } } pub(super) fn run(config: Args) -> Result<()> { - let mut records: Vec = Vec::new(); - for batch_size in config.batch_sizes.clone() { - let mut reader = File::open(&config.file)?; - let inferer = if cervo::nnef::is_nnef_tar(&config.file) { - cervo::nnef::builder(&mut reader).build_fixed(&[batch_size])? - } else { - match config.file.extension().and_then(|ext| ext.to_str()) { - Some("onnx") => cervo::onnx::builder(&mut reader).build_fixed(&[batch_size])?, - Some("crvo") => AssetData::deserialize(&mut reader)?.load_fixed(&[batch_size])?, - Some(other) => bail!("unknown file type {:?}", other), - None => bail!("missing file extension {:?}", config.file), - } - }; - - let record = run_apply_recurrent(inferer, batch_size, &config)?; - - // Print Text - if matches!(config.output, OutputFormat::Text) { - println!( - "Batch Size {}: {:.2} ms ± {:.2} per element, {:.2} ms total", - record.batch_size, record.mean, record.stddev, record.total, - ); + let mut reader = File::open(&config.file)?; + let inferer = if cervo::nnef::is_nnef_tar(&config.file) { + cervo::nnef::builder(&mut reader).build_basic()? + } else { + match config.file.extension().and_then(|ext| ext.to_str()) { + Some("onnx") => cervo::onnx::builder(&mut reader).build_basic()?, + Some("crvo") => AssetData::deserialize(&mut reader)?.load_basic()?, + Some(other) => bail!("unknown file type {:?}", other), + None => bail!("missing file extension {:?}", config.file), } + }; + + let records = run_apply_recurrent(BaseWrapper, inferer, &config)?; - records.push(record); - } // Print JSON if matches!(config.output, OutputFormat::Json) { let json = serde_json::to_string_pretty(&records)?; diff --git a/crates/cervo-cli/src/commands/run.rs b/crates/cervo-cli/src/commands/run.rs index 8ab9c4b..af58066 100644 --- a/crates/cervo-cli/src/commands/run.rs +++ b/crates/cervo-cli/src/commands/run.rs @@ -84,15 +84,7 @@ pub(super) fn run(config: Args) -> Result<()> { let elapsed = if let Some(epsilon) = config.with_epsilon.as_ref() { let inferer = inferer.with_default_epsilon(epsilon)?; - // TODO[TSolberg]: Issue #31. - let shapes = inferer - .raw_input_shapes() - .iter() - .filter(|(k, _)| k.as_str() != epsilon) - .cloned() - .collect::>(); - - let observations = build_inputs_from_desc(config.batch_size as u64, &shapes); + let observations = build_inputs_from_desc(config.batch_size as u64, inferer.input_shapes()); if config.print_input { print_input(&observations); diff --git a/crates/cervo-core/src/epsilon.rs b/crates/cervo-core/src/epsilon.rs index 0009675..89aaaed 100644 --- a/crates/cervo-core/src/epsilon.rs +++ b/crates/cervo-core/src/epsilon.rs @@ -8,7 +8,7 @@ Utilities for filling noise inputs for an inference model. use std::cell::RefCell; -use crate::{batcher::ScratchPadView, inferer::Inferer}; +use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper}; use anyhow::{bail, Result}; use perchance::PerchanceContext; use rand::thread_rng; @@ -112,6 +112,13 @@ impl NoiseGenerator for HighQualityNoiseGenerator { } } +struct EpsilonInjectorState { + count: usize, + index: usize, + generator: NG, + + inputs: Vec<(String, Vec)>, +} /// The [`EpsilonInjector`] wraps an inferer to add noise values as one of the input data points. This is useful for /// continuous action policies where you might have trained your agent to follow a stochastic policy trained with the /// reparametrization trick. @@ -120,11 +127,8 @@ impl NoiseGenerator for HighQualityNoiseGenerator { /// wrapper. pub struct EpsilonInjector { inner: T, - count: usize, - index: usize, - generator: NG, - inputs: Vec<(String, Vec)>, + state: EpsilonInjectorState, } impl EpsilonInjector @@ -169,11 +173,12 @@ where Ok(Self { inner: inferer, - index, - count, - generator, - - inputs, + state: EpsilonInjectorState { + index, + count, + generator, + inputs, + }, }) } } @@ -188,15 +193,15 @@ where } fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> { - let total_count = self.count * batch.len(); - let output = batch.input_slot_mut(self.index); - self.generator.generate(total_count, output); + let total_count = self.state.count * batch.len(); + let output = batch.input_slot_mut(self.state.index); + self.state.generator.generate(total_count, output); self.inner.infer_raw(batch) } fn input_shapes(&self) -> &[(String, Vec)] { - &self.inputs + &self.state.inputs } fn raw_input_shapes(&self) -> &[(String, Vec)] { @@ -207,11 +212,105 @@ where self.inner.raw_output_shapes() } - fn begin_agent(&mut self, id: u64) { + fn begin_agent(&self, id: u64) { self.inner.begin_agent(id); } - fn end_agent(&mut self, id: u64) { + fn end_agent(&self, id: u64) { self.inner.end_agent(id); } } + +pub struct EpsilonInjectorWrapper { + inner: Inner, + state: EpsilonInjectorState, +} + +impl EpsilonInjectorWrapper { + /// Wraps the provided `inferer` to automatically generate noise for the input named by `key`. + /// + /// This function will use [`HighQualityNoiseGenerator`] as the noise source. + /// + /// # Errors + /// + /// Will return an error if the provided key doesn't match an input on the model. + pub fn wrap( + inner: Inner, + inferer: &dyn Inferer, + key: &str, + ) -> Result> { + Self::with_generator(inner, inferer, HighQualityNoiseGenerator::default(), key) + } +} + +impl EpsilonInjectorWrapper +where + Inner: InfererWrapper, + NG: NoiseGenerator, +{ + /// Create a new injector for the provided `key`, using the custom `generator` as the noise source. + /// + /// # Errors + /// + /// Will return an error if the provided key doesn't match an input on the model. + pub fn with_generator( + inner: Inner, + inferer: &dyn Inferer, + generator: NG, + key: &str, + ) -> Result { + let inputs = inferer.input_shapes(); + + let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) { + Some((index, (_, shape))) => (index, shape.iter().product()), + None => bail!("model has no input key {:?}", key), + }; + + let inputs = inputs + .iter() + .filter(|(k, _)| *k != key) + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect::>(); + + Ok(Self { + inner, + state: EpsilonInjectorState { + index, + count, + generator, + inputs, + }, + }) + } +} + +impl InfererWrapper for EpsilonInjectorWrapper +where + Inner: InfererWrapper, + NG: NoiseGenerator, +{ + fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> { + self.inner.invoke(inferer, batch)?; + let total_count = self.state.count * batch.len(); + let output = batch.input_slot_mut(self.state.index); + self.state.generator.generate(total_count, output); + + self.inner.invoke(inferer, batch) + } + + fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.state.inputs.as_ref() + } + + fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.inner.output_shapes(inferer) + } + + fn begin_agent(&self, inferer: &dyn Inferer, id: u64) { + self.inner.begin_agent(inferer, id); + } + + fn end_agent(&self, inferer: &dyn Inferer, id: u64) { + self.inner.end_agent(inferer, id); + } +} diff --git a/crates/cervo-core/src/inferer.rs b/crates/cervo-core/src/inferer.rs index 22e4f48..27e3a01 100644 --- a/crates/cervo-core/src/inferer.rs +++ b/crates/cervo-core/src/inferer.rs @@ -112,8 +112,8 @@ pub trait Inferer { /// Retrieve the name and shapes of the model outputs. fn raw_output_shapes(&self) -> &[(String, Vec)]; - fn begin_agent(&mut self, id: u64); - fn end_agent(&mut self, id: u64); + fn begin_agent(&self, id: u64); + fn end_agent(&self, id: u64); } /// Helper trait to provide helper functions for loadable models. @@ -242,12 +242,12 @@ impl Inferer for Box { self.as_ref().raw_output_shapes() } - fn begin_agent(&mut self, id: u64) { - self.as_mut().begin_agent(id); + fn begin_agent(&self, id: u64) { + self.as_ref().begin_agent(id); } - fn end_agent(&mut self, id: u64) { - self.as_mut().end_agent(id); + fn end_agent(&self, id: u64) { + self.as_ref().end_agent(id); } } @@ -268,11 +268,11 @@ impl Inferer for Box { self.as_ref().raw_output_shapes() } - fn begin_agent(&mut self, id: u64) { - self.as_mut().begin_agent(id); + fn begin_agent(&self, id: u64) { + self.as_ref().begin_agent(id); } - fn end_agent(&mut self, id: u64) { - self.as_mut().end_agent(id); + fn end_agent(&self, id: u64) { + self.as_ref().end_agent(id); } } diff --git a/crates/cervo-core/src/inferer/basic.rs b/crates/cervo-core/src/inferer/basic.rs index 7292311..fefb8d8 100644 --- a/crates/cervo-core/src/inferer/basic.rs +++ b/crates/cervo-core/src/inferer/basic.rs @@ -94,6 +94,6 @@ impl Inferer for BasicInferer { &self.model_api.outputs } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } diff --git a/crates/cervo-core/src/inferer/dynamic.rs b/crates/cervo-core/src/inferer/dynamic.rs index 5cc0f65..afcdd9a 100644 --- a/crates/cervo-core/src/inferer/dynamic.rs +++ b/crates/cervo-core/src/inferer/dynamic.rs @@ -110,6 +110,6 @@ impl Inferer for DynamicInferer { &self.model_api.outputs } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } diff --git a/crates/cervo-core/src/inferer/fixed.rs b/crates/cervo-core/src/inferer/fixed.rs index 1b96f53..57a80e8 100644 --- a/crates/cervo-core/src/inferer/fixed.rs +++ b/crates/cervo-core/src/inferer/fixed.rs @@ -111,8 +111,8 @@ impl Inferer for FixedBatchInferer { &self.model_api.outputs } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } struct BatchedModel { diff --git a/crates/cervo-core/src/inferer/memoizing.rs b/crates/cervo-core/src/inferer/memoizing.rs index a3ace8e..fd3a365 100644 --- a/crates/cervo-core/src/inferer/memoizing.rs +++ b/crates/cervo-core/src/inferer/memoizing.rs @@ -175,6 +175,6 @@ impl Inferer for MemoizingDynamicInferer { &self.model_api.outputs } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } diff --git a/crates/cervo-core/src/lib.rs b/crates/cervo-core/src/lib.rs index 2dd9776..9bd80bb 100644 --- a/crates/cervo-core/src/lib.rs +++ b/crates/cervo-core/src/lib.rs @@ -17,6 +17,7 @@ pub mod epsilon; pub mod inferer; mod model_api; pub mod recurrent; +pub mod wrapper; /// Most core utilities are re-exported here. pub mod prelude { @@ -31,4 +32,5 @@ pub mod prelude { pub use super::model_api::ModelApi; pub use super::recurrent::{RecurrentInfo, RecurrentTracker}; + pub use super::wrapper::{InfererWrapper, InfererWrapperExt, IntoStateful, StatefulInferer}; } diff --git a/crates/cervo-core/src/model_api.rs b/crates/cervo-core/src/model_api.rs index 2d31f1d..e0d531f 100644 --- a/crates/cervo-core/src/model_api.rs +++ b/crates/cervo-core/src/model_api.rs @@ -56,6 +56,8 @@ impl ModelApi { Ok(Self { outputs, inputs }) } + // Note[TS]: Clippy wants us to use name...clone_into(&name) but that's illegal. + #[allow(clippy::assigning_clones)] pub fn for_typed_model(model: &TypedModel) -> TractResult { let mut inputs: Vec<(String, Vec)> = Default::default(); diff --git a/crates/cervo-core/src/recurrent.rs b/crates/cervo-core/src/recurrent.rs index ff64eb6..7cc4243 100644 --- a/crates/cervo-core/src/recurrent.rs +++ b/crates/cervo-core/src/recurrent.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::{batcher::ScratchPadView, inferer::Inferer}; +use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper}; use anyhow::{Context, Result}; use itertools::Itertools; use parking_lot::RwLock; @@ -18,11 +18,7 @@ struct RecurrentPair { offset: usize, } -/// The [`RecurrentTracker`] wraps an inferer to manage states that -/// are input/output in a recurrent fashion, instead of roundtripping -/// them to the high-level code. -pub struct RecurrentTracker { - inner: T, +struct RecurrentState { keys: TVec, per_agent_states: RwLock>>, agent_state_size: usize, @@ -31,6 +27,53 @@ pub struct RecurrentTracker { outputs: Vec<(String, Vec)>, } +impl RecurrentState { + fn apply(&self, batch: &mut ScratchPadView<'_>) { + for pair in &self.keys { + let (ids, indata) = batch.input_slot_mut_with_id(pair.inslot); + + let mut offset = 0; + let states = self.per_agent_states.read(); + for id in ids { + // if None, leave as zeros and pray + if let Some(state) = states.get(id) { + indata[offset..offset + pair.numels] + .copy_from_slice(&state[pair.offset..pair.offset + pair.numels]); + } else { + indata[offset..offset + pair.numels].fill(0.0); + } + offset += pair.numels; + } + } + } + + fn extract(&self, batch: &mut ScratchPadView<'_>) { + for pair in &self.keys { + let (ids, outdata) = batch.output_slot_mut_with_id(pair.outslot); + + let mut offset = 0; + let mut states = self.per_agent_states.write(); + for id in ids { + // if None, leave as zeros and pray + if let Some(state) = states.get_mut(id) { + state[pair.offset..pair.offset + pair.numels] + .copy_from_slice(&outdata[offset..offset + pair.numels]); + } + + offset += pair.numels; + } + } + } +} + +/// The [`RecurrentTracker`] wraps an inferer to manage states that +/// are input/output in a recurrent fashion, instead of roundtripping +/// them to the high-level code. +pub struct RecurrentTracker { + inner: T, + state: RecurrentState, +} + impl RecurrentTracker where T: Inferer, @@ -105,11 +148,13 @@ where .collect::>(); Ok(Self { inner: inferer, - keys, - agent_state_size: offset, - inputs, - outputs, - per_agent_states: Default::default(), + state: RecurrentState { + keys, + agent_state_size: offset, + inputs, + outputs, + per_agent_states: Default::default(), + }, }) } } @@ -123,40 +168,11 @@ where } fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> { - for pair in &self.keys { - let (ids, indata) = batch.input_slot_mut_with_id(pair.inslot); - - let mut offset = 0; - let states = self.per_agent_states.read(); - for id in ids { - // if None, leave as zeros and pray - if let Some(state) = states.get(id) { - indata[offset..offset + pair.numels] - .copy_from_slice(&state[pair.offset..pair.offset + pair.numels]); - } else { - indata[offset..offset + pair.numels].fill(0.0); - } - offset += pair.numels; - } - } + self.state.apply(batch); self.inner.infer_raw(batch)?; - for pair in &self.keys { - let (ids, outdata) = batch.output_slot_mut_with_id(pair.outslot); - - let mut offset = 0; - let mut states = self.per_agent_states.write(); - for id in ids { - // if None, leave as zeros and pray - if let Some(state) = states.get_mut(id) { - state[pair.offset..pair.offset + pair.numels] - .copy_from_slice(&outdata[offset..offset + pair.numels]); - } - - offset += pair.numels; - } - } + self.state.extract(batch); Ok(()) } @@ -170,29 +186,154 @@ where } fn input_shapes(&self) -> &[(String, Vec)] { - &self.inputs + &self.state.inputs } fn output_shapes(&self) -> &[(String, Vec)] { - &self.outputs + &self.state.outputs } - fn begin_agent(&mut self, id: u64) { - self.per_agent_states - .write() - .insert(id, vec![0.0; self.agent_state_size].into_boxed_slice()); + fn begin_agent(&self, id: u64) { + self.state.per_agent_states.write().insert( + id, + vec![0.0; self.state.agent_state_size].into_boxed_slice(), + ); self.inner.begin_agent(id); } - fn end_agent(&mut self, id: u64) { - self.per_agent_states.write().remove(&id); + fn end_agent(&self, id: u64) { + self.state.per_agent_states.write().remove(&id); self.inner.end_agent(id); } } +/// A wrapper that adds recurrent state tracking to an inner model. +/// +/// This is an alternative to using [`RecurrentTracker`] which allows separate +/// state tracking from the inferer itself. +pub struct RecurrentTrackerWrapper { + inner: Inner, + state: RecurrentState, +} + +impl RecurrentTrackerWrapper { + /// Wraps the provided `inferer` to automatically track any keys that are both inputs/outputs. + pub fn wrap(inner: Inner, inferer: &T) -> Result> { + let inputs = inferer.raw_input_shapes(); + let outputs = inferer.raw_output_shapes(); + + let mut keys = vec![]; + + for (inkey, inshape) in inputs { + for (outkey, outshape) in outputs { + if inkey == outkey && inshape == outshape { + keys.push(RecurrentInfo { + inkey: inkey.clone(), + outkey: outkey.clone(), + }); + } + } + } + + if keys.is_empty() { + let inkeys = inputs.iter().map(|(k, _)| k).join(", "); + let outkeys = outputs.iter().map(|(k, _)| k).join(", "); + anyhow::bail!( + "Unable to find a matching key between inputs [{inkeys}] and outputs [{outkeys}]" + ); + } + Self::new(inner, inferer, keys) + } + + /// Create a new recurrency tracker for the model. + /// + pub fn new(inner: Inner, inferer: &T, info: Vec) -> Result { + let inputs = inferer.raw_input_shapes(); + let outputs = inferer.raw_output_shapes(); + + let mut offset = 0; + let keys = info + .iter() + .map(|info| { + let inslot = inputs + .iter() + .position(|input| info.inkey == input.0) + .with_context(|| format!("no input named {}", info.inkey))?; + let outslot = outputs + .iter() + .position(|output| info.outkey == output.0) + .with_context(|| format!("no output named {}", info.outkey))?; + + let numels = inputs[inslot].1.iter().product(); + offset += numels; + Ok(RecurrentPair { + inslot, + outslot, + numels, + offset: offset - numels, + }) + }) + .collect::>>()?; + + let inputs = inputs + .iter() + .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k)) + .cloned() + .collect::>(); + let outputs = outputs + .iter() + .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k)) + .cloned() + .collect::>(); + Ok(Self { + inner, + state: RecurrentState { + keys, + agent_state_size: offset, + inputs, + outputs, + per_agent_states: Default::default(), + }, + }) + } +} + +impl InfererWrapper for RecurrentTrackerWrapper { + fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> { + self.state.apply(batch); + self.inner.invoke(inferer, batch)?; + self.state.extract(batch); + + Ok(()) + } + + fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.state.inputs.as_ref() + } + + fn output_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.state.outputs.as_ref() + } + + fn begin_agent(&self, inferer: &dyn Inferer, id: u64) { + self.state.per_agent_states.write().insert( + id, + vec![0.0; self.state.agent_state_size].into_boxed_slice(), + ); + self.inner.begin_agent(inferer, id); + } + + fn end_agent(&self, inferer: &dyn Inferer, id: u64) { + self.state.per_agent_states.write().remove(&id); + self.inner.end_agent(inferer, id); + } +} + #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + use crate::{ batcher::ScratchPadView, inferer::State, @@ -202,8 +343,8 @@ mod tests { use super::RecurrentTracker; struct DummyInferer { - end_called: bool, - begin_called: bool, + end_called: AtomicBool, + begin_called: AtomicBool, inputs: Vec<(String, Vec)>, outputs: Vec<(String, Vec)>, } @@ -227,8 +368,8 @@ mod tests { cell_name_out: &str, ) -> Self { Self { - end_called: false, - begin_called: false, + end_called: false.into(), + begin_called: false.into(), inputs: vec![ (hidden_name_in.to_owned(), vec![2, 1]), (cell_name_in.to_owned(), vec![2, 3]), @@ -282,44 +423,44 @@ mod tests { &self.outputs } - fn begin_agent(&mut self, _id: u64) { - self.begin_called = true; + fn begin_agent(&self, _id: u64) { + self.begin_called.store(true, Ordering::Relaxed); } - fn end_agent(&mut self, _id: u64) { - self.end_called = true; + fn end_agent(&self, _id: u64) { + self.end_called.store(true, Ordering::Relaxed); } } #[test] fn begin_end_forwarded() { let inferer = DummyInferer::default(); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); - assert!(recurrent.inner.begin_called); + assert!(recurrent.inner.begin_called.load(Ordering::Relaxed)); recurrent.end_agent(10); - assert!(recurrent.inner.end_called); + assert!(recurrent.inner.end_called.into_inner()); } #[test] fn begin_creates_state() { let inferer = DummyInferer::default(); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); - assert!(recurrent.per_agent_states.read().contains_key(&10)); + assert!(recurrent.state.per_agent_states.read().contains_key(&10)); } #[test] fn end_removes_state() { let inferer = DummyInferer::default(); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); recurrent.end_agent(10); - assert!(!recurrent.per_agent_states.read().contains_key(&10)); + assert!(!recurrent.state.per_agent_states.read().contains_key(&10)); } #[test] @@ -333,7 +474,7 @@ mod tests { fn test_infer() { let inferer = DummyInferer::default(); let mut batcher = Batcher::new(&inferer); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); batcher.push(10, State::empty()).unwrap(); @@ -345,7 +486,7 @@ mod tests { fn test_infer_output() { let inferer = DummyInferer::default(); let mut batcher = Batcher::new(&inferer); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); batcher.push(10, State::empty()).unwrap(); @@ -363,7 +504,7 @@ mod tests { fn test_infer_twice_output() { let inferer = DummyInferer::default(); let mut batcher = Batcher::new(&inferer); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); batcher.push(10, State::empty()).unwrap(); @@ -385,7 +526,7 @@ mod tests { fn test_infer_twice_reuse_id() { let inferer = DummyInferer::default(); let mut batcher = Batcher::new(&inferer); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); batcher.push(10, State::empty()).unwrap(); @@ -411,7 +552,7 @@ mod tests { fn test_infer_multiple_agents() { let inferer = DummyInferer::default(); let mut batcher = Batcher::new(&inferer); - let mut recurrent = RecurrentTracker::wrap(inferer).unwrap(); + let recurrent = RecurrentTracker::wrap(inferer).unwrap(); recurrent.begin_agent(10); recurrent.begin_agent(20); diff --git a/crates/cervo-core/src/wrapper.rs b/crates/cervo-core/src/wrapper.rs new file mode 100644 index 0000000..35a90f0 --- /dev/null +++ b/crates/cervo-core/src/wrapper.rs @@ -0,0 +1,255 @@ +/*! +Inferer wrappers with state separated from the inferer. + +This allows separation of stateful logic from the inner inferer, +allowing the inner inferer to be swapped out while maintaining +state in the wrappers. + +This is an alternative to the old layered inferer setup, which +tightly coupled the inner inferer with the wrapper state. + +```rust,ignore +let inferer = ...; +// the root needs [`BaseCase`] passed as a base case. +let wrappers = RecurrentTrackerWrapper::new(BaseCase, inferer); +let wrapped = StatefulInferer::new(wrappers, infere); +// or +let wrapped = inferer.into_stateful(wrappers); +// or +let wrapped = wrappers.wrap(inferer); +``` +*/ + +use crate::batcher::ScratchPadView; +use crate::inferer::{ + BasicInferer, DynamicInferer, FixedBatchInferer, Inferer, MemoizingDynamicInferer, +}; + +/// A trait for wrapping an inferer with additional functionality. +/// +/// This works similar to the old layered inferer setup, but allows +/// separation of wrapper state from the inner inferer. This allows +/// swapping out the inner inferer while maintaining state in the +/// wrappers. +pub trait InfererWrapper { + /// Returns the input shapes after this wrapper has been applied. + fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)]; + + /// Returns the output shapes after this wrapper has been applied. + fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)]; + + /// Invokes the inner inferer, applying any additional logic before + /// and after the call. + fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()>; + + /// Called when starting inference for a new agent. + fn begin_agent(&self, inferer: &dyn Inferer, id: u64); + + /// Called when finishing inference for an agent. + fn end_agent(&self, inferer: &dyn Inferer, id: u64); +} + +/// A no-op inferer wrapper that just calls the inner inferer directly. This is the base-case of wrapper stack. +pub struct BaseWrapper; + +impl InfererWrapper for BaseWrapper { + /// Returns the input shapes after this wrapper has been applied. + fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + inferer.input_shapes() + } + + /// Returns the output shapes after this wrapper has been applied. + fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + inferer.output_shapes() + } + + /// Invokes the inner inferer. + fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> { + inferer.infer_raw(batch) + } + + fn begin_agent(&self, inferer: &dyn Inferer, id: u64) { + inferer.begin_agent(id); + } + + fn end_agent(&self, inferer: &dyn Inferer, id: u64) { + inferer.end_agent(id); + } +} + +impl InfererWrapper for Box { + fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.as_ref().input_shapes(inferer) + } + + fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec)] { + self.as_ref().output_shapes(inferer) + } + + fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> { + self.as_ref().invoke(inferer, batch) + } + + fn begin_agent(&self, inferer: &dyn Inferer, id: u64) { + self.as_ref().begin_agent(inferer, id); + } + + fn end_agent(&self, inferer: &dyn Inferer, id: u64) { + self.as_ref().end_agent(inferer, id); + } +} + +/// An inferer that maintains state in wrappers around an inferer. +/// +/// This is an alternative to direct wrapping of an inferer, which +/// allows the inner inferer to be swapped out while maintaining +/// state in the wrappers. +pub struct StatefulInferer { + wrapper_stack: WrapStack, + inferer: Inf, +} + +impl StatefulInferer { + pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self { + Self { + wrapper_stack, + inferer, + } + } + + /// Replace the inner inferer with a new inferer while maintaining + /// any state in wrappers. + /// + /// Requires that the shapes of the policies are compatible, but + /// they may be different concrete inferer implementations. If + /// this check fails, will return self unchanged. + pub fn with_new_inferer( + self, + new_inferer: NewInf, + ) -> Result, (Self, anyhow::Error)> { + if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) { + return Err((self, e)); + } + Ok(StatefulInferer { + wrapper_stack: self.wrapper_stack, + inferer: new_inferer, + }) + } + + /// Validate that [`Old`] and [`New`] are compatible with each + /// other. + pub fn check_compatible_shapes( + old: &Old, + new: &New, + ) -> Result<(), anyhow::Error> { + let old_in = old.raw_input_shapes(); + let new_in = new.raw_input_shapes(); + + let old_out = old.raw_output_shapes(); + let new_out = new.raw_output_shapes(); + + for (i, (o, n)) in old_in.iter().zip(new_in).enumerate() { + if o != n { + if o.0 != n.0 { + return Err(anyhow::format_err!( + "name mismatch for input {i}: '{}' != '{}'", + o.0, + n.0, + )); + } + + return Err(anyhow::format_err!( + "shape mismatch for input '{}': {:?} != {:?}", + o.0, + o.1, + n.1, + )); + } + } + + for (i, (o, n)) in old_out.iter().zip(new_out).enumerate() { + if o != n { + if o.0 != n.0 { + return Err(anyhow::format_err!( + "name mismatch for output {i}: '{}' != '{}'", + o.0, + n.0, + )); + } + + return Err(anyhow::format_err!( + "shape mismatch for output {}: {:?} != {:?}", + o.0, + o.1, + n.1, + )); + } + } + + Ok(()) + } + + /// Returns the input shapes after all wrappers have been applied. + pub fn input_shapes(&self) -> &[(String, Vec)] { + self.wrapper_stack.input_shapes(&self.inferer) + } + + /// Returns the output shapes after all wrappers have been applied. + pub fn output_shapes(&self) -> &[(String, Vec)] { + self.wrapper_stack.output_shapes(&self.inferer) + } +} + +/// See [`Inferer`] for documentation. +impl Inferer for StatefulInferer { + fn select_batch_size(&self, max_count: usize) -> usize { + self.inferer.select_batch_size(max_count) + } + + fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> { + self.wrapper_stack.invoke(&self.inferer, batch) + } + + fn raw_input_shapes(&self) -> &[(String, Vec)] { + self.inferer.raw_input_shapes() + } + + fn raw_output_shapes(&self) -> &[(String, Vec)] { + self.inferer.raw_output_shapes() + } + + fn begin_agent(&self, id: u64) { + self.wrapper_stack.begin_agent(&self.inferer, id); + } + + fn end_agent(&self, id: u64) { + self.wrapper_stack.end_agent(&self.inferer, id); + } +} + +/// Extension trait to allow easy wrapping of an inferer with a wrapper stack. +pub trait IntoStateful: Inferer + Sized { + /// Construct a [`StatefulInferer`] by wrapping this concrete + /// inferer with the given wrapper stack. + fn into_stateful( + self, + wrapper_stack: WrapStack, + ) -> StatefulInferer { + StatefulInferer::new(wrapper_stack, self) + } +} + +impl IntoStateful for BasicInferer {} +impl IntoStateful for DynamicInferer {} +impl IntoStateful for MemoizingDynamicInferer {} +impl IntoStateful for FixedBatchInferer {} + +/// Extension trait to allow easy wrapping of an inferer with a wrapper stack. +pub trait InfererWrapperExt: InfererWrapper + Sized { + /// Construct a [`StatefulInferer`] by wrapping an inner inferer with this wrapper. + fn wrap(self, inferer: Inf) -> StatefulInferer { + StatefulInferer::new(self, inferer) + } +} + +impl InfererWrapperExt for T {} diff --git a/crates/cervo-core/tests/batcher.rs b/crates/cervo-core/tests/batcher.rs index f64467a..132ae39 100644 --- a/crates/cervo-core/tests/batcher.rs +++ b/crates/cervo-core/tests/batcher.rs @@ -40,8 +40,8 @@ where &self.out_shapes } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } #[test] diff --git a/crates/cervo-onnx/tests/infer-lstm.rs b/crates/cervo-onnx/tests/infer-lstm.rs index 6e2c092..6c27812 100644 --- a/crates/cervo-onnx/tests/infer-lstm.rs +++ b/crates/cervo-onnx/tests/infer-lstm.rs @@ -20,7 +20,7 @@ fn test_infer_once_recurrent() { ) .unwrap(); - let mut instance = RecurrentTracker::new( + let instance = RecurrentTracker::new( instance, vec![ RecurrentInfo { @@ -61,7 +61,7 @@ fn test_infer_once_recurrent_batched() { ) .unwrap(); - let mut instance = RecurrentTracker::new( + let instance = RecurrentTracker::new( instance, vec![ RecurrentInfo { @@ -101,7 +101,7 @@ fn test_infer_once_recurrent_batched_not_loaded() { ) .unwrap(); - let mut instance = RecurrentTracker::new( + let instance = RecurrentTracker::new( instance, vec![ RecurrentInfo { @@ -140,7 +140,7 @@ fn test_infer_once_recurrent_fixed_batch() { ) .unwrap(); - let mut instance = RecurrentTracker::new( + let instance = RecurrentTracker::new( instance, vec![ RecurrentInfo { diff --git a/crates/cervo-runtime/src/runtime.rs b/crates/cervo-runtime/src/runtime.rs index d3bf465..d0f34ec 100644 --- a/crates/cervo-runtime/src/runtime.rs +++ b/crates/cervo-runtime/src/runtime.rs @@ -319,8 +319,8 @@ mod tests { &[] } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } #[test] diff --git a/crates/cervo-runtime/src/state.rs b/crates/cervo-runtime/src/state.rs index ac4d6b7..36f3f8c 100644 --- a/crates/cervo-runtime/src/state.rs +++ b/crates/cervo-runtime/src/state.rs @@ -168,8 +168,8 @@ mod tests { &[] } - fn begin_agent(&mut self, _id: u64) {} - fn end_agent(&mut self, _id: u64) {} + fn begin_agent(&self, _id: u64) {} + fn end_agent(&self, _id: u64) {} } #[test]