diff --git a/CHANGELOG.md b/CHANGELOG.md index d002e04..e62dc3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `StatefulInferer::replace_inferer` which works with a `&mut StatefulInferer`, at the cost of requiring the inferer to be of the same type. +- Fix bugs in the new wrapper setup where consumed and modified shapes + weren't respected during wrapper construction. ## [0.9.0] - 2025-09-04 diff --git a/crates/cervo-core/src/epsilon.rs b/crates/cervo-core/src/epsilon.rs index 89aaaed..b414785 100644 --- a/crates/cervo-core/src/epsilon.rs +++ b/crates/cervo-core/src/epsilon.rs @@ -259,7 +259,7 @@ where generator: NG, key: &str, ) -> Result { - let inputs = inferer.input_shapes(); + let inputs = inner.input_shapes(inferer); let (index, count) = match inputs.iter().enumerate().find(|(_, (k, _))| k == key) { Some((index, (_, shape))) => (index, shape.iter().product()), diff --git a/crates/cervo-core/src/recurrent.rs b/crates/cervo-core/src/recurrent.rs index 7cc4243..ac8ed2d 100644 --- a/crates/cervo-core/src/recurrent.rs +++ b/crates/cervo-core/src/recurrent.rs @@ -248,8 +248,8 @@ impl RecurrentTrackerWrapper { /// Create a new recurrency tracker for the model. /// pub fn new(inner: Inner, inferer: &T, info: Vec) -> Result { - let inputs = inferer.raw_input_shapes(); - let outputs = inferer.raw_output_shapes(); + let inputs = inner.input_shapes(inferer); + let outputs = inner.output_shapes(inferer); let mut offset = 0; let keys = info diff --git a/deny.toml b/deny.toml index 197263a..cfd9610 100644 --- a/deny.toml +++ b/deny.toml @@ -45,6 +45,7 @@ db-urls = ["https://github.com/rustsec/advisory-db"] ignore = [ #"RUSTSEC-0000-0000", "RUSTSEC-2024-0436", + "RUSTSEC-2025-0056", ] # Threshold for security vulnerabilities, any vulnerability with a CVSS score # lower than the range specified will be ignored. Note that ignored advisories