Skip to content

Commit

Permalink
fix: Some rebase issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Nov 14, 2024
1 parent ea15908 commit 7b4dd2e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
5 changes: 2 additions & 3 deletions python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ def make_expand_func(seed1, seed2, chain):
make_expand_func,
self._variables,
self.n_dim,
self._make_initial_points,
make_transform_adapter,
make_adapter,
init_point_func=self._make_initial_points,
transform_adapter=make_adapter,
)


Expand Down
8 changes: 4 additions & 4 deletions src/pyfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use pyo3::{
Bound, Py, PyAny, PyErr, Python,
};
use rand::Rng;
use rand_distr::{Distribution, StandardNormal, Uniform};
use rand_distr::{Distribution, Uniform};
use smallvec::SmallVec;
use thiserror::Error;

Expand Down Expand Up @@ -76,7 +76,7 @@ impl PyVariable {
pub struct PyModel {
make_logp_func: Arc<Py<PyAny>>,
make_expand_func: Arc<Py<PyAny>>,
init_point_func: Arc<Option<Py<PyAny>>>,
init_point_func: Option<Arc<Py<PyAny>>>,
variables: Arc<Vec<PyVariable>>,
transform_adapter: Option<PyTransformAdapt>,
ndim: usize,
Expand All @@ -85,7 +85,7 @@ pub struct PyModel {
#[pymethods]
impl PyModel {
#[new]
#[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, transform_adapter=None))]
#[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))]
fn new<'py>(
make_logp_func: Py<PyAny>,
make_expand_func: Py<PyAny>,
Expand All @@ -97,7 +97,7 @@ impl PyModel {
Self {
make_logp_func: Arc::new(make_logp_func),
make_expand_func: Arc::new(make_expand_func),
init_point_func: Arc::new(init_point_func),
init_point_func: init_point_func.map(|x| x.into()),
variables: Arc::new(variables),
ndim,
transform_adapter: transform_adapter.map(PyTransformAdapt::new),
Expand Down
5 changes: 2 additions & 3 deletions src/pymc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use pyo3::{
types::{PyAnyMethods, PyList},
Bound, Py, PyAny, PyObject, PyResult, Python,
};
use rand::{distributions::Uniform, prelude::Distribution};

use thiserror::Error;

Expand Down Expand Up @@ -232,7 +231,7 @@ pub(crate) struct PyMcModel {
dim: usize,
density: LogpFunc,
expand: ExpandFunc,
init_func: Py<PyAny>,
init_func: Arc<Py<PyAny>>,
var_sizes: Vec<usize>,
var_names: Vec<String>,
}
Expand All @@ -252,7 +251,7 @@ impl PyMcModel {
dim,
density,
expand,
init_func,
init_func: init_func.into(),
var_names: var_names.extract()?,
var_sizes: var_sizes.extract()?,
})
Expand Down

0 comments on commit 7b4dd2e

Please sign in to comment.