Skip to content

Commit

Permalink
refactor modules
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Aug 30, 2023
1 parent c4cf394 commit dd136a2
Show file tree
Hide file tree
Showing 192 changed files with 15,698 additions and 15,611 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

**General Information** | |
--- | ---
**Repository** | [![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) [![Conda Recipe](https://img.shields.io/badge/recipe-atom--ml-green.svg)](https://anaconda.org/conda-forge/atom-ml) [![License: MIT](https://img.shields.io/github/license/tvdboom/ATOM)](https://opensource.org/licenses/MIT) [![Downloads](https://pepy.tech/badge/atom-ml)](https://pepy.tech/project/atom-ml)
**Repository** | [![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) [![Conda Recipe](https://img.shields.io/badge/recipe-atom--ml-green.svg)](https://anaconda.org/conda-forge/atom-ml) [![License: MIT](https://img.shields.io/github/license/tvdboom/ATOM)](https://opensource.org/licenses/MIT) [![Downloads](https://static.pepy.tech/badge/atom-ml)](https://pepy.tech/project/atom-ml)
**Release** | [![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm.fming.dev) [![PyPI version](https://img.shields.io/pypi/v/atom-ml)](https://pypi.org/project/atom-ml/) [![Conda Version](https://img.shields.io/conda/vn/conda-forge/atom-ml.svg)](https://anaconda.org/conda-forge/atom-ml) [![DOI](https://zenodo.org/badge/195069958.svg)](https://zenodo.org/badge/latestdoi/195069958)
**Compatibility** | [![Python 3.8\|3.9\|3.10\|3.11](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue?logo=python)](https://www.python.org) [![Conda Platforms](https://img.shields.io/conda/pn/conda-forge/atom-ml.svg)](https://anaconda.org/conda-forge/atom-ml)
**Build status** | [![Build Status](https://github.com/tvdboom/ATOM/workflows/ATOM/badge.svg)](https://github.com/tvdboom/ATOM/actions) [![Azure Pipelines](https://dev.azure.com/conda-forge/feedstock-builds/_apis/build/status/atom-ml-feedstock?branchName=master)](https://dev.azure.com/conda-forge/feedstock-builds/_build/latest?definitionId=10822&branchName=master) [![codecov](https://codecov.io/gh/tvdboom/ATOM/branch/master/graph/badge.svg)](https://codecov.io/gh/tvdboom/ATOM)
Expand Down
17 changes: 8 additions & 9 deletions atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from atom.atom import ATOM
from atom.basetransformer import BaseTransformer
from atom.utils.types import (
BACKEND, BOOL, ENGINE, GOAL, INDEX_SELECTOR, INT, PREDICTOR, SCALAR,
TARGET,
BACKEND, BOOL, ENGINE, INDEX_SELECTOR, INT, PREDICTOR, SCALAR,
TARGET, WARNINGS,
)


Expand Down Expand Up @@ -160,7 +160,6 @@ class ATOMClassifier(BaseTransformer, ATOM):
y: int, str, dict, sequence or dataframe, default=-1
Target column corresponding to X.
- If None: y is ignored.
- If int: Position of the target column in X.
- If str: Name of the target column in X.
- If sequence: Target array with shape=(n_samples,) or
Expand Down Expand Up @@ -336,7 +335,7 @@ def __init__(
engine: ENGINE = {"data": "numpy", "estimator": "sklearn"},
backend: BACKEND = "loky",
verbose: Literal[0, 1, 2] = 0,

Check warning on line 337 in atom/api.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
warnings: BOOL | str = False,
warnings: BOOL | WARNINGS = False,
logger: str | Logger | None = None,
experiment: str | None = None,
random_state: INT | None = None,
Expand All @@ -353,7 +352,7 @@ def __init__(
random_state=random_state,
)

self.goal: GOAL = "class"
self.goal = "class"
ATOM.__init__(
self,
arrays=arrays,
Expand Down Expand Up @@ -555,7 +554,7 @@ def __init__(
engine: ENGINE = {"data": "numpy", "estimator": "sklearn"},
backend: BACKEND = "loky",
verbose: Literal[0, 1, 2] = 0,

Check warning on line 556 in atom/api.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
warnings: BOOL | str = False,
warnings: BOOL | WARNINGS = False,
logger: str | Logger | None = None,
experiment: str | None = None,
random_state: INT | None = None,
Expand All @@ -572,7 +571,7 @@ def __init__(
random_state=random_state,
)

self.goal: GOAL = "fc"
self.goal = "fc"
ATOM.__init__(
self,
arrays=arrays,
Expand Down Expand Up @@ -790,7 +789,7 @@ def __init__(
engine: ENGINE = {"data": "numpy", "estimator": "sklearn"},
backend: BACKEND = "loky",
verbose: Literal[0, 1, 2] = 0,

Check warning on line 791 in atom/api.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
warnings: BOOL | str = False,
warnings: BOOL | WARNINGS = False,
logger: str | Logger | None = None,
experiment: str | None = None,
random_state: INT | None = None,
Expand All @@ -807,7 +806,7 @@ def __init__(
random_state=random_state,
)

self.goal: GOAL = "reg"
self.goal = "reg"
ATOM.__init__(
self,
arrays=arrays,
Expand Down
103 changes: 55 additions & 48 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
)
from atom.models import MODELS
from atom.nlp import TextCleaner, TextNormalizer, Tokenizer, Vectorizer
from atom.plots import (
DataPlot, FeatureSelectorPlot, HTPlot, PredictionPlot, ShapPlot,
)
from atom.plots import ATOMPlot
from atom.training import (
DirectClassifier, DirectForecaster, DirectRegressor,
SuccessiveHalvingClassifier, SuccessiveHalvingForecaster,
Expand All @@ -47,9 +45,10 @@
)
from atom.utils.constants import MISSING_VALUES, __version__
from atom.utils.types import (
BOOL, DATAFRAME, DATASET, FEATURES, INDEX, INDEX_SELECTOR, INT,
METRIC_SELECTOR, PANDAS, PREDICTOR, RUNNER, SCALAR, SEQUENCE, SERIES,
SLICE, TARGET, TRANSFORMER, TS_INDEX_TYPES,
BOOL, DATAFRAME, DATASET, DISCRETIZER_STRATS, ESTIMATOR, FEATURES, INDEX,
INDEX_SELECTOR, INT, METRIC_SELECTOR, PANDAS, PREDICTOR, PRUNER_STRATS,
RUNNER, SCALAR, SCALER_STRATS, SEQUENCE, SERIES, SLICE, STRAT_NUM, TARGET,
TRANSFORMER, TS_INDEX_TYPES,
)
from atom.utils.utils import (
ClassMap, DataConfig, check_dependency, check_is_fitted, check_scaling,
Expand All @@ -60,7 +59,7 @@


@typechecked
class ATOM(BaseRunner, FeatureSelectorPlot, DataPlot, HTPlot, PredictionPlot, ShapPlot):
class ATOM(BaseRunner, ATOMPlot):
"""ATOM base class.
The ATOM class is a convenient wrapper for all data cleaning,
Expand Down Expand Up @@ -160,7 +159,7 @@ def __repr__(self) -> str:

return out

def __iter__(self) -> TRANSFORMER:
def __iter__(self) -> TRANSFORMER | None:
yield from self.pipeline.values

# Utility properties =========================================== >>
Expand Down Expand Up @@ -545,7 +544,7 @@ def inverse_transform(
y: TARGET | None = None,
*,
verbose: INT | None = None,
) -> PANDAS | tuple[DATAFRAME, SERIES]:
) -> PANDAS | tuple[DATAFRAME, PANDAS]:
"""Inversely transform new data through the pipeline.
Transformers that are only applied on the training set are
Expand Down Expand Up @@ -898,7 +897,7 @@ def get_data(new_t: str) -> SERIES:
get_data(r[0]) for r in t if r[1] <= column.min() and r[2] >= column.max()
)

if self.engine["data"] == "pyarrow":
if self.engine.get("data") == "pyarrow":
self.branch.dataset = self.branch.dataset.astype(
{name: to_pyarrow(col) for name, col in self.branch._data.items()}

Check notice on line 902 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
)
Expand Down Expand Up @@ -986,7 +985,7 @@ def transform(
y: TARGET | None = None,
*,
verbose: INT | None = None,
) -> PANDAS | tuple[DATAFRAME, SERIES]:
) -> PANDAS | tuple[DATAFRAME, PANDAS]:
"""Transform new data through the pipeline.
Transformers that are only applied on the training set are
Expand Down Expand Up @@ -1068,7 +1067,7 @@ def _add_transformer(
self,
transformer: TRANSFORMER,
columns: SLICE | None = None,
train_only: bool = False,
train_only: BOOL = False,
**fit_params,
):
"""Add a transformer to the pipeline.
Expand Down Expand Up @@ -1106,9 +1105,6 @@ def _add_transformer(
"new branch to continue the pipeline."
)

if not hasattr(transformer, "transform"):
raise AttributeError("Added transformers should have a transform method!")

# Add BaseTransformer params to the estimator if left to default
transformer = self._inherit(transformer)

Expand Down Expand Up @@ -1160,7 +1156,7 @@ def add(
transformer: TRANSFORMER,
*,
columns: SLICE | None = None,
train_only: bool = False,
train_only: BOOL = False,
**fit_params,
):
"""Add a transformer to the pipeline.
Expand Down Expand Up @@ -1249,9 +1245,8 @@ def apply(
):
"""Apply a function to the dataset.
The function should have signature `func(dataset, **kw_args) ->
dataset`. This method is useful for stateless transformations
such as taking the log, doing custom scaling, etc...
This method is useful for stateless transformations such as
taking the log, doing custom scaling, etc...
!!! note
This approach is preferred over changing the dataset directly
Expand All @@ -1265,7 +1260,8 @@ def apply(
Parameters
----------
func: callable
Function to apply.
Function to apply with signature `func(dataset, **kw_args) ->
dataset`.
inverse_func: callable or None, default=None
Inverse function of `func`. If None, the inverse_transform
Expand Down Expand Up @@ -1336,13 +1332,13 @@ def balance(self, strategy: str = "adasyn", **kwargs):
def clean(
self,
*,
convert_dtypes: bool = True,
convert_dtypes: BOOL = True,
drop_dtypes: str | SEQUENCE | None = None,
drop_chars: str | None = None,
strip_categorical: bool = True,
drop_duplicates: bool = False,
drop_missing_target: bool = True,
encode_target: bool = True,
strip_categorical: BOOL = True,
drop_duplicates: BOOL = False,
drop_missing_target: BOOL = True,
encode_target: BOOL = True,
**kwargs,
):
"""Applies standard data cleaning steps on the dataset.
Expand Down Expand Up @@ -1382,7 +1378,7 @@ def clean(
@composed(crash, method_to_log)
def discretize(
self,
strategy: str = "quantile",
strategy: DISCRETIZER_STRATS = "quantile",
*,
bins: INT | SEQUENCE | dict = 5,
labels: SEQUENCE | dict | None = None,
Expand Down Expand Up @@ -1467,7 +1463,7 @@ def encode(
@composed(crash, method_to_log)
def impute(
self,
strat_num: SCALAR | Literal["drop", "mean", "knn", "most_frequent"] = "drop",
strat_num: STRAT_NUM = "drop",
strat_cat: Literal["drop", "most_frequent"] | str = "drop",

Check warning on line 1467 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
*,
max_nan_rows: SCALAR | None = None,
Expand Down Expand Up @@ -1539,11 +1535,11 @@ def normalize(
@composed(crash, method_to_log)
def prune(
self,
strategy: str | SEQUENCE = "zscore",
strategy: PRUNER_STRATS | SEQUENCE = "zscore",
*,
method: SCALAR | Literal["drop", "minmax"] = "drop",

Check warning on line 1540 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
max_sigma: SCALAR = 3,
include_target: bool = False,
include_target: BOOL = False,
**kwargs,
):
"""Prune outliers from the training set.
Expand Down Expand Up @@ -1581,7 +1577,12 @@ def prune(
setattr(self.branch, strat.lower(), getattr(pruner, strat.lower()))

@composed(crash, method_to_log)
def scale(self, strategy: str = "standard", include_binary: bool = False, **kwargs):
def scale(
self,
strategy: SCALER_STRATS = "standard",
include_binary: BOOL = False,
**kwargs,
):
"""Scale the data.
Apply one of sklearn's scalers. Categorical columns are ignored.
Expand Down Expand Up @@ -1611,19 +1612,19 @@ def scale(self, strategy: str = "standard", include_binary: bool = False, **kwar
def textclean(
self,
*,
decode: bool = True,
lower_case: bool = True,
drop_email: bool = True,
decode: BOOL = True,
lower_case: BOOL = True,
drop_email: BOOL = True,
regex_email: str | None = None,
drop_url: bool = True,
drop_url: BOOL = True,
regex_url: str | None = None,
drop_html: bool = True,
drop_html: BOOL = True,
regex_html: str | None = None,
drop_emoji: bool = True,
drop_emoji: BOOL = True,
regex_emoji: str | None = None,
drop_number: bool = True,
drop_number: BOOL = True,
regex_number: str | None = None,
drop_punctuation: bool = True,
drop_punctuation: BOOL = True,
**kwargs,
):
"""Applies standard text cleaning to the corpus.
Expand Down Expand Up @@ -1664,10 +1665,10 @@ def textclean(
def textnormalize(
self,
*,
stopwords: bool | str = True,
stopwords: BOOL | str = True,
custom_stopwords: SEQUENCE | None = None,
stem: bool | str = False,
lemmatize: bool = True,
stem: BOOL | str = False,
lemmatize: BOOL = True,
**kwargs,
):
"""Normalize the corpus.
Expand Down Expand Up @@ -1727,7 +1728,13 @@ def tokenize(
self.branch.quadgrams = tokenizer.quadgrams

@composed(crash, method_to_log)
def vectorize(self, strategy: str = "bow", *, return_sparse: bool = True, **kwargs):
def vectorize(
self,
strategy: Literal["bow", "tfidf", "hashing"] = "bow",

Check warning on line 1733 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
*,
return_sparse: BOOL = True,
**kwargs,
):
"""Vectorize the corpus.
Transform the corpus into meaningful vectors of numbers. The
Expand Down Expand Up @@ -1766,7 +1773,7 @@ def feature_extraction(
fmt: str | SEQUENCE | None = None,
*,
encoding_type: str = "ordinal",
drop_columns: bool = True,
drop_columns: BOOL = True,
**kwargs,
):
"""Extract features from datetime columns.
Expand Down Expand Up @@ -1831,7 +1838,7 @@ def feature_grouping(
group: dict[str, str | SEQUENCE],
*,
operators: str | SEQUENCE | None = None,
drop_columns: bool = True,
drop_columns: BOOL = True,
**kwargs,
):
"""Extract statistics from similar features.
Expand Down Expand Up @@ -1862,7 +1869,7 @@ def feature_selection(
self,
strategy: str | None = None,
*,
solver: str | Callable | None = None,
solver: str | ESTIMATOR | None = None,
n_features: SCALAR | None = None,
min_repeated: SCALAR | None = 2,
max_repeated: SCALAR | None = 1.0,
Expand Down Expand Up @@ -2005,7 +2012,7 @@ def run(
n_trials: INT | dict | SEQUENCE = 0,
ht_params: dict | None = None,
n_bootstrap: INT | SEQUENCE = 0,
parallel: bool = False,
parallel: BOOL = False,
errors: Literal["raise", "skip", "keep"] = "skip",

Check warning on line 2016 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
**kwargs,
):
Expand Down Expand Up @@ -2061,7 +2068,7 @@ def successive_halving(
n_trials: INT | dict | SEQUENCE = 0,
ht_params: dict | None = None,
n_bootstrap: INT | dict | SEQUENCE = 0,
parallel: bool = False,
parallel: BOOL = False,
errors: Literal["raise", "skip", "keep"] = "skip",

Check warning on line 2072 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
**kwargs,
):
Expand Down Expand Up @@ -2124,7 +2131,7 @@ def train_sizing(
n_trials: INT | dict | SEQUENCE = 0,
ht_params: dict | None = None,
n_bootstrap: INT | dict | SEQUENCE = 0,
parallel: bool = False,
parallel: BOOL = False,
errors: Literal["raise", "skip", "keep"] = "skip",

Check warning on line 2135 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Invalid type hints definitions and usages

'Literal' may be parameterized with literal ints, byte and unicode strings, bools, Enum values, None, other literal types, or type aliases to other literal types
**kwargs,
):
Expand Down
Loading

0 comments on commit dd136a2

Please sign in to comment.