diff --git a/CHANGELOG.md b/CHANGELOG.md index b091711..d002e04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate + +- Add `StatefulInferer::replace_inferer` which works with a `&mut + StatefulInferer`, at the cost of requiring the inferer to be of the + same type. + ## [0.9.0] - 2025-09-04 - Breaking: `Inferer.begin_agent` and `Inferer.end_agent` now take diff --git a/crates/cervo-core/src/wrapper.rs b/crates/cervo-core/src/wrapper.rs index 35a90f0..80a9f80 100644 --- a/crates/cervo-core/src/wrapper.rs +++ b/crates/cervo-core/src/wrapper.rs @@ -110,6 +110,8 @@ pub struct StatefulInferer { } impl StatefulInferer { + /// Construct a new [`StatefulInferer`] by wrapping the given + /// inferer with the given wrapper stack. pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self { Self { wrapper_stack, @@ -121,8 +123,8 @@ impl StatefulInferer { /// any state in wrappers. /// /// Requires that the shapes of the policies are compatible, but - /// they may be different concrete inferer implementations. If - /// this check fails, will return self unchanged. + /// they may be different inferer types. If this check fails, will + /// return self unchanged. pub fn with_new_inferer( self, new_inferer: NewInf, @@ -136,6 +138,22 @@ impl StatefulInferer { }) } + /// Replace the inner inferer with a new inferer while maintaining + /// any state in wrappers. + /// + /// Requires that the shapes of the policies are compatible If + /// this check fails, will not change self. Compared to + /// [`with_new_inferer`], also requires that the new inferer has + /// the same type as the old one. + pub fn replace_inferer(&mut self, new_inferer: Inf) -> anyhow::Result<()> { + if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) { + Err(e) + } else { + self.inferer = new_inferer; + Ok(()) + } + } + /// Validate that [`Old`] and [`New`] are compatible with each /// other. pub fn check_compatible_shapes(