Skip to content

Commit

Permalink
Update MyPy 14 (#210)
Browse files Browse the repository at this point in the history
* move mypy config

* some fixes

* some fixes

* some fixes

* some fixes

* some fixes

* some fixes

* remove reference np.float64

* remove unnesserary casting

* fix type

* fix import
  • Loading branch information
juanitorduz authored Dec 26, 2024
1 parent 8b536b9 commit b84ba1c
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 100 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.4
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
15 changes: 0 additions & 15 deletions mypy.ini

This file was deleted.

6 changes: 2 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_prior: Optional[npt.NDArray] = None,
split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
Expand Down Expand Up @@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
return mean


def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
73 changes: 42 additions & 31 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
from numba import njit
from pymc.initial_point import PointType
from pymc.model import Model, modelcontext
Expand Down Expand Up @@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
"tune": (bool, []),
}

def __init__( # noqa: PLR0915
def __init__( # noqa: PLR0912, PLR0915
self,
vars=None, # pylint: disable=redefined-builtin
vars: list[pm.Distribution] | None = None,
num_particles: int = 10,
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
):
compile_kwargs: dict | None = None,
) -> None:
model = modelcontext(model)
if initial_point is None:
initial_point = model.initial_point()
Expand All @@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)

if vars is None:
raise ValueError("Unable to find variables to sample")

value_bart = vars[0]
self.bart = model.values_to_rvs[value_bart].owner.op

Expand Down Expand Up @@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> list[ParticleTree]:
"""
Use systematic resample for all but the first particle
Expand All @@ -347,7 +353,7 @@ def resample(
return particles

def get_particle_tree(
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
) -> tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
Expand All @@ -359,7 +365,7 @@ def get_particle_tree(

return new_particle, new_particle.tree

def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
"""
Systematic resampling.
Expand Down Expand Up @@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
particle.log_weight = new_likelihood

@staticmethod
def competence(var, has_grad):
def competence(var: pm.Distribution, has_grad: bool) -> Competence:
"""PGBART is only suitable for BART distributions."""
dist = getattr(var.owner, "op", None)
if isinstance(dist, BARTRV):
Expand All @@ -406,12 +412,12 @@ def competence(var, has_grad):
class RunningSd:
"""Welford's online algorithm for computing the variance/standard deviation"""

def __init__(self, shape: tuple) -> None:
def __init__(self, shape: tuple[int, ...]) -> None:
self.count = 0 # number of data points
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
@njit
def _update(
count: int,
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
mean: npt.NDArray,
m_2: npt.NDArray,
new_value: npt.NDArray,
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -434,7 +440,7 @@ def _update(


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
def __init__(self, alpha_vec: npt.NDArray) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.
Expand Down Expand Up @@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float64],
x_mu: npt.NDArray[np.float64],
y_mu_pred: npt.NDArray,
x_mu: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
norm: npt.NDArray,
shape: int,
response: str,
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
mu_mean: npt.NDArray
if y_mu_pred.size == 0:
return np.zeros(shape), linear_params

Expand All @@ -571,7 +577,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand All @@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float

@njit
def fast_linear_fit(
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
x: npt.NDArray,
y: npt.NDArray,
m: int,
norm: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
norm: npt.NDArray,
) -> tuple[npt.NDArray, list[npt.NDArray]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)

Expand Down Expand Up @@ -678,17 +684,17 @@ def update(self):

@njit
def inverse_cdf(
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
single_uniform: npt.NDArray, normalized_weights: npt.NDArray
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.
Parameters
----------
single_uniform: npt.NDArray[np.float64]
single_uniform: npt.NDArray
Ordered points in [0,1]
normalized_weights: npt.NDArray[np.float64])
normalized_weights: npt.NDArray)
Normalized weights
Returns
Expand All @@ -711,7 +717,7 @@ def inverse_cdf(


@njit
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
"""
Jitter duplicated values.
"""
Expand All @@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray


@njit
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
def are_whole_number(array: npt.NDArray) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)


def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
def logp(
point,
out_vars: list[pm.Distribution],
vars: list[pm.Distribution],
shared: list[pt.TensorVariable],
):
"""Compile PyTensor function of the model and the input and output variables.
Parameters
Expand Down
Loading

0 comments on commit b84ba1c

Please sign in to comment.