Skip to content

Commit

Permalink
Revert functional changes during rename
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 13, 2024
1 parent 088a477 commit 9268d3c
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 52 deletions.
2 changes: 0 additions & 2 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,3 @@
if len(_log.handlers) == 0:
handler = logging.StreamHandler()
_log.addHandler(handler)

__all__ = ["fit", "MarginalModel", "marginalize", "as_model"]
9 changes: 5 additions & 4 deletions pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from pymc import SymbolicRandomVariable
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Constant, Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorType, TensorVariable
Expand All @@ -17,6 +16,8 @@
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
from pytensor.tensor.type_other import NoneTypeT

from pymc_extras.model.marginal.distributions import MarginalRV


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
Expand Down Expand Up @@ -62,7 +63,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):


def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
if hasattr(op, "support_axes"):
if isinstance(op, MarginalRV):
return op.support_axes
else:
# For vanilla RVs, the support axes are the last ndim_supp
Expand Down Expand Up @@ -145,7 +146,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
var_dims[node.outputs[0]] = output_dims

elif (isinstance(node.op, OpFromGraph) and hasattr(node.op, "support_axes")) or (
elif isinstance(node.op, MarginalRV) or (
isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
):
# MarginalRV and SymbolicRandomVariables without signature are a wild-card,
Expand All @@ -159,7 +160,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
)

support_axes = iter(get_support_axes(op))
if hasattr(op, "support_axes"):
if isinstance(op, MarginalRV):
# The first output is the marginalized variable for which we don't compute support axes
support_axes = itertools.chain(((),), support_axes)
for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)):
Expand Down
2 changes: 2 additions & 0 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pymc_extras.statespace.core.representation import PytensorRepresentation
from pymc_extras.statespace.filters import (
KalmanSmoother,
SquareRootFilter,
StandardFilter,
UnivariateFilter,
)
Expand Down Expand Up @@ -50,6 +51,7 @@
FILTER_FACTORY = {
"standard": StandardFilter,
"univariate": UnivariateFilter,
"cholesky": SquareRootFilter,
}


Expand Down
2 changes: 2 additions & 0 deletions pymc_extras/statespace/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
from pymc_extras.statespace.filters.kalman_filter import (
SquareRootFilter,
StandardFilter,
UnivariateFilter,
)
Expand All @@ -9,5 +10,6 @@
"StandardFilter",
"UnivariateFilter",
"KalmanSmoother",
"SquareRootFilter",
"LinearGaussianStateSpace",
]
4 changes: 2 additions & 2 deletions pymc_extras/statespace/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def make_SARIMA_transition_matrix(
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA
differences, as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences,
as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites
of the states is left an exercise for the motivated reader:
Expand Down
5 changes: 0 additions & 5 deletions tests/statespace/test_coord_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
from pymc_extras.statespace.utils.constants import (
FILTER_OUTPUT_DIMS,
FILTER_OUTPUT_NAMES,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MISSING_FILL,
SHORT_NAME_TO_LONG,
SMOOTHER_OUTPUT_NAMES,
TIME_DIM,
)
from pymc_extras.statespace.utils.data_tools import (
NO_FREQ_INFO_WARNING,
NO_TIME_INDEX_WARNING,
register_data_with_pymc,
)
from tests.statespace.utilities.test_helpers import load_nile_test_data

Expand Down
6 changes: 1 addition & 5 deletions tests/statespace/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
)
from pymc_extras.statespace.utils.constants import (
ALL_STATE_DIM,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MISSING_FILL,
OBS_STATE_DIM,
SHORT_NAME_TO_LONG,
TIME_DIM,
)
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
Expand All @@ -41,7 +37,7 @@

filter_names = [
"standard",
# "cholesky",
"cholesky",
"univariate",
]

Expand Down
19 changes: 10 additions & 9 deletions tests/statespace/test_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StandardFilter,
UnivariateFilter,
)
from pymc_extras.statespace.filters.kalman_filter import BaseFilter
from pymc_extras.statespace.filters.kalman_filter import BaseFilter, SquareRootFilter
from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import
rng,
)
Expand All @@ -30,17 +30,18 @@
RTOL = 1e-6 if floatX.endswith("64") else 1e-3

standard_inout = initialize_filter(StandardFilter())
# cholesky_inout = initialize_filter(CholeskyFilter())
cholesky_inout = initialize_filter(SquareRootFilter())
univariate_inout = initialize_filter(UnivariateFilter())

f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
# f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")

filter_funcs = [f_standard, f_univariate]
filter_funcs = [f_standard, f_cholesky, f_univariate]

filter_names = [
"StandardFilter",
"CholeskyFilter",
"UnivariateFilter",
]

Expand Down Expand Up @@ -229,8 +230,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
# if filter_name == "CholeskyFilter":
# P0 = np.linalg.cholesky(P0)
if filter_name == "CholeskyFilter":
P0 = np.linalg.cholesky(P0)
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
outputs = filter_func(*inputs)

Expand Down Expand Up @@ -278,8 +279,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
pytest.skip("Univariate filter not stable at half precision without measurement error")

fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
# if filter_name == "CholeskyFilter":
# P0 = np.linalg.cholesky(P0)
if filter_name == "CholeskyFilter":
P0 = np.linalg.cholesky(P0)

H *= int(obs_noise)
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
Expand All @@ -301,7 +302,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob

@pytest.mark.parametrize(
"filter",
[StandardFilter],
[StandardFilter, SquareRootFilter],
ids=["standard"],
)
def test_kalman_filter_jax(filter):
Expand Down
6 changes: 0 additions & 6 deletions tests/statespace/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
from pymc_extras.statespace.models.utilities import make_default_coords
from pymc_extras.statespace.utils.constants import (
FILTER_OUTPUT_NAMES,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MATRIX_NAMES,
MISSING_FILL,
NEVER_TIME_VARYING,
SHORT_NAME_TO_LONG,
SMOOTHER_OUTPUT_NAMES,
VECTOR_VALUED,
)
from tests.statespace.utilities.shared_fixtures import (
rng,
Expand Down
4 changes: 0 additions & 4 deletions tests/statespace/test_statespace_JAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@

from pymc_extras.statespace.utils.constants import (
FILTER_OUTPUT_NAMES,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MATRIX_NAMES,
MISSING_FILL,
SHORT_NAME_TO_LONG,
SMOOTHER_OUTPUT_NAMES,
)
from tests.statespace.test_statespace import ( # pylint: disable=unused-import
Expand Down
3 changes: 0 additions & 3 deletions tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
AR_PARAM_DIM,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MISSING_FILL,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
Expand Down
2 changes: 0 additions & 2 deletions tests/statespace/utilities/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
from pymc_extras.statespace.utils.constants import (
JITTER_DEFAULT,
MATRIX_NAMES,
MISSING_FILL,
SHORT_NAME_TO_LONG,
)
from tests.statespace.utilities.statsmodel_local_level import LocalLinearTrend
Expand Down
6 changes: 3 additions & 3 deletions tests/test_blackjax_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from numpy import dtype
from xarray.core.utils import Frozen

jax = pytest.importorskip("jax")
pytest.importorskip("blackjax")

from pymc_extras.inference.smc.sampling import (
arviz_from_particles,
blackjax_particles_from_pymc_population,
Expand All @@ -28,9 +31,6 @@
sample_smc_blackjax,
)

jax = pytest.importorskip("jax")
pytest.importorskip("blackjax")


def two_gaussians_model():
n = 4
Expand Down
8 changes: 1 addition & 7 deletions tests/test_find_map.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Literal

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_extras.inference.find_map import find_MAP, scipy_optimize_funcs_from_loss
from pymc_extras.inference.find_map import GradientBackend, find_MAP, scipy_optimize_funcs_from_loss

pytest.importorskip("jax")

Expand All @@ -16,10 +14,6 @@ def rng():
return np.random.default_rng(seed)


# Define GradientBackend type alias
GradientBackend = Literal["jax", "pytensor"]


@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
x = pt.tensor("x", shape=(2,))
Expand Down

0 comments on commit 9268d3c

Please sign in to comment.