Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MyPy 14 #210

Merged
merged 11 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.4
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
15 changes: 0 additions & 15 deletions mypy.ini

This file was deleted.

6 changes: 2 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_prior: Optional[npt.NDArray] = None,
split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
Expand Down Expand Up @@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
return mean


def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
73 changes: 42 additions & 31 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
from numba import njit
from pymc.initial_point import PointType
from pymc.model import Model, modelcontext
Expand Down Expand Up @@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
"tune": (bool, []),
}

def __init__( # noqa: PLR0915
def __init__( # noqa: PLR0912, PLR0915
self,
vars=None, # pylint: disable=redefined-builtin
vars: list[pm.Distribution] | None = None,
num_particles: int = 10,
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
):
compile_kwargs: dict | None = None,
) -> None:
model = modelcontext(model)
if initial_point is None:
initial_point = model.initial_point()
Expand All @@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)

if vars is None:
raise ValueError("Unable to find variables to sample")

value_bart = vars[0]
self.bart = model.values_to_rvs[value_bart].owner.op

Expand Down Expand Up @@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> list[ParticleTree]:
"""
Use systematic resample for all but the first particle
Expand All @@ -347,7 +353,7 @@ def resample(
return particles

def get_particle_tree(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
Expand All @@ -359,7 +365,7 @@ def get_particle_tree(

return new_particle, new_particle.tree

def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
"""
Systematic resampling.

Expand Down Expand Up @@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
particle.log_weight = new_likelihood

@staticmethod
def competence(var, has_grad):
def competence(var: pm.Distribution, has_grad: bool) -> Competence:
"""PGBART is only suitable for BART distributions."""
dist = getattr(var.owner, "op", None)
if isinstance(dist, BARTRV):
Expand All @@ -406,12 +412,12 @@ def competence(var, has_grad):
class RunningSd:
"""Welford's online algorithm for computing the variance/standard deviation"""

def __init__(self, shape: tuple) -> None:
def __init__(self, shape: tuple[int, ...]) -> None:
self.count = 0 # number of data points
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
@njit
def _update(
count: int,
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
mean: npt.NDArray,
m_2: npt.NDArray,
new_value: npt.NDArray,
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -434,7 +440,7 @@ def _update(


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
def __init__(self, alpha_vec: npt.NDArray) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.

Expand Down Expand Up @@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float64],
x_mu: npt.NDArray[np.float64],
y_mu_pred: npt.NDArray,
x_mu: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
norm: npt.NDArray,
shape: int,
response: str,
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
mu_mean: npt.NDArray
if y_mu_pred.size == 0:
return np.zeros(shape), linear_params

Expand All @@ -571,7 +577,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand All @@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float

@njit
def fast_linear_fit(
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
x: npt.NDArray,
y: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
norm: npt.NDArray,
) -> tuple[npt.NDArray, list[npt.NDArray]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)

Expand Down Expand Up @@ -678,17 +684,17 @@ def update(self):

@njit
def inverse_cdf(
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
single_uniform: npt.NDArray, normalized_weights: npt.NDArray
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.

Parameters
----------
single_uniform: npt.NDArray[np.float64]
single_uniform: npt.NDArray
Ordered points in [0,1]

normalized_weights: npt.NDArray[np.float64])
normalized_weights: npt.NDArray)
Normalized weights

Returns
Expand All @@ -711,7 +717,7 @@ def inverse_cdf(


@njit
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
"""
Jitter duplicated values.
"""
Expand All @@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray


@njit
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
def are_whole_number(array: npt.NDArray) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)


def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
def logp(
point,
out_vars: list[pm.Distribution],
vars: list[pm.Distribution],
shared: list[pt.TensorVariable],
):
"""Compile PyTensor function of the model and the input and output variables.

Parameters
Expand Down
Loading
Loading