Skip to content

Commit

Permalink
Add transforming adaptation for stan
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 25, 2024
1 parent 9235e21 commit 4498fbf
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 470 deletions.
18 changes: 16 additions & 2 deletions python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from functools import partial
import json
import tempfile
from dataclasses import dataclass, replace
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Callable

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from nutpie import _lib
from nutpie.sample import CompiledModel
from nutpie.transform_adapter import make_transform_adapter


class _NumpyArrayEncoder(json.JSONEncoder):
Expand All @@ -28,6 +30,7 @@ class CompiledStanModel(CompiledModel):
library: Any
model: Any
model_name: Optional[str] = None
_transform_adapt_args: dict | None = None

def with_data(self, *, seed=None, **updates):
if self.data is None:
Expand All @@ -42,7 +45,15 @@ def with_data(self, *, seed=None, **updates):
else:
data_json = None

model = _lib.StanModel(self.library, seed, data_json)
kwargs = self._transform_adapt_args
if kwargs is None:
kwargs = {}
make_adapter = partial(
make_transform_adapter(**kwargs),
logp_fn=None,
)

model = _lib.StanModel(self.library, seed, data_json, make_adapter)
coords = self._coords
if coords is None:
coords = {}
Expand Down Expand Up @@ -75,6 +86,9 @@ def with_dims(self, **dims):
dims_new.update(dims)
return replace(self, dims=dims_new)

def with_transform_adapt(self, **kwargs):
return replace(self, _transform_adapt_args=kwargs).with_data()

def _make_model(self, init_mean):
if self.model is None:
return self.with_data().model
Expand Down
Loading

0 comments on commit 4498fbf

Please sign in to comment.