-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix modelbuilder load #229
base: main
Are you sure you want to change the base?
Changes from all commits
e553f47
1aa690b
57ccdb8
c89022b
f6cdcca
0a6f1dc
37c4161
c2ba4bd
1198415
4370eca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,14 @@ | |
# limitations under the License. | ||
|
||
|
||
# Modified by Stijn de Boer | ||
|
||
import hashlib | ||
import json | ||
import warnings | ||
from abc import abstractmethod | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
from typing import Any, Dict, List, Optional, Union, Tuple | ||
|
||
import arviz as az | ||
import numpy as np | ||
|
@@ -74,8 +76,8 @@ def __init__( | |
sampler_config = self.default_sampler_config if sampler_config is None else sampler_config | ||
self.sampler_config = sampler_config | ||
model_config = self.default_model_config if model_config is None else model_config | ||
|
||
self.model_config = model_config # parameters for priors etc. | ||
self.model_coords = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this property should be included in child classes IF they utilise coords. Otherwise it forces all classes that would inherit from MB even if they don't have Hierarchical nature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your suggestion. The suggestion does not force any non-hierarchical model to populate the |
||
self.model = None # Set by build_model | ||
self.idata: Optional[az.InferenceData] = None # idata is generated during fitting | ||
self.is_fitted_ = False | ||
|
@@ -172,7 +174,7 @@ def default_sampler_config(self) -> Dict: | |
-------- | ||
>>> @classmethod | ||
>>> def default_sampler_config(self): | ||
>>> Return { | ||
>>> return { | ||
>>> 'draws': 1_000, | ||
>>> 'tune': 1_000, | ||
>>> 'chains': 1, | ||
|
@@ -187,32 +189,29 @@ def default_sampler_config(self) -> Dict: | |
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def generate_and_preprocess_model_data( | ||
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series | ||
) -> None: | ||
def preprocess_model_data( | ||
self, X: Union[pd.DataFrame, pd.Series], y: Optional[pd.Series] = None | ||
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.Series]]: | ||
""" | ||
Applies preprocessing to the data before fitting the model. | ||
if validate is True, it will check if the data is valid for the model. | ||
sets self.model_coords based on provided dataset | ||
|
||
Parameters: | ||
X : array, shape (n_obs, n_features) | ||
y : array, shape (n_obs,) | ||
|
||
Examples | ||
-------- | ||
>>> @classmethod | ||
>>> def generate_and_preprocess_model_data(self, X, y): | ||
>>> x = np.linspace(start=1, stop=50, num=100) | ||
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4 | ||
>>> X = pd.DataFrame(x, columns=['x']) | ||
>>> y = pd.Series(y, name='y') | ||
>>> self.X = X | ||
>>> self.y = y | ||
>>> def preprocess_model_data(self, X: DataFrame | Series, y: Series = None): | ||
>>> X_prep = X.copy() | ||
>>> X_prep['x'] = (X_prep['x'] - X_prep['x'].mean())/X_prep['x'].std() | ||
>>> if y is None: | ||
>>> return X_prep | ||
>>> return X_prep, y.copy() | ||
|
||
Returns | ||
------- | ||
None | ||
Union[pd.DataFrame, Tuple[pd.DataFrame, pd.Series]] | ||
|
||
""" | ||
raise NotImplementedError | ||
|
@@ -258,6 +257,28 @@ def build_model( | |
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def set_model_coords_from_data( | ||
self, X: Union[pd.DataFrame, pd.Series], y: Optional[pd.Series] = None | ||
) -> None: | ||
"""Creates the model coords. | ||
|
||
Parameters: | ||
X : array, shape (n_obs, n_features) | ||
y : array, shape (n_obs,) | ||
|
||
Examples | ||
-------- | ||
def extract_model_coords_from_data(self, X): | ||
group_dim1 = X['group1'].unique() | ||
group_dim2 = X['group2'].unique() | ||
self.model_coords = {'group1':group_dim1, 'group2':group_dim2} | ||
|
||
Returns | ||
------- | ||
Dict[str, List[Union[str, int]] | ||
""" | ||
|
||
def sample_model(self, **kwargs): | ||
""" | ||
Sample from the PyMC model. | ||
|
@@ -374,7 +395,7 @@ def save(self, fname: str) -> None: | |
>>> model.fit(data) | ||
>>> model.save('model_results.nc') # This will call the overridden method in MyModel | ||
""" | ||
if self.idata is not None and "posterior" in self.idata: | ||
if self.is_fitted: | ||
file = Path(str(fname)) | ||
self.idata.to_netcdf(str(file)) | ||
else: | ||
|
@@ -433,17 +454,11 @@ def load(cls, fname: str): | |
sampler_config=json.loads(idata.attrs["sampler_config"]), | ||
) | ||
model.idata = idata | ||
dataset = idata.fit_data.to_dataframe() | ||
X = dataset.drop(columns=[model.output_var]) | ||
y = dataset[model.output_var] | ||
model.build_model(X, y) | ||
# All previously used data is in idata. | ||
|
||
model.set_model_coords_from_idata() | ||
if model.id != idata.attrs["id"]: | ||
raise ValueError( | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" | ||
) | ||
|
||
return model | ||
|
||
def fit( | ||
|
@@ -462,7 +477,7 @@ def fit( | |
|
||
Parameters | ||
---------- | ||
X : array-like if sklearn is available, otherwise array, shape (n_obs, n_features) | ||
X : pd.DataFrame (n_obs, n_features) | ||
The training input samples. | ||
y : array-like if sklearn is available, otherwise array, shape (n_obs,) | ||
The target values (real numbers). | ||
|
@@ -492,26 +507,15 @@ def fit( | |
if y is None: | ||
y = np.zeros(X.shape[0]) | ||
y = pd.DataFrame({self.output_var: y}) | ||
self.generate_and_preprocess_model_data(X, y.values.flatten()) | ||
self.build_model(self.X, self.y) | ||
X_prep, y_prep = self.preprocess_model_data(X, y.values.flatten()) | ||
self.set_model_coords_from_data(X) | ||
self.build_model(X_prep, y_prep) | ||
|
||
sampler_config = self.sampler_config.copy() | ||
sampler_config["progressbar"] = progressbar | ||
sampler_config["random_seed"] = random_seed | ||
sampler_config.update(**kwargs) | ||
self.idata = self.sample_model(**sampler_config) | ||
|
||
X_df = pd.DataFrame(X, columns=X.columns) | ||
combined_data = pd.concat([X_df, y], axis=1) | ||
assert all(combined_data.columns), "All columns must have non-empty names" | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
category=UserWarning, | ||
message="The group fit_data is not defined in the InferenceData scheme", | ||
) | ||
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore | ||
|
||
return self.idata # type: ignore | ||
|
||
def predict( | ||
|
@@ -526,7 +530,7 @@ def predict( | |
|
||
Parameters | ||
--------- | ||
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features) | ||
X_pred : pd.DataFrame (n_pred, n_features) | ||
The input data used for prediction. | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
@@ -545,9 +549,13 @@ def predict( | |
>>> prediction_data = pd.DataFrame({'input':x_pred}) | ||
>>> pred_mean = model.predict(prediction_data) | ||
""" | ||
X_pred_prep = self.preprocess_model_data(X_pred) | ||
if self.model is None: | ||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
self.build_model(X_pred_prep, synth_y) | ||
|
||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined=False, **kwargs | ||
X_pred_prep, extend_idata, combined=False, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
@@ -682,7 +690,11 @@ def predict_proba( | |
**kwargs, | ||
) -> xr.DataArray: | ||
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators.""" | ||
return self.predict_posterior(X_pred, extend_idata, combined, **kwargs) | ||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. preprocess_model_data should not be defined in a way that forces creation of entire dummy dataset for y if we want to process X, so: preprocess_model_data(X_pred: pd.DataFrame, y_pred:Optional[pd.Series]= None) and the logic that checks first if the y was provided. |
||
if self.model is None: | ||
self.build_model(X_pred_prep, y_synth_prep) | ||
return self.predict_posterior(X_pred_prep, extend_idata, combined, **kwargs) | ||
|
||
def predict_posterior( | ||
self, | ||
|
@@ -710,9 +722,13 @@ def predict_posterior( | |
Posterior predictive samples for each input X_pred | ||
""" | ||
|
||
X_pred = self._validate_data(X_pred) | ||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y) | ||
if self.model is None: | ||
self.build_model(X_pred_prep, y_synth_prep) | ||
|
||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined, **kwargs | ||
X_pred_prep, extend_idata, combined=False, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
@@ -746,3 +762,13 @@ def id(self) -> str: | |
hasher.update(self.version.encode()) | ||
hasher.update(self._model_type.encode()) | ||
return hasher.hexdigest()[:16] | ||
|
||
@property | ||
def is_fitted(self): | ||
return self.idata is not None and "posterior" in self.idata | ||
|
||
def set_model_coords_from_idata(self): | ||
az_coords = self.idata.posterior.coords.variables | ||
self.model_coords = { | ||
k: list(az_coords[k].to_numpy()) for k in az_coords if not k in ["chain", "draw"] | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have this info in files, it's in the release notes and commit history.