Skip to content

Commit

Permalink
refactor: Remove unnecessary stats structs and add some transform stats
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Nov 28, 2024
1 parent 2a55602 commit 0b2913e
Show file tree
Hide file tree
Showing 17 changed files with 933 additions and 932 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
name = "nuts-rs"
version = "0.13.0"
authors = [
"Adrian Seyboldt <[email protected]>",
"PyMC Developers <[email protected]>",
"Adrian Seyboldt <[email protected]>",
"PyMC Developers <[email protected]>",
]
edition = "2021"
license = "MIT"
Expand All @@ -22,12 +22,12 @@ rand = { version = "0.8.5", features = ["small_rng"] }
rand_distr = "0.4.3"
multiversion = "0.7.2"
itertools = "0.13.0"
thiserror = "1.0.43"
thiserror = "2.0.3"
arrow = { version = "53.1.0", default-features = false, features = ["ffi"] }
rand_chacha = "0.3.1"
anyhow = "1.0.72"
faer = { version = "0.19.4", default-features = false, features = ["std"] }
pulp = "0.18.21"
pulp = "0.19.6"
rayon = "1.10.0"

[dev-dependencies]
Expand Down
77 changes: 43 additions & 34 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ use crate::{
state::State,
stepsize::AcceptanceRateCollector,
stepsize_adapt::{
DualAverageSettings, Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder,
Strategy as StepSizeStrategy,
DualAverageSettings, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy,
},
NutsError,
};
Expand Down Expand Up @@ -63,20 +62,18 @@ impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
}

impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<M, A> {
type Stats = CombinedStats<StepSizeStats, A::Stats>;
type Builder = CombinedStatsBuilder<StepSizeStatsBuilder, A::Builder>;
type Builder = GlobalStrategyBuilder<A::Builder>;
type StatOptions = <A as SamplerStats<M>>::StatOptions;

fn current_stats(&self, math: &mut M) -> Self::Stats {
CombinedStats {
stats1: self.step_size.current_stats(math),
stats2: self.mass_matrix.current_stats(math),
}
}

fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
CombinedStatsBuilder {
stats1: SamplerStats::<M>::new_builder(&self.step_size, settings, dim),
stats2: self.mass_matrix.new_builder(settings, dim),
fn new_builder(
&self,
options: Self::StatOptions,
settings: &impl Settings,
dim: usize,
) -> Self::Builder {
GlobalStrategyBuilder {
step_size: SamplerStats::<M>::new_builder(&self.step_size, (), settings, dim),
mass_matrix: self.mass_matrix.new_builder(options, settings, dim),
}
}
}
Expand Down Expand Up @@ -218,33 +215,37 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
fn is_tuning(&self) -> bool {
self.tuning
}
}

#[derive(Debug, Clone)]
pub struct CombinedStats<D1, D2> {
pub stats1: D1,
pub stats2: D2,
fn last_num_steps(&self) -> u64 {
self.step_size.last_n_steps
}
}

#[derive(Clone)]
pub struct CombinedStatsBuilder<B1, B2> {
pub stats1: B1,
pub stats2: B2,
pub struct GlobalStrategyBuilder<B> {
pub step_size: StepSizeStatsBuilder,
pub mass_matrix: B,
}

impl<S1, S2, B1, B2> StatTraceBuilder<CombinedStats<S1, S2>> for CombinedStatsBuilder<B1, B2>
impl<M: Math, A> StatTraceBuilder<M, GlobalStrategy<M, A>> for GlobalStrategyBuilder<A::Builder>
where
B1: StatTraceBuilder<S1>,
B2: StatTraceBuilder<S2>,
A: MassMatrixAdaptStrategy<M>,
{
fn append_value(&mut self, value: CombinedStats<S1, S2>) {
self.stats1.append_value(value.stats1);
self.stats2.append_value(value.stats2);
fn append_value(&mut self, math: Option<&mut M>, value: &GlobalStrategy<M, A>) {
let math = math.expect("Smapler stats need math");
self.step_size.append_value(Some(math), &value.step_size);
self.mass_matrix
.append_value(Some(math), &value.mass_matrix);
}

fn finalize(self) -> Option<StructArray> {
let Self { stats1, stats2 } = self;
match (stats1.finalize(), stats2.finalize()) {
let Self {
step_size,
mass_matrix,
} = self;
match (
StatTraceBuilder::<M, _>::finalize(step_size),
mass_matrix.finalize(),
) {
(None, None) => None,
(Some(stats1), None) => Some(stats1),
(None, Some(stats2)) => Some(stats2),
Expand All @@ -266,8 +267,14 @@ where
}

fn inspect(&self) -> Option<StructArray> {
let Self { stats1, stats2 } = self;
match (stats1.inspect(), stats2.inspect()) {
let Self {
step_size,
mass_matrix,
} = self;
match (
StatTraceBuilder::<M, _>::inspect(step_size),
mass_matrix.inspect(),
) {
(None, None) => None,
(Some(stats1), None) => Some(stats1),
(None, Some(stats2)) => Some(stats2),
Expand Down Expand Up @@ -374,6 +381,7 @@ pub mod test_logps {

#[derive(Error, Debug)]
pub enum NormalLogpError {}

impl LogpError for NormalLogpError {
fn is_recoverable(&self) -> bool {
false
Expand Down Expand Up @@ -438,6 +446,7 @@ pub mod test_logps {
_rng: &mut R,
_untransformed_positions: impl Iterator<Item = &'a [f64]>,
_untransformed_gradients: impl Iterator<Item = &'a [f64]>,
_untransformed_logp: impl Iterator<Item = &'a f64>,
_params: &'a mut Self::TransformParams,
) -> Result<(), Self::LogpError> {
unimplemented!()
Expand Down
Loading

0 comments on commit 0b2913e

Please sign in to comment.