-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix modelbuilder load #229
base: main
Are you sure you want to change the base?
Changes from 8 commits
e553f47
1aa690b
57ccdb8
c89022b
f6cdcca
0a6f1dc
37c4161
c2ba4bd
1198415
4370eca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,8 +74,8 @@ def __init__( | |
sampler_config = self.default_sampler_config if sampler_config is None else sampler_config | ||
self.sampler_config = sampler_config | ||
model_config = self.default_model_config if model_config is None else model_config | ||
|
||
self.model_config = model_config # parameters for priors etc. | ||
self.model_coords = None | ||
self.model = None # Set by build_model | ||
self.idata: Optional[az.InferenceData] = None # idata is generated during fitting | ||
self.is_fitted_ = False | ||
|
@@ -172,7 +172,7 @@ def default_sampler_config(self) -> Dict: | |
-------- | ||
>>> @classmethod | ||
>>> def default_sampler_config(self): | ||
>>> Return { | ||
>>> return { | ||
>>> 'draws': 1_000, | ||
>>> 'tune': 1_000, | ||
>>> 'chains': 1, | ||
|
@@ -187,13 +187,10 @@ def default_sampler_config(self) -> Dict: | |
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def generate_and_preprocess_model_data( | ||
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series | ||
) -> None: | ||
def preprocess_model_data(self, X: Union[pd.DataFrame, pd.Series], y: pd.Series = None) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Signature doesn't match the return value from docstring, will result in mypy error in child class about breaking Liskov's principle if it was implemented in a way suggested in a docstring |
||
""" | ||
Applies preprocessing to the data before fitting the model. | ||
if validate is True, it will check if the data is valid for the model. | ||
sets self.model_coords based on provided dataset | ||
|
||
Parameters: | ||
X : array, shape (n_obs, n_features) | ||
|
@@ -202,17 +199,16 @@ def generate_and_preprocess_model_data( | |
Examples | ||
-------- | ||
>>> @classmethod | ||
>>> def generate_and_preprocess_model_data(self, X, y): | ||
>>> def preprocess_model_data(self, X, y): | ||
>>> x = np.linspace(start=1, stop=50, num=100) | ||
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4 | ||
>>> X = pd.DataFrame(x, columns=['x']) | ||
>>> y = pd.Series(y, name='y') | ||
>>> self.X = X | ||
>>> self.y = y | ||
>>> return X, y | ||
|
||
Returns | ||
------- | ||
None | ||
pd.DataFrame, pd.Series | ||
|
||
""" | ||
raise NotImplementedError | ||
|
@@ -258,6 +254,23 @@ def build_model( | |
""" | ||
raise NotImplementedError | ||
|
||
def save_model_coords(self, X: Union[pd.DataFrame, pd.Series], y: pd.Series): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an abstract method, since it doesn't actually implement anything (the assignment of None, to a property which is originally initialised to None) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this method does not do anything. I don't know exactly how python abstract methods work, but like this I hoped the specific implementation would be optional for those cases where the model coords actually need to be set. For non-hierarchical models, this method does not need to be touched. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I didn't even include it here is exactly what you're mentioning here: A lot of models won't use the coords at all, but this here would force the models to contain property that will never be used, along with its setter method. In my opinion is best to not include it at all, and allow for developers of specific child model to include this, as it was done in case of DelayedSaturatedMMM from pymc-marketing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does save_model_coords require both X and y as inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did that because I thought there may be some highly unusual cases in which y is also required. I would have no problem with it myself if this function only takes X. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case y should be optional param, with default value being None |
||
"""Creates the model coords. | ||
|
||
Parameters: | ||
X : array, shape (n_obs, n_features) | ||
y : array, shape (n_obs,) | ||
|
||
Examples | ||
-------- | ||
def set_model_coords(self, X, y): | ||
group_dim1 = X['group1'].unique() | ||
group_dim2 = X['group2'].unique() | ||
|
||
self.model_coords = {'group1':group_dim1, 'group2':group_dim2} | ||
""" | ||
self.model_coords = None | ||
|
||
def sample_model(self, **kwargs): | ||
""" | ||
Sample from the PyMC model. | ||
|
@@ -339,6 +352,7 @@ def set_idata_attrs(self, idata=None): | |
idata.attrs["version"] = self.version | ||
idata.attrs["sampler_config"] = json.dumps(self.sampler_config) | ||
idata.attrs["model_config"] = json.dumps(self._serializable_model_config) | ||
idata.attrs["model_coords"] = json.dumps(self.model_coords) | ||
# Only classes with non-dataset parameters will implement save_input_params | ||
if hasattr(self, "_save_input_params"): | ||
self._save_input_params(idata) | ||
|
@@ -432,18 +446,12 @@ def load(cls, fname: str): | |
model_config=model_config, | ||
sampler_config=json.loads(idata.attrs["sampler_config"]), | ||
) | ||
model.model_coords = json.loads(idata.attrs["model_coords"]) | ||
model.idata = idata | ||
dataset = idata.fit_data.to_dataframe() | ||
X = dataset.drop(columns=[model.output_var]) | ||
y = dataset[model.output_var] | ||
model.build_model(X, y) | ||
# All previously used data is in idata. | ||
|
||
if model.id != idata.attrs["id"]: | ||
raise ValueError( | ||
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" | ||
) | ||
|
||
return model | ||
|
||
def fit( | ||
|
@@ -462,7 +470,7 @@ def fit( | |
|
||
Parameters | ||
---------- | ||
X : array-like if sklearn is available, otherwise array, shape (n_obs, n_features) | ||
X : pd.DataFrame (n_obs, n_features) | ||
The training input samples. | ||
y : array-like if sklearn is available, otherwise array, shape (n_obs,) | ||
The target values (real numbers). | ||
|
@@ -492,26 +500,15 @@ def fit( | |
if y is None: | ||
y = np.zeros(X.shape[0]) | ||
y = pd.DataFrame({self.output_var: y}) | ||
self.generate_and_preprocess_model_data(X, y.values.flatten()) | ||
self.build_model(self.X, self.y) | ||
X_prep, y_prep = self.preprocess_model_data(X, y.values.flatten()) | ||
self.save_model_coords(X_prep, y_prep) | ||
self.build_model(X_prep, y_prep) | ||
|
||
sampler_config = self.sampler_config.copy() | ||
sampler_config["progressbar"] = progressbar | ||
sampler_config["random_seed"] = random_seed | ||
sampler_config.update(**kwargs) | ||
self.idata = self.sample_model(**sampler_config) | ||
|
||
X_df = pd.DataFrame(X, columns=X.columns) | ||
combined_data = pd.concat([X_df, y], axis=1) | ||
assert all(combined_data.columns), "All columns must have non-empty names" | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
category=UserWarning, | ||
message="The group fit_data is not defined in the InferenceData scheme", | ||
) | ||
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore | ||
|
||
return self.idata # type: ignore | ||
|
||
def predict( | ||
|
@@ -526,7 +523,7 @@ def predict( | |
|
||
Parameters | ||
--------- | ||
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features) | ||
X_pred : pd.DataFrame (n_pred, n_features) | ||
The input data used for prediction. | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
|
@@ -545,9 +542,12 @@ def predict( | |
>>> prediction_data = pd.DataFrame({'input':x_pred}) | ||
>>> pred_mean = model.predict(prediction_data) | ||
""" | ||
|
||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y) | ||
if self.model is None: | ||
self.build_model(X_pred_prep, y_synth_prep) | ||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined=False, **kwargs | ||
X_pred_prep, extend_idata, combined=False, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
@@ -652,6 +652,7 @@ def get_params(self, deep=True): | |
return { | ||
"model_config": self.model_config, | ||
"sampler_config": self.sampler_config, | ||
"model_coords": self.model_coords, | ||
} | ||
|
||
def set_params(self, **params): | ||
|
@@ -660,6 +661,7 @@ def set_params(self, **params): | |
""" | ||
self.model_config = params["model_config"] | ||
self.sampler_config = params["sampler_config"] | ||
self.model_coords = params["model_coords"] | ||
|
||
@property | ||
@abstractmethod | ||
|
@@ -682,7 +684,11 @@ def predict_proba( | |
**kwargs, | ||
) -> xr.DataArray: | ||
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators.""" | ||
return self.predict_posterior(X_pred, extend_idata, combined, **kwargs) | ||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. preprocess_model_data should not be defined in a way that forces creation of entire dummy dataset for y if we want to process X, so: preprocess_model_data(X_pred: pd.DataFrame, y_pred:Optional[pd.Series]= None) and the logic that checks first if the y was provided. |
||
if self.model is None: | ||
self.build_model(X_pred_prep, y_synth_prep) | ||
return self.predict_posterior(X_pred_prep, extend_idata, combined, **kwargs) | ||
|
||
def predict_posterior( | ||
self, | ||
|
@@ -710,9 +716,13 @@ def predict_posterior( | |
Posterior predictive samples for each input X_pred | ||
""" | ||
|
||
X_pred = self._validate_data(X_pred) | ||
synth_y = pd.Series(np.zeros(len(X_pred))) | ||
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y) | ||
if self.model is None: | ||
self.build_model(X_pred_prep, y_synth_prep) | ||
|
||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined, **kwargs | ||
X_pred_prep, extend_idata, combined=False, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this property should be included in child classes IF they utilise coords. Otherwise it forces all classes that would inherit from MB even if they don't have Hierarchical nature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion. The suggestion does not force any non-hierarchical model to populate the
coords
attribute. If a model has no coords, this can stayNone