Skip to content

Commit

Permalink
fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Nov 29, 2023
1 parent 3b8e4a8 commit b756dbc
Show file tree
Hide file tree
Showing 14 changed files with 8,244 additions and 2,572 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
args: ["--show-source", "--statistics"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [pip==23.3.1, types-requests, pandas-stubs, beartype]
Expand Down
160 changes: 95 additions & 65 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
import pandas as pd
import ray
from beartype import beartype
from beartype.roar import (
BeartypeCallHintParamViolation, BeartypeCallHintReturnViolation,
)
from joblib.memory import Memory
from joblib.parallel import Parallel, delayed
from mlflow.data import from_pandas
Expand Down Expand Up @@ -68,13 +65,13 @@
from atom.utils.types import (
HT, Backend, Bool, DataFrame, Engine, FHSelector, Float, FloatZeroToOneExc,
Int, IntLargerEqualZero, MetricConstructor, NJobs, Pandas,
PredictionMethod, Predictor, RowSelector, Scalar, Scorer, Sequence, Stages,
TargetSelector, Verbose, Warnings, XSelector, YSelector, dataframe_t,
float_t, int_t,
PredictionMethods, PredictionMethodsTS, Predictor, RowSelector, Scalar,
Scorer, Sequence, Stages, TargetSelector, Verbose, Warnings, XSelector,
YSelector, dataframe_t, float_t, int_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, Goal, PlotCallback, ShapExplanation, Task,
TrialsCallback, adjust_verbosity, bk, check_dependency, check_empty,
TrialsCallback, adjust_verbosity, bk, cache, check_dependency, check_empty,
check_is_fitted, check_scaling, composed, crash, estimator_has_attr,
fit_and_score, flt, get_cols, get_custom_scorer, has_task, it, lst, merge,
method_to_log, rnd, sign, time_to_str, to_df, to_pandas, to_series,
Expand Down Expand Up @@ -266,16 +263,17 @@ def __setstate__(self, state: dict[str, Any]):
self.__dict__.update(state)

def __getattr__(self, item: str) -> Any:
if item in dir(self.branch) and not item.startswith("_"):
return getattr(self.branch, item) # Get attr from branch
elif item in self.branch.columns:
return self.branch.dataset[item] # Get column
elif item in DF_ATTRS:
return getattr(self.branch.dataset, item) # Get attr from dataset
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}'."
)
if "_branch" in self.__dict__:
if item in dir(self.branch) and not item.startswith("_"):
return getattr(self.branch, item) # Get attr from branch
elif item in self.branch.columns:
return self.branch.dataset[item] # Get column
elif item in DF_ATTRS:
return getattr(self.branch.dataset, item) # Get attr from dataset

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}'."
)

def __contains__(self, item: str) -> bool:
return item in self.dataset
Expand Down Expand Up @@ -433,26 +431,28 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
"""
# Separate the params for the estimator from those in sub-estimators
base_params, sub_params = {}, {}
for name, value in params.items():
if "__" not in name:
base_params[name] = value
else:
sub_params[name] = value
if params:
for name, value in params.items():
if "__" not in name:
base_params[name] = value
else:
sub_params[name] = value

estimator = self._inherit(self._est_class(**base_params))
estimator.set_params(**sub_params)

if self.task is Task.multilabel_classification:
if not self.native_multilabel:
estimator = ClassifierChain(estimator)
elif self.task.is_multioutput and not self.native_multioutput:
if self.task.is_classification:
estimator = MultiOutputClassifier(estimator)
elif self.task.is_regression:
estimator = MultiOutputRegressor(estimator)
elif hasattr(self, "_estimators") and self._goal.name not in self._estimators:
# Forecasting task with a regressor
estimator = make_reduction(estimator)
if hasattr(self, "task"):
if self.task is Task.multilabel_classification:
if not self.native_multilabel:
estimator = ClassifierChain(estimator)
elif self.task.is_multioutput and not self.native_multioutput:
if self.task.is_classification:
estimator = MultiOutputClassifier(estimator)
elif self.task.is_regression:
estimator = MultiOutputRegressor(estimator)
elif hasattr(self, "_estimators") and self._goal.name not in self._estimators:
# Forecasting task with a regressor
estimator = make_reduction(estimator)

return self._inherit(estimator)

Expand Down Expand Up @@ -617,11 +617,12 @@ def _final_output(self) -> str:

return out

@cache
def _get_pred(
self,
rows: RowSelector,
target: TargetSelector | None = None,
attr: PredictionMethod = "predict",
attr: PredictionMethods | Literal["thresh"] = "predict",

Check warning on line 625 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
) -> tuple[Pandas, Pandas]:
"""Get the true and predicted values for a column.
Expand All @@ -640,12 +641,8 @@ def _get_pred(
If None, all columns are returned.
attr: str, default="predict"
Attribute used to get predictions. Choose from:
- "predict": Use the `predict` method.
- "predict_proba": Use the `predict_proba` method.
- "decision_function": Use the `decision_function` method.
- "thresh": Use `decision_function` or `predict_proba`.
Method used to get predictions. Use "thresh" to get
`decision_function` or `predict_proba` in that order.
Returns
-------
Expand All @@ -658,7 +655,7 @@ def _get_pred(
"""
# Select method to use for predictions
if attr == "thresh":
for attribute in PredictionMethod.__args__:
for attribute in PredictionMethods.__args__:
if hasattr(self.estimator, attribute):
attr = attribute
break
Expand Down Expand Up @@ -760,7 +757,7 @@ def _score_from_pred(
Calculated score.
"""
func = lambda x, y: scorer._score_func(x, y, **scorer._kwargs, **kwargs)
func = lambda y1, y2: scorer._score_func(y1, y2, **scorer._kwargs, **kwargs)

Check notice on line 760 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _kwargs of a class

Check notice on line 760 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _score_func of a class

# Forecasting models can have first prediction NaN
if self.task.is_forecast and all(x.isna()[0] for x in get_cols(y_pred)):
Expand Down Expand Up @@ -1176,10 +1173,7 @@ def fit(self, X: DataFrame | None = None, y: Pandas | None = None):
)

for ds in ("train", "test"):
out = [
f"{metric.name}: {self._get_score(metric, ds)}"
for metric in self._metric
]
out = [f"{met.name}: {self._get_score(met, ds)}" for met in self._metric]
self._log(f"T{ds[1:]} evaluation --> {' '.join(out)}", 1)

# Get duration and print to log
Expand Down Expand Up @@ -1759,13 +1753,15 @@ def clear(self):
affected attributes are:
- [In-training validation][] scores
- [Cached predictions][predicting].
- [Shap values][shap]
- [App instance][self-create_app]
- [Dashboard instance][self-create_dashboard]
- Calculated [holdout data sets][data-sets]
"""
self._evals = defaultdict(list)
self._get_pred.clear_cache()
self._shap_explanation = None
self.__dict__.pop("app", None)
self.__dict__.pop("dashboard", None)
Expand Down Expand Up @@ -2449,7 +2445,7 @@ def _prediction(
metric: MetricConstructor = ...,
sample_weight: Sequence[Scalar] | None = ...,
verbose: Int | None = ...,
method: str = ...,
method: PredictionMethods = ...,
) -> Pandas: ...

def _prediction(

Check warning on line 2451 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Redeclared names without usages

Redeclared '_prediction' defined above without usage
Expand All @@ -2459,7 +2455,7 @@ def _prediction(
metric: MetricConstructor = None,
sample_weight: Sequence[Scalar] | None = None,
verbose: Int | None = None,
method: str = "predict",
method: PredictionMethods = "predict",
) -> Float | Pandas:
"""Get predictions on new data or existing rows.
Expand Down Expand Up @@ -2510,6 +2506,31 @@ def _prediction(
"""

def get_transform_X_y(X: XSelector, y: YSelector) -> tuple[DataFrame, Pandas]:

Check notice on line 2509 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 2509 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Function name should be lowercase

Check notice on line 2509 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'y' from outer scope

Check notice on line 2509 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'X' from outer scope
"""Get X and y from the pipeline transformation.
Parameters
----------
X: dataframe-like
Feature set.
y: int, str or sequence
Target column.
Returns
-------
dataframe
Transformed feature set.
series or dataframe
Transformed target column.
"""
if isinstance(out := self.transform(X, y, verbose=verbose), tuple):
return out
else:
return out, y

def assign_prediction_columns() -> list[str]:
"""Assign column names for the prediction methods.
Expand All @@ -2525,24 +2546,22 @@ def assign_prediction_columns() -> list[str]:
return self.mapping.get(self.target, np.unique(self.y).astype(str))

try:
if isinstance(out := self.transform(X, y, verbose=verbose), tuple):
Xt, yt = out
if isinstance(X, dataframe_t):
# Dataframe must go first since we can expect
# prediction calls from dataframes with reset indices
Xt, yt = get_transform_X_y(X, y)

Check notice on line 2552 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
else:
Xt, yt = out, y
except (
BeartypeCallHintParamViolation,
BeartypeCallHintReturnViolation,
ValueError,
):
Xt, yt = self.branch._get_rows(X, return_X_y=True)
Xt, yt = self.branch._get_rows(X, return_X_y=True)

Check notice on line 2554 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 2554 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if self.scaler:
Xt = self.scaler.transform(Xt)
if self.scaler:
Xt = self.scaler.transform(Xt)

Check notice on line 2557 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
except Exception:
Xt, yt = get_transform_X_y(X, y)

Check notice on line 2559 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if method != "score":
pred = np.array(self.memory.cache(getattr(self.estimator, method))(Xt))

if np.array(pred).ndim < 3:
if pred.ndim < 3:
data = to_pandas(
data=pred,
index=Xt.index,
Expand All @@ -2557,7 +2576,7 @@ def assign_prediction_columns() -> list[str]:
columns=assign_prediction_columns(),
)
else:
# Convert to (n_samples * n_classes, n_targets)'
# Convert to (n_samples * n_classes, n_targets)
data = bk.DataFrame(
data=pred.reshape(-1, pred.shape[2]),
index=bk.MultiIndex.from_tuples(
Expand Down Expand Up @@ -2806,6 +2825,8 @@ class ForecastModel(BaseModel):
@overload
def _prediction(
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,

Check notice on line 2829 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: MetricConstructor = None,
verbose: Int | None = None,
method: Literal["score"] = ...,

Check warning on line 2832 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
Expand All @@ -2815,17 +2836,21 @@ def _prediction(
@overload
def _prediction(
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,

Check notice on line 2840 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: MetricConstructor = None,
verbose: Int | None = None,
method: str = ...,
method: PredictionMethodsTS = ...,
**kwargs,
) -> Pandas: ...

def _prediction(

Check warning on line 2847 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Redeclared names without usages

Redeclared '_prediction' defined above without usage
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,

Check notice on line 2850 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: MetricConstructor = None,
verbose: Int | None = None,
method: str = "predict",
method: PredictionMethodsTS = "predict",
**kwargs,
) -> Float | Pandas:
"""Get predictions on new data or existing rows.
Expand All @@ -2836,6 +2861,12 @@ def _prediction(
Parameters
----------
y: sequence or dataframe-like
Ground truth observations.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
metric: str, func, scorer or None, default=None
Metric to calculate. Choose from any of sklearn's scorers,
a function with signature metric(y_true, y_pred) or a scorer
Expand All @@ -2859,8 +2890,7 @@ def _prediction(
called.
"""
if (X := kwargs.get("X")) is not None and (y := kwargs.get("y")) is not None:
Xt, yt = self.transform(X, y, verbose=verbose)
Xt, yt = self.transform(X, y, verbose=verbose)

Check notice on line 2893 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if method != "score":
return self.memory.cache(getattr(self.estimator, method))(**kwargs)
Expand Down
Loading

0 comments on commit b756dbc

Please sign in to comment.