Skip to content

Commit

Permalink
fixes for sklearn 1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcovdBoom committed Jan 22, 2024
1 parent d2277d1 commit 1999447
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 115 deletions.
2 changes: 1 addition & 1 deletion atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from atom.utils.constants import __version__


sklearn.set_config(transform_output="pandas")
sklearn.set_config(transform_output="pandas", enable_metadata_routing=True)
4 changes: 2 additions & 2 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ def add(
"""Add a transformer to the pipeline.
If the transformer is not fitted, it is fitted on the complete
training set. Afterwards, the data set is transformed and the
training set. Afterward, the data set is transformed and the
estimator is added to atom's pipeline. If the estimator is
a sklearn Pipeline, every estimator is merged independently
with atom.
Expand Down Expand Up @@ -1639,7 +1639,7 @@ def encode(
max_onehot: IntLargerTwo | None = 10,
ordinal: dict[str, Sequence[Any]] | None = None,
infrequent_to_value: FloatLargerZero | None = None,
value: str = "rare",
value: str = "infrequent",
**kwargs,
):
"""Perform encoding of categorical features.
Expand Down
56 changes: 29 additions & 27 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from sklearn.utils import resample
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import _check_response_method

Check notice on line 53 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _check_response_method of a class
from sktime.forecasting.base import ForecastingHorizon
from sktime.forecasting.compose import make_reduction
from sktime.forecasting.model_evaluation import evaluate
Expand Down Expand Up @@ -643,7 +644,7 @@ def _get_pred(
self,
rows: RowSelector,
target: TargetSelector | None = None,
attr: PredictionMethods | Literal["thresh"] = "predict",
method: PredictionMethods | Sequence[PredictionMethods] = "predict",
) -> tuple[Pandas, Pandas]:
"""Get the true and predicted values for a column.
Expand All @@ -661,9 +662,10 @@ def _get_pred(
Target column to look at. Only for [multioutput tasks][].
If None, all columns are returned.
attr: str, default="predict"
Method used to get predictions. Use "thresh" to get
`decision_function` or `predict_proba` in that order.
method: str or sequence, default="predict"
Response method(s) used to get predictions. If sequence,
the order provided states the order in which the methods
are tried.
Returns
-------
Expand All @@ -674,12 +676,7 @@ def _get_pred(
Predicted values.
"""
# Select method to use for predictions
if attr == "thresh":
for attribute in PredictionMethods.__args__:
if hasattr(self.estimator, attribute):
attr = attribute
break
method_caller = _check_response_method(self.estimator, method).__name__

X, y = self.branch._get_rows(rows, return_X_y=True)

Check notice on line 681 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 681 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

Expand All @@ -695,11 +692,16 @@ def _get_pred(
self.estimator.get_tags().get("capability:insample")
and (not self.estimator.get_tags()["requires-fh-in-fit"] or rows == "test")
):
y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
y_pred = self._prediction(
fh=X.index,
X=check_empty(X),
verbose=0,
method=method_caller,
)
else:
y_pred = bk.Series([np.NaN] * len(X), index=X.index)
else:
y_pred = self._prediction(X.index, verbose=0, method=attr)
y_pred = self._prediction(X.index, verbose=0, method=method_caller)

if self.task.is_multioutput:
if target is not None:
Expand Down Expand Up @@ -843,21 +845,21 @@ def _get_score(
Metric score on the selected data set.
"""
if scorer.__class__.__name__ == "_ThresholdScorer":
y_true, y_pred = self._get_pred(rows, attr="thresh")
elif scorer.__class__.__name__ == "_ProbaScorer":
y_true, y_pred = self._get_pred(rows, attr="predict_proba")
else:
if threshold and self.task.is_binary and hasattr(self, "predict_proba"):
y_true, y_pred = self._get_pred(rows, attr="predict_proba")
if isinstance(y_pred, dataframe_t):
# Update every target column with its corresponding threshold
for i, value in enumerate(threshold):
y_pred.iloc[:, i] = (y_pred.iloc[:, i] > value).astype("int")
else:
y_pred = (y_pred > threshold[0]).astype("int")
if (
scorer._response_method == "predict"

Check notice on line 849 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _response_method of a class
and threshold
and self.task.is_binary
and hasattr(self.estimator, "predict_proba")
):
y_true, y_pred = self._get_pred(rows, method="predict_proba")
if isinstance(y_pred, dataframe_t):
# Update every target column with its corresponding threshold
for i, value in enumerate(threshold):
y_pred.iloc[:, i] = (y_pred.iloc[:, i] > value).astype("int")
else:
y_true, y_pred = self._get_pred(rows, attr="predict")
y_pred = (y_pred > threshold[0]).astype("int")
else:
y_true, y_pred = self._get_pred(rows, method=scorer._response_method)

Check notice on line 862 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _response_method of a class

kwargs = {}
if "sample_weight" in sign(scorer._score_func):

Check notice on line 865 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _score_func of a class
Expand Down Expand Up @@ -2917,7 +2919,7 @@ def score(
metric for [multi-metric runs][]).
sample_weight: sequence or None, default=None
Sample weights corresponding to y.
Sample weights corresponding to `y`.
verbose: int or None, default=None
Verbosity level for the transformers in the pipeline. If
Expand Down
Loading

0 comments on commit 1999447

Please sign in to comment.