Skip to content

Commit

Permalink
fix inherit for meta-estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcovdBoom committed Jan 15, 2024
1 parent 4f0d673 commit 7b810c2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
5 changes: 2 additions & 3 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
else:
sub_params[name] = value

estimator = self._inherit(self._est_class(**base_params))
estimator.set_params(**sub_params)
estimator = self._est_class(**base_params).set_params(**sub_params)

if hasattr(self, "task"):
if self.task is Task.multilabel_classification:
Expand Down Expand Up @@ -1021,7 +1020,7 @@ def fit_model(
random_state=trial.number + (self.random_state or 0),
)
else: # Custom cross-validation generator
splitter = self._inherit(cv)
splitter = cv

args = [self.og.X_train]
if "y" in sign(splitter.split) and cols is not None:

Check warning on line 1026 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'splitter' might be referenced before assignment
Expand Down
33 changes: 21 additions & 12 deletions atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@
from joblib.memory import Memory
from pandas._typing import Axes

Check notice on line 32 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _typing of a class
from ray.util.joblib import register_ray
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_memory

from atom.utils.constants import SHARED_PARAMS
from atom.utils.types import (
Backend, Bool, DataFrame, Engine, Estimator, Int, IntLargerEqualZero,
Pandas, Sequence, Severity, Verbose, Warnings, XSelector, YSelector,
bool_t, dataframe_t, int_t, sequence_t,
)
from atom.utils.utils import crash, flt, n_cols, sign, to_df, to_pandas
from atom.utils.utils import crash, flt, lst, n_cols, sign, to_df, to_pandas


T_Estimator = TypeVar("T_Estimator", bound=Estimator)
Expand Down Expand Up @@ -357,28 +359,35 @@ def _inherit(self, obj: T_Estimator) -> T_Estimator:
Utility method to set the sp (seasonal period), n_jobs and
random_state parameters of an estimator (if available) equal
to that of this instance.
to that of this instance. If `obj` is a meta-estimator, it
also sets the parameters to the base estimator. Note that
the parameters are only changed when the value is equal to
the constructor's default value.
Parameters
----------
obj: Estimator
Object in which to change the parameters.
Instance for which to change the parameters.
Returns
-------
Estimator
Same object with changed parameters.
"""
signature = sign(obj.__init__) # type: ignore[misc]
for p in ("n_jobs", "random_state"):
if p in signature and getattr(obj, p, "<!>") == signature[p]._default:
setattr(obj, p, getattr(self, p))

# Add seasonal period to the estimator
if hasattr(self, "_config") and self._config.sp:
if "sp" in signature and getattr(obj, "sp", "<!>") == signature["sp"]._default:
obj.sp = self._config.sp if self.multiple_seasonality else flt(self._config.sp)
signature = sign(obj.__class__)
for p, value in obj.get_params().items():
if p in SHARED_PARAMS and (p not in signature or value == signature[p]._default):

Check notice on line 380 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
# Some estimators like XGB use kwargs, so param
# isn't in signature. In that case, always override
obj.set_params(**{p: getattr(self, p)})
elif isinstance(value, BaseEstimator):
obj.set_params(**{p: self._inherit(value)})
elif p == "sp" and hasattr(self, "_config") and self._config.sp:
if self.multiple_seasonality:
obj.set_params(**{p: self._config.sp})
else:
obj.set_params(**{p: lst(self._config.sp)[0]})

return obj

Expand Down
3 changes: 3 additions & 0 deletions atom/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# Default string values considered missing
DEFAULT_MISSING = ["", "?", "NA", "nan", "NaN", "NaT", "none", "None", "inf", "-inf"]

# Shared parameters between estimators
SHARED_PARAMS = ("n_jobs", "random_state")

# Attributes shared between atom and a dataframe
DF_ATTRS = (
"size",
Expand Down

0 comments on commit 7b810c2

Please sign in to comment.