Skip to content

Commit

Permalink
infer shapes after attach data
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jun 27, 2024
1 parent c8f70bd commit a2eed13
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/mrtool/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def __init__(
self.cov_names.extend(cov_model.covs)
self.num_covs = len(self.cov_names)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def _infer_shape(self) -> None:
# add random effects
if not any([cov_model.use_re for cov_model in self.cov_models]):
self.cov_models[0].use_re = True
Expand Down Expand Up @@ -83,14 +92,6 @@ def __init__(
[cov_model.num_regularizations for cov_model in self.cov_models]
)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def attach_data(self, data=None):
"""Attach data to cov_model."""
data = self.data if data is None else data
Expand Down Expand Up @@ -239,6 +240,7 @@ def fit_model(self, **fit_options):
"""
if not all([cov_model.has_data() for cov_model in self.cov_models]):
self.attach_data()
self._infer_shape()

# dimensions
n = self.data.study_sizes
Expand Down

0 comments on commit a2eed13

Please sign in to comment.