Skip to content

Commit

Permalink
add prophet
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 19, 2023
1 parent bafd417 commit 3efd838
Show file tree
Hide file tree
Showing 17 changed files with 832 additions and 83 deletions.
23 changes: 13 additions & 10 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __dir__(self) -> list[str]:
if "_branch" in self.__dict__:
attrs += [x for x in dir(self.branch) if not x.startswith("_")]
attrs += list(DF_ATTRS)
attrs += list(self.columns)
attrs += [c for c in self.columns if re.fullmatch(r"\w+$", c)]
return attrs

def __getattr__(self, item: str) -> Any:
Expand Down Expand Up @@ -694,7 +694,7 @@ def _score_from_est(
X: DataFrame,

Check notice on line 694 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas,
**kwargs,
) -> float:
) -> Float:
"""Calculate the metric score from an estimator.
Parameters
Expand Down Expand Up @@ -737,7 +737,7 @@ def _score_from_pred(
y_true: Pandas,
y_pred: Pandas,
**kwargs,
) -> float:
) -> Float:
"""Calculate the metric score from predicted values.
Since sklearn metrics don't support multiclass-multioutput
Expand Down Expand Up @@ -770,12 +770,15 @@ def _score_from_pred(
if self.task.is_forecast and all(x.isna()[0] for x in get_cols(y_pred)):
y_true, y_pred = y_true.iloc[1:], y_pred.iloc[1:]

if self.task is Task.multiclass_multioutput_classification:
# Get the mean of the scores over the target columns
scores = [scorer._sign * func(y_true[c], y_pred[c]) for c in y_pred.columns]
return float(np.mean(scores, axis=0))
else:
return float(scorer._sign * func(y_true, y_pred))
try:
if self.task is Task.multiclass_multioutput_classification:
# Get the mean of the scores over the target columns
scores = [scorer._sign * func(y_true[c], y_pred[c]) for c in y_pred.columns]

Check notice on line 776 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _sign of a class
return np.mean(scores, axis=0)
else:
return scorer._sign * func(y_true, y_pred)

Check notice on line 779 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _sign of a class
except ValueError:
return np.NaN # Some forecast models predict NaN

def _get_score(
self,
Expand Down Expand Up @@ -885,7 +888,7 @@ def fit_model(
estimator: Predictor,

Check notice on line 888 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'estimator' from outer scope
train_idx: np.ndarray,
val_idx: np.ndarray,
) -> tuple[Predictor, list[float]]:
) -> tuple[Predictor, list[Float]]:
"""Fit the model. Function for parallelization.
Divide the training set in a (sub) train and validation
Expand Down
26 changes: 13 additions & 13 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __dir__(self) -> list[str]:
attrs += [x for x in dir(self.branch) if not x.startswith("_")]
attrs += list(DF_ATTRS)
attrs += [b.name.lower() for b in self._branches]
attrs += list(self.columns)
attrs += [c for c in self.columns if re.fullmatch(r"\w+$", c)]
if isinstance(self._models, ClassMap):
attrs += [m.name.lower() for m in self._models]
return attrs
Expand Down Expand Up @@ -163,25 +163,25 @@ def sp(self) -> int | list[int] | None:
Read more about seasonality in the [user guide][seasonality].
"""
return self._sp
return self._config.sp

@sp.setter
def sp(self, sp: Seasonality):
"""Convert seasonal period to integer value."""
if sp is None:
self._sp = None
self._config.sp = None
elif sp == "index":
if not hasattr(self.dataset.index, "freqstr"):
raise ValueError(
f"Invalid value for the seasonal period, got {sp}. "
f"The dataset's index has no attribute freqstr."
)
else:
self._sp = self._get_sp(self.dataset.index.freqstr)
self._config.sp = self._get_sp(self.dataset.index.freqstr)
elif sp == "infer":
self._sp = self.get_seasonal_period()
self._config.sp = self.get_seasonal_period()
else:
self._sp = flt([self._get_sp(x) for x in lst(sp)])
self._config.sp = flt([self._get_sp(x) for x in lst(sp)])

@property
def og(self) -> Branch:
Expand Down Expand Up @@ -901,15 +901,15 @@ def available_models(self) -> pd.DataFrame:
- **fullname:** Name of the model's class.
- **estimator:** Class of the model's underlying estimator.
- **module:** The estimator's module.
- **handles_missing:** Whether the model can handle `NaN` values
without preprocessing.
- **handles_missing:** Whether the model can handle missing
(`NaN`) values without preprocessing. If False, consider using
the [Imputer][] class before training the models.
- **needs_scaling:** Whether the model requires feature scaling.
- **accepts_sparse:** Whether the model accepts sparse matrices.
- **uses_exogenous:** Whether the model uses exogenous variables.
- **in_sample_prediction:** Whether the model can do predictions
on the training set.
If True, [automated feature scaling][] is applied.
- **accepts_sparse:** Whether the model accepts [sparse input][sparse-datasets].
- **uses_exogenous:** Whether the model uses [exogenous variables][].
- **multiple_seasonality:** Whether the model can handle more than
one [seasonality periods][seasonality].
one [seasonality period][seasonality].
- **native_multilabel:** Whether the model has native support
for [multilabel][] tasks.
- **native_multioutput:** Whether the model has native support
Expand Down
1 change: 1 addition & 0 deletions atom/basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def _prepare_parameters(self):
"BATS": "tbats",
"CatB": "catboost",
"LGB": "lightgbm",
"MSTL": "statsforecast",
"TBATS": "tbats",
"XGB": "xgboost",
}
Expand Down
11 changes: 8 additions & 3 deletions atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,14 @@ def _inherit(self, obj: T_Estimator) -> T_Estimator:
"""
signature = sign(obj.__init__) # type: ignore[misc]
for p in ("sp", "n_jobs", "random_state"):
if p in signature and getattr(obj, p, "<!>") == signature[p]._default:
setattr(obj, p, getattr(self, p, signature[p]._default))
for p in ("n_jobs", "random_state"):
if p in signature and obj.get_params()[p] == signature[p]._default:

Check notice on line 375 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
obj.set_params(**{p: getattr(self, p)})

# Add seasonal period to the estimator
if self._config.sp:
if "sp" in signature and obj.get_params()["sp"] == signature["sp"]._default:

Check notice on line 380 in atom/basetransformer.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _default of a class
obj.set_params(sp=self._config.sp)

return obj

Expand Down
11 changes: 9 additions & 2 deletions atom/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from atom.models.custom import CustomModel
from atom.models.ensembles import Stacking, Voting
from atom.models.ts import (
ARIMA, BATS, ETS, STL, TBATS, AutoARIMA, Croston, ExponentialSmoothing,
NaiveForecaster, PolynomialTrend, Theta,
ARIMA, BATS, ETS, MSTL, SARIMAX, STL, TBATS, VAR, VARMAX, AutoARIMA,
Croston, DynamicFactor, ExponentialSmoothing, NaiveForecaster,
PolynomialTrend, Prophet, Theta,
)
from atom.utils.types import Predictor
from atom.utils.utils import ClassMap
Expand All @@ -43,6 +44,7 @@
Croston,
DecisionTree,
Dummy,
DynamicFactor,
ElasticNet,
ETS,
ExponentialSmoothing,
Expand All @@ -60,23 +62,28 @@
LinearDiscriminantAnalysis,
LinearSVM,
LogisticRegression,
MSTL,
MultiLayerPerceptron,
MultinomialNB,
NaiveForecaster,
OrdinaryLeastSquares,
OrthogonalMatchingPursuit,
PassiveAggressive,
Perceptron,
Prophet,
PolynomialTrend,
QuadraticDiscriminantAnalysis,
RadiusNearestNeighbors,
RandomForest,
Ridge,
SARIMAX,
STL,
StochasticGradientDescent,
SupportVectorMachine,
TBATS,
Theta,
VAR,
VARMAX,
XGBoost,
key="acronym",
)
Expand Down
8 changes: 4 additions & 4 deletions atom/models/classreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ class CatBoost(ClassRegModel):
"""

acronym = "CatB"
handles_missing = False
handles_missing = True
needs_scaling = True
accepts_sparse = True
native_multilabel = False
Expand Down Expand Up @@ -1640,7 +1640,7 @@ class LightGBM(ClassRegModel):
"""

acronym = "LGB"
handles_missing = False
handles_missing = True
needs_scaling = True
accepts_sparse = True
native_multilabel = False
Expand Down Expand Up @@ -2165,7 +2165,7 @@ def _trial_to_est(self, params: dict[str, Any]) -> dict[str, Any]:
hidden_layer_sizes = [
value
for param in [p for p in sorted(params) if p.startswith("hidden_layer")]
if (value := params.pop(param)) # Neurons should be more than zero
if (value := params.pop(param)) # Neurons should be >0
]

if hidden_layer_sizes:
Expand Down Expand Up @@ -3078,7 +3078,7 @@ class XGBoost(ClassRegModel):
"""

acronym = "XGB"
handles_missing = False
handles_missing = True
needs_scaling = True
accepts_sparse = True
native_multilabel = False
Expand Down
Loading

0 comments on commit 3efd838

Please sign in to comment.