Skip to content

Commit

Permalink
increase coverage 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Mar 5, 2024
1 parent 04430a0 commit ee89f46
Show file tree
Hide file tree
Showing 29 changed files with 16,581 additions and 16,882 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Example steps taken by ATOM's pipeline:
* [50+ plots to analyze the data and model performance](https://tvdboom.github.io/ATOM/latest/user_guide/plots/#available-plots)
* [Avoid refactoring to test new pipelines](https://tvdboom.github.io/ATOM/latest/user_guide/data_management/#branches)
* [Native support for GPU training](https://tvdboom.github.io/ATOM/latest/user_guide/accelerating/#gpu-acceleration)
* [Integration with polars, pyspark and pyarrow](https://tvdboom.github.io/ATOM/latest/user_guide/data_management/#data-engines)
* [25+ example notebooks to get you started](https://tvdboom.github.io/ATOM/latest/examples/accelerating_cuml/)
* [Full integration with multilabel and multioutput datasets](https://tvdboom.github.io/ATOM/latest/user_guide/data_management/#multioutput-tasks)
* [Native support for sparse datasets](https://tvdboom.github.io/ATOM/latest/user_guide/data_management/#sparse-datasets)
Expand Down
2 changes: 2 additions & 0 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ def _score_from_est(
data=_check_response_method(estimator, scorer._response_method)(X),
index=y.index,
)
if isinstance(y_pred, pd.DataFrame) and self.task is Task.binary_classification:
y_pred = y_pred.iloc[:, 1] # Return probability of the positive class

return self._score_from_pred(scorer, y, y_pred, **kwargs)

Expand Down
16 changes: 1 addition & 15 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,11 +2621,6 @@ def fit(self, X: XConstructor, y: YConstructor | None = None) -> Self:
random_state=kwargs.pop("random_state", self.random_state),
**kwargs,
)
else:
raise ValueError(
f"Invalid value for the strategy parameter, got {self.strategy}. "
f"Choose from: {', '.join(strategies)}."
)

num_cols = Xt.select_dtypes(include="number")

Expand Down Expand Up @@ -2889,11 +2884,6 @@ def transform(
}

for strat in lst(self.strategy):
if strat not in ["zscore", *strategies]:
raise ValueError(
"Invalid value for the strategy parameter. "
f"Choose from: zscore, {', '.join(strategies)}."
)
if strat != "zscore" and str(self.method) != "drop":
raise ValueError(
"Invalid value for the method parameter. Only the zscore "
Expand Down Expand Up @@ -2986,12 +2976,8 @@ def transform(
yt = yt[outlier_rows]

else:
# Replace the columns in X and y with the new values from objective
# Replace the columns in X with the new values from objective
Xt.update(objective)
if isinstance(yt, pd.Series) and yt.name in objective:
yt.update(objective[str(yt.name)])
elif isinstance(yt, pd.DataFrame):
yt.update(objective)

return variable_return(self._convert(Xt), self._convert(yt))

Expand Down
4 changes: 0 additions & 4 deletions atom/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,10 +955,6 @@ def fit(self, X: XConstructor, y: YConstructor | None = None) -> Self:
)
self._estimator = estimator(**self.kwargs)

if hasattr(self._estimator, "set_output"):
# transform="pandas" fails for sparse output
self._estimator.set_output(transform="default")

self._log("Fitting Vectorizer...", 1)
self._estimator.fit(Xt[self._corpus])

Expand Down
5 changes: 0 additions & 5 deletions atom/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,6 @@ def _fit(
else:
cloned = clone(transformer)

# Attach internal attrs otherwise wiped by clone
for attr in ("_cols", "_train_only"):
if hasattr(transformer, attr):
setattr(cloned, attr, getattr(transformer, attr))

with adjust(cloned, verbose=self._verbose):
# Fit or load the current estimator from cache
# Type ignore because routed_params is never None but
Expand Down
8 changes: 8 additions & 0 deletions atom/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
)


__all__ = [
"DirectClassifier", "DirectForecaster", "DirectRegressor",
"SuccessiveHalvingClassifier", "SuccessiveHalvingForecaster",
"SuccessiveHalvingRegressor", "TrainSizingClassifier",
"TrainSizingForecaster", "TrainSizingRegressor",
]


class Direct(BaseEstimator, BaseTrainer):
"""Direct training approach.
Expand Down
17 changes: 7 additions & 10 deletions atom/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,17 +2023,14 @@ def check_is_fitted(
Whether the estimator is fitted.
"""
if not _is_fitted(obj, attributes):
if exception:
raise NotFittedError(
f"This {type(obj).__name__} instance is not yet fitted. "
f"Call {'run' if hasattr(obj, 'run') else 'fit'} with "
"appropriate arguments before using this object."
)
else:
return False
if not (is_fitted := _is_fitted(obj, attributes)) and exception:
raise NotFittedError(
f"This {type(obj).__name__} instance is not yet fitted. "
f"Call {'run' if hasattr(obj, 'run') else 'fit'} with "
"appropriate arguments before using this object."
)

return True
return is_fitted


def get_custom_scorer(metric: str | MetricFunction | Scorer) -> Scorer:
Expand Down
2 changes: 1 addition & 1 deletion docs_sources/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ packages are necessary for its correct functioning.
* **[optuna](https://optuna.org/)** (>=3.4.0)
* **[pandas](https://pandas.pydata.org/)** (>=2.1.2)
* **[plotly](https://plotly.com/python/)** (>=5.18.0)
* **[scikit-learn](https://scikit-learn.org/stable/)** (>=1.4.0)
* **[scikit-learn](https://scikit-learn.org/stable/)** (>=1.4.1.post1)
* **[scipy](https://www.scipy.org/)** (>=1.10.1)
* **[shap](https://github.com/slundberg/shap/)** (>=0.43.0)
* **[sktime[forecasting]](http://www.sktime.net/en/latest/)** (>=0.26.0)
Expand Down
Loading

0 comments on commit ee89f46

Please sign in to comment.