Skip to content

Commit

Permalink
switch to use NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jun 25, 2024
1 parent 216297b commit 8b5d8b0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 40 deletions.
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
hooks:
- id: mypy
files: ^src
9 changes: 5 additions & 4 deletions src/mrtool/core/cov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import xspline
from numpy.typing import NDArray

from . import utils
from .data import MRData
Expand Down Expand Up @@ -438,7 +439,7 @@ def has_data(self):
return True

def create_spline(
self, data: MRData, spline_knots: np.ndarray = None
self, data: MRData, spline_knots: NDArray | None = None
) -> xspline.XSpline:
"""Create spline given current spline parameters.
Parameters
Expand Down Expand Up @@ -525,7 +526,7 @@ def create_spline(

return spline

def create_design_mat(self, data) -> tuple[np.ndarray, np.ndarray]:
def create_design_mat(self, data) -> tuple[NDArray, NDArray]:
"""Create design matrix.
Parameters
----------
Expand Down Expand Up @@ -564,7 +565,7 @@ def create_z_mat(self, data):
"Cannot use create_z_mat directly in CovModel class."
)

def create_constraint_mat(self) -> tuple[np.ndarray, np.ndarray]:
def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
"""Create constraint matrix.
Returns:
tuple{numpy.ndarray, numpy.ndarray}:
Expand Down Expand Up @@ -679,7 +680,7 @@ def create_constraint_mat(self) -> tuple[np.ndarray, np.ndarray]:

return c_mat, c_val

def create_regularization_mat(self) -> tuple[np.ndarray, np.ndarray]:
def create_regularization_mat(self) -> tuple[NDArray, NDArray]:
"""Create constraint matrix.
Returns:
tuple{numpy.ndarray, numpy.ndarray}:
Expand Down
21 changes: 11 additions & 10 deletions src/mrtool/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from .utils import empty_array, expand_array, is_numeric_array, to_list

Expand All @@ -20,11 +21,11 @@
class MRData:
"""Data for simple linear mixed effects model."""

obs: np.ndarray = field(default_factory=empty_array)
obs_se: np.ndarray = field(default_factory=empty_array)
covs: dict[str, np.ndarray] = field(default_factory=dict)
study_id: np.ndarray = field(default_factory=empty_array)
data_id: np.ndarray = field(default_factory=empty_array)
obs: NDArray = field(default_factory=empty_array)
obs_se: NDArray = field(default_factory=empty_array)
covs: dict[str, NDArray] = field(default_factory=dict)
study_id: NDArray = field(default_factory=empty_array)
data_id: NDArray = field(default_factory=empty_array)
cov_scales: dict[str, float] = field(init=False, default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -121,7 +122,7 @@ def _get_study_structure(self):
)
self._sort_by_study_id()

def _sort_data(self, index: np.ndarray):
def _sort_data(self, index: NDArray):
"""Sort the object.
Parameters
Expand Down Expand Up @@ -166,7 +167,7 @@ def _remove_nan_in_covs(self):
index = index | cov_index
self._remove_data(index)

def _remove_data(self, index: np.ndarray):
def _remove_data(self, index: NDArray):
"""Remove the data point by index.
Parameters
Expand All @@ -186,7 +187,7 @@ def _remove_data(self, index: np.ndarray):
self.study_id = self.study_id[keep_index]
self.data_id = self.data_id[keep_index]

def _get_data(self, index: np.ndarray) -> "MRData":
def _get_data(self, index: NDArray) -> "MRData":
"""Get the data point by index.
Parameters
Expand Down Expand Up @@ -383,7 +384,7 @@ def _assert_has_studies(self, studies: list[Any] | Any):
f"MRData object do not contain studies: {missing_studies}."
)

def get_covs(self, covs: list[str] | str) -> np.ndarray:
def get_covs(self, covs: list[str] | str) -> NDArray:
"""Get covariate matrix.
Parameters
Expand All @@ -393,7 +394,7 @@ def get_covs(self, covs: list[str] | str) -> np.ndarray:
Returns
-------
np.ndarray
NDArray
Covariates matrix, in the column fashion.
"""
Expand Down
18 changes: 11 additions & 7 deletions src/mrtool/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from copy import deepcopy
from typing import Sequence

import numpy as np
import pandas as pd
Expand All @@ -22,7 +23,10 @@ class MRBRT:
"""MR-BRT Object"""

def __init__(
self, data: MRData, cov_models: list[CovModel], inlier_pct: float = 1.0
self,
data: MRData,
cov_models: Sequence[CovModel],
inlier_pct: float = 1.0,
):
"""Constructor of MRBRT.
Expand Down Expand Up @@ -80,12 +84,12 @@ def __init__(
)

# place holder for the limetr objective
self.lt = None
self.beta_soln = None
self.gamma_soln = None
self.u_soln = None
self.w_soln = None
self.re_soln = None
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def attach_data(self, data=None):
"""Attach data to cov_model."""
Expand Down
33 changes: 17 additions & 16 deletions src/mrtool/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from numpy.typing import NDArray


def get_cols(df, cols):
Expand Down Expand Up @@ -124,7 +125,7 @@ def sizes_to_indices(sizes):
Returns
-------
list[np.ndarray]
list[NDArray]
list the indices.
"""
Expand Down Expand Up @@ -325,29 +326,29 @@ def avg_integral(mat, spline=None, use_spline_intercept=False):
# random knots
def sample_knots(
num_knots: int,
knot_bounds: np.ndarray,
min_dist: float | np.ndarray,
knot_bounds: NDArray,
min_dist: float | NDArray,
num_samples: int = 1,
) -> np.ndarray:
) -> NDArray:
"""Sample knot vectors given a set of rules.
Parameters
----------
num_knots : int
Number of interior knots.
knot_bounds : np.ndarray, shape(2,) or shape(`num_knots`,2)
knot_bounds : NDArray, shape(2,) or shape(`num_knots`,2)
Lower and upper bounds for knots. If shape(2,), boundary knots
placed at `knot_bounds[0]` and `knot_bounds[1]`. If
shape(`num_knots`,2), boundary knots placed at
`knot_bounds[0, 0]` and `knot_bounds[-1, 1]`.
min_dist : float or np.ndarray, shape(`num_knots`+1,)
min_dist : float or NDArray, shape(`num_knots`+1,)
Minimum distances between knots.
num_samples : int, optional
Number of knot vectors to sample. Default is 1.
Returns
-------
np.ndarray, shape(`num_samples`,`num_knots`+2)
NDArray, shape(`num_samples`,`num_knots`+2)
Sampled knot vectors.
"""
Expand Down Expand Up @@ -380,7 +381,7 @@ def _check_nums(num_name: str, num_val: int) -> None:
raise ValueError(f"{num_name} must be at least 1")


def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
def _check_knot_bounds(num_knots: int, knot_bounds: NDArray) -> NDArray:
"""Check knot_bounds."""
try:
knot_bounds = np.asarray(knot_bounds, dtype=float)
Expand All @@ -399,7 +400,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
return knot_bounds


def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
def _check_min_dist(num_knots: int, min_dist: float | NDArray) -> NDArray:
"""Check knot min_dist."""
if np.isscalar(min_dist):
min_dist = np.tile(min_dist, num_knots + 1)
Expand All @@ -415,8 +416,8 @@ def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:


def _check_feasibility(
num_knots: int, knot_bounds: np.ndarray, min_dist: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
num_knots: int, knot_bounds: NDArray, min_dist: NDArray
) -> tuple[NDArray, NDArray]:
"""Check knot feasibility and get left and right boundaries."""
if np.sum(min_dist) > knot_bounds[-1, 1] - knot_bounds[0, 0]:
raise ValueError("min_dist cannot exceed knot_bounds")
Expand Down Expand Up @@ -561,7 +562,7 @@ def to_list(obj: Any) -> list[Any]:
return [obj]


def is_numeric_array(array: np.ndarray) -> bool:
def is_numeric_array(array: NDArray) -> bool:
"""Check if an array is numeric.
Parameters
Expand Down Expand Up @@ -590,8 +591,8 @@ def is_numeric_array(array: np.ndarray) -> bool:


def expand_array(
array: np.ndarray, shape: tuple[int], value: Any, name: str
) -> np.ndarray:
array: NDArray, shape: tuple[int], value: Any, name: str
) -> NDArray:
"""Expand array when it is empty.
Parameters
Expand All @@ -608,7 +609,7 @@ def expand_array(
Returns
-------
np.ndarray
NDArray
Expanded array.
"""
Expand All @@ -630,7 +631,7 @@ def expand_array(
def ravel_dict(x: dict) -> dict:
"""Ravel dictionary."""
assert all([isinstance(k, str) for k in x.keys()])
assert all([isinstance(v, np.ndarray) for v in x.values()])
assert all([isinstance(v, NDArray) for v in x.values()])
new_x = {}
for k, v in x.items():
if v.size == 1:
Expand Down
6 changes: 3 additions & 3 deletions src/mrtool/cov_selection/covfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
power_step_size: float = 0.5,
inlier_pct: float = 1.0,
alpha: float = 0.05,
beta_gprior: dict[str, np.ndarray] = None,
beta_gprior: dict[str, np.ndarray] | None = None,
beta_gprior_std: float = 1.0,
bias_zero: bool = False,
use_re: dict | None = None,
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self.power_step_size = power_step_size
self.powers = np.arange(*self.power_range, self.power_step_size)

self.num_covs = len(pre_selected_covs) + len(covs)
self.num_covs = len(self.all_covs)
if len(covs) == 0:
warnings.warn(
"There is no covariates to select, will return the pre-selected covariates."
Expand All @@ -117,7 +117,7 @@ def create_model(
self,
covs: list[str],
prior_type: str = "Laplace",
laplace_std: float = None,
laplace_std: float | None = None,
) -> MRBRT:
"""Create Gaussian or Laplace model.
Expand Down

0 comments on commit 8b5d8b0

Please sign in to comment.