Skip to content

Commit

Permalink
fix examples 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Nov 29, 2023
1 parent b756dbc commit 5fca7a1
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 89 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,12 @@ jobs:
pip install -U pytest pytest-xdist nbmake scikeras tensorflow
pip install -e .[full]
- name: Run example notebooks
run: pytest --nbmake -n=auto --nbmake-timeout=600 --ignore=./examples/webapp/ --ignore=./examples/accelerating_cuml.ipynb ./examples/
run: |
pytest \
-n=auto \
--nbmake \
--nbmake-timeout=600 \
--ignore=./examples/webapp/ \
--ignore=./examples/accelerating_cuml.ipynb \
--ignore=./examples/ray_backend.ipynb \
./examples/
6 changes: 3 additions & 3 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,12 +548,12 @@ def eda(
rows_c = [(self.branch._get_rows(v), k) for k, v in rows.items()]

Check notice on line 548 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if len(rows_c) == 1:

Check warning on line 550 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'rows_c' might be referenced before assignment
self.report = self.memory.cache(sv.analyze)(
self.report = sv.analyze(
source=rows_c[0],
target_feat=self.branch._get_target(target, only_columns=True),

Check notice on line 553 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_target of a class
)
elif len(rows_c) == 2:
self.report = self.memory.cache(sv.compare)(
self.report = sv.compare(

Check notice on line 556 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute report defined outside __init__
source=rows_c[0],
compare=rows_c[1],
target_feat=self.branch._get_target(target, only_columns=True),

Check notice on line 559 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_target of a class
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def _add_transformer(
name = transformer_c.__class__.__name__
while name in self.pipeline:
counter += 1
name = f"{transformer_c.__class__.__name__}{counter}"
name = f"{transformer_c.__class__.__name__.lower()}-{counter}"

self.branch.pipeline.steps.append((name, transformer_c))

Expand Down
99 changes: 51 additions & 48 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from atom.utils.constants import DF_ATTRS
from atom.utils.types import (
HT, Backend, Bool, DataFrame, Engine, FHSelector, Float, FloatZeroToOneExc,
Int, IntLargerEqualZero, MetricConstructor, NJobs, Pandas,
Int, IntLargerEqualZero, MetricConstructor, MetricFunction, NJobs, Pandas,
PredictionMethods, PredictionMethodsTS, Predictor, RowSelector, Scalar,
Scorer, Sequence, Stages, TargetSelector, Verbose, Warnings, XSelector,
YSelector, dataframe_t, float_t, int_t,
Expand Down Expand Up @@ -2431,7 +2431,7 @@ def _prediction(
self,
X: RowSelector | XSelector,

Check notice on line 2432 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = ...,
metric: MetricConstructor = ...,
metric: str | MetricFunction | Scorer | None = ...,
sample_weight: Sequence[Scalar] | None = ...,
verbose: Int | None = ...,
method: Literal["score"] = ...,

Check warning on line 2437 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
Expand All @@ -2442,7 +2442,7 @@ def _prediction(
self,
X: RowSelector | XSelector,

Check notice on line 2443 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = ...,
metric: MetricConstructor = ...,
metric: str | MetricFunction | Scorer | None = ...,
sample_weight: Sequence[Scalar] | None = ...,
verbose: Int | None = ...,
method: PredictionMethods = ...,
Expand All @@ -2452,7 +2452,7 @@ def _prediction(
self,
X: RowSelector | XSelector,

Check notice on line 2453 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = None,
metric: MetricConstructor = None,
metric: str | MetricFunction | Scorer | None = None,
sample_weight: Sequence[Scalar] | None = None,
verbose: Int | None = None,
method: PredictionMethods = "predict",
Expand All @@ -2465,7 +2465,7 @@ def _prediction(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand Down Expand Up @@ -2619,7 +2619,7 @@ def decision_function(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand Down Expand Up @@ -2656,7 +2656,7 @@ def predict(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand Down Expand Up @@ -2692,7 +2692,7 @@ def predict_log_proba(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand Down Expand Up @@ -2728,7 +2728,7 @@ def predict_proba(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand All @@ -2754,7 +2754,7 @@ def score(
X: RowSelector | XSelector,

Check notice on line 2754 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = None,
*,
metric: MetricConstructor = None,
metric: str | MetricFunction | Scorer | None = None,
sample_weight: Sequence[Scalar] | None = None,
verbose: Int | None = None,
) -> Float:
Expand All @@ -2773,7 +2773,7 @@ def score(
Parameters
----------
X: hashable, range, slice, sequence or dataframe-like
X: hashable, segment, sequence or dataframe-like
[Selection of rows][row-and-column-selection] or feature
set with shape=(n_samples, n_features) to make predictions
on.
Expand Down Expand Up @@ -2825,9 +2825,9 @@ class ForecastModel(BaseModel):
@overload
def _prediction(
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,
metric: MetricConstructor = None,
y: RowSelector | YSelector | None = None,
X: XSelector | None = None,

Check notice on line 2829 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: str | MetricFunction | Scorer | None = None,
verbose: Int | None = None,
method: Literal["score"] = ...,

Check warning on line 2832 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
**kwargs,
Expand All @@ -2836,19 +2836,19 @@ def _prediction(
@overload
def _prediction(
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,
metric: MetricConstructor = None,
y: RowSelector | YSelector | None = None,
X: XSelector | None = None,

Check notice on line 2840 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: str | MetricFunction | Scorer | None = None,
verbose: Int | None = None,
method: PredictionMethodsTS = ...,
**kwargs,
) -> Pandas: ...

def _prediction(

Check warning on line 2847 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Redeclared names without usages

Redeclared '_prediction' defined above without usage
self,
y: YSelector | None = None,
X: RowSelector | XSelector | None = None,
metric: MetricConstructor = None,
y: RowSelector | YSelector | None = None,
X: XSelector | None = None,

Check notice on line 2850 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
metric: str | MetricFunction | Scorer | None = None,
verbose: Int | None = None,

Check notice on line 2852 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'verbose' value is not used
method: PredictionMethodsTS = "predict",
**kwargs,
Expand All @@ -2861,11 +2861,11 @@ def _prediction(
Parameters
----------
y: sequence or dataframe-like
y: int, str, dict, sequence, dataframe or None, default=None
Ground truth observations.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
metric: str, func, scorer or None, default=None
Metric to calculate. Choose from any of sklearn's scorers,
Expand All @@ -2890,10 +2890,13 @@ def _prediction(
called.
"""
Xt, yt = self.transform(X, y, verbose=verbose)
Xt, yt = X, y # self.transform(X, y, verbose=verbose) TODO: Fix pipeline ts

Check notice on line 2893 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if method != "score":
return self.memory.cache(getattr(self.estimator, method))(**kwargs)
if "y" in sign(func := getattr(self.estimator, method)):
return self.memory.cache(func)(y=yt, X=Xt, **kwargs)
else:
return self.memory.cache(func)(X=Xt, **kwargs)
else:
if metric is None:
scorer = self._metric[0]
Expand Down Expand Up @@ -2925,8 +2928,8 @@ def predict(
The forecasting horizon encoding the time stamps to
forecast at.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
verbose: int or None, default=None
Verbosity level for the transformers in the pipeline. If None,
Expand Down Expand Up @@ -2965,8 +2968,8 @@ def predict_interval(
The forecasting horizon encoding the time stamps to
forecast at.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
coverage: float or sequence, default=0.9
Nominal coverage(s) of predictive interval(s).
Expand Down Expand Up @@ -3014,8 +3017,8 @@ def predict_proba(
The forecasting horizon encoding the time stamps to
forecast at.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
marginal: bool, default=True
Whether returned distribution is marginal by time index.
Expand Down Expand Up @@ -3062,8 +3065,8 @@ def predict_quantiles(
The forecasting horizon encoding the time stamps to
forecast at.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
alpha: float or list of float, default=[0.05, 0.95]
A probability or list of, at which quantile forecasts are
Expand Down Expand Up @@ -3093,7 +3096,7 @@ def predict_quantiles(
@composed(crash, method_to_log, beartype)
def predict_residuals(
self,
y: Sequence[Any] | DataFrame,
y: RowSelector | YSelector,
X: XSelector | None = None,

Check notice on line 3100 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
verbose: Int | None = None,
Expand All @@ -3108,11 +3111,11 @@ def predict_residuals(
Parameters
----------
y: sequence or dataframe-like
Ground truth observations to compute residuals to.
y: int, str, dict, sequence or dataframe
Ground truth observations.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `y`.
verbose: int or None, default=None
Verbosity level for the transformers in the pipeline. If None,
Expand All @@ -3131,7 +3134,7 @@ def predict_residuals(
@composed(crash, method_to_log, beartype)
def predict_var(
self,
fh: FHSelector,
fh: RowSelector | FHSelector,
X: XSelector | None = None,

Check notice on line 3138 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
*,
cov: Bool = False,
Expand All @@ -3151,11 +3154,11 @@ def predict_var(
The forecasting horizon encoding the time stamps to
forecast at.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
cov: bool, default=False
Whether to computes covariance matrix forecast or marginal
Whether to compute covariance matrix forecast or marginal
variance forecasts.
verbose: int or None, default=None
Expand All @@ -3181,11 +3184,11 @@ def predict_var(
@composed(crash, method_to_log, beartype)
def score(
self,
y: Sequence[Any] | DataFrame,
X: DataFrame | None = None,
y: RowSelector | YSelector,
X: XSelector | None = None,

Check notice on line 3188 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
fh: FHSelector | None = None,
*,
metric: MetricConstructor = None,
metric: str | MetricFunction | Scorer | None = None,
verbose: Int | None = None,
) -> Float:
"""Get a metric score on new data.
Expand All @@ -3203,11 +3206,11 @@ def score(
Parameters
----------
y: sequence or dataframe-like
y: int, str, dict, sequence or dataframe
Ground truth observations.
X: dataframe-like or None, default=None
Exogenous time series corresponding to fh.
X: hashable, segment, sequence, dataframe-like or None, default=None
Exogenous time series corresponding to `fh`.
fh: int, sequence or [ForecastingHorizon][] or None, default=None
The forecasting horizon encoding the time stamps to
Expand Down
4 changes: 2 additions & 2 deletions atom/plots/predictionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def plot_confusion_matrix(
models: int, str, Model, segment, sequence or None, default=None
Models to plot. If None, all models are selected.
rows: hashable, range, slice or sequence, default="test"
rows: hashable, segment or sequence, default="test"
[Selection of rows][row-and-column-selection] on which to
calculate the confusion matrix.
Expand Down Expand Up @@ -2678,7 +2678,7 @@ def plot_probabilities(
models: int, str, Model, segment, sequence or None, default=None
Models to plot. If None, all models are selected.
rows: hashable, range, slice or sequence, default="test"
rows: hashable, segment or sequence, default="test"
[Selection of rows][row-and-column-selection] on which to
calculate the metric.
Expand Down
12 changes: 2 additions & 10 deletions atom/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from atom.utils.constants import __version__
from atom.utils.types import (
Bool, DataFrame, Estimator, Float, Index, IndexSelector, Int,
IntLargerEqualZero, MetricConstructor, Model, Pandas, Predictor, Scalar,
IntLargerEqualZero, MetricFunction, Model, Pandas, Predictor, Scalar,
Scorer, Segment, Sequence, Series, Transformer, TReturn, TReturns, Verbose,
XSelector, YSelector, YTypes, dataframe_t, int_t, pandas_t, segment_t,
sequence_t, series_t,
Expand Down Expand Up @@ -2054,7 +2054,7 @@ def check_attr(attr: str) -> bool:
return True


def get_custom_scorer(metric: MetricConstructor) -> Scorer:
def get_custom_scorer(metric: str | MetricFunction | Scorer) -> Scorer:
"""Get a scorer from a str, func or scorer.
Scorers used by ATOM have a name attribute.
Expand Down Expand Up @@ -2780,14 +2780,6 @@ def wrap_methods(f: Callable) -> Callable:
- Check if the instance is fitted before transforming.
- Convert output to pyarrow dtypes if specified in config.
Parameters
----------
f: callable
Function to decorate.
check_fitted: bool
Whether to check if the instance is fitted.
"""

@wraps(f)
Expand Down
3 changes: 2 additions & 1 deletion docs_sources/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ packages are necessary for its correct functioning.
* **[gplearn](https://gplearn.readthedocs.io/en/stable/index.html)** (>=0.4.2)
* **[imbalanced-learn](https://imbalanced-learn.readthedocs.io/en/stable/api.html)** (>=0.11.0)
* **[ipython](https://ipython.readthedocs.io/en/stable/)** (>=8.11.0)
* **[ipywidgets](https://pypi.org/project/ipywidgets/)** (>=8.1.1)
* **[featuretools](https://www.featuretools.com/)** (>=1.28.0)
* **[joblib](https://joblib.readthedocs.io/en/latest/)** (>=1.3.1)
* **[matplotlib](https://matplotlib.org/)** (>=3.7.2)
Expand Down Expand Up @@ -63,7 +64,7 @@ additional libraries. You can install all the optional dependencies using
* **[lightgbm](https://lightgbm.readthedocs.io/en/latest/)** (>=4.1.0)
* **[pmdarima](http://alkaline-ml.com/pmdarima/)** (>=2.0.3)
* **[schemdraw](https://schemdraw.readthedocs.io/en/latest/index.html)** (>=0.16)
* **[sweetviz](https://github.com/fbdesignpro/sweetviz)** (>=2.2.1)
* **[sweetviz](https://github.com/fbdesignpro/sweetviz)** (>=2.3.1)
* **[wordcloud](http://amueller.github.io/word_cloud/)** (>=1.9.2)
* **[xgboost](https://xgboost.readthedocs.io/en/latest/)** (>=2.0.0)

Expand Down
Loading

0 comments on commit 5fca7a1

Please sign in to comment.