Skip to content
This repository was archived by the owner on Jun 22, 2022. It is now read-only.

Commit 89ca705

Browse files
author
kamil-kaczmarek
committed
Merge branch 'Ninoko-master'
2 parents 78a1a09 + 7c7aaac commit 89ca705

File tree

14 files changed

+462
-127
lines changed

14 files changed

+462
-127
lines changed

Diff for: setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
setup(name='steppy-toolkit',
1717
packages=find_packages(),
18-
version='0.1.9',
18+
version='0.1.10',
1919
description='Set of tools to make your work with steppy faster and more effective.',
2020
long_description=long_description,
2121
url='https://github.com/minerva-ml/steppy-toolkit',
22-
download_url='https://github.com/minerva-ml/steppy-toolkit/archive/0.1.9.tar.gz',
22+
download_url='https://github.com/minerva-ml/steppy-toolkit/archive/0.1.10.tar.gz',
2323
author='Kamil A. Kaczmarek, Jakub Czakon',
2424
2525
keywords=['machine-learning', 'reproducibility', 'pipeline', 'tools'],

Diff for: toolkit/catboost_transformers/__init__.py

Whitespace-only changes.

Diff for: toolkit/catboost_transformers/models.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
try:
2+
import catboost as ctb
3+
from catboost import CatBoostClassifier
4+
from steppy.base import BaseTransformer
5+
from steppy.utils import get_logger
6+
7+
from toolkit.sklearn_transformers.models import MultilabelEstimators
8+
from toolkit.utils import SteppyToolkitError
9+
except ImportError as e:
10+
msg = 'SteppyToolkitError: you have missing modules. Install requirements specific to catboost_transformers.' \
11+
'Use this file: toolkit/catboost_transformers/requirements.txt'
12+
raise SteppyToolkitError(msg) from e
13+
14+
logger = get_logger()
15+
16+
17+
class CatboostClassifierMultilabel(MultilabelEstimators):
18+
@property
19+
def estimator(self):
20+
return CatBoostClassifier
21+
22+
23+
class CatBoost(BaseTransformer):
24+
def __init__(self, **kwargs):
25+
super().__init__()
26+
self.estimator = ctb.CatBoostClassifier(**kwargs)
27+
28+
def fit(self,
29+
X, y,
30+
X_valid, y_valid,
31+
feature_names=None,
32+
categorical_features=None,
33+
**kwargs):
34+
35+
logger.info('Catboost, train data shape {}'.format(X.shape))
36+
logger.info('Catboost, validation data shape {}'.format(X_valid.shape))
37+
logger.info('Catboost, train labels shape {}'.format(y.shape))
38+
logger.info('Catboost, validation labels shape {}'.format(y_valid.shape))
39+
40+
categorical_indeces = self._get_categorical_indices(feature_names, categorical_features)
41+
self.estimator.fit(X, y,
42+
eval_set=(X_valid, y_valid),
43+
cat_features=categorical_indeces)
44+
return self
45+
46+
def transform(self, X, **kwargs):
47+
prediction = self.estimator.predict_proba(X)[:, 1]
48+
return {'prediction': prediction}
49+
50+
def load(self, filepath):
51+
self.estimator.load_model(filepath)
52+
return self
53+
54+
def persist(self, filepath):
55+
self.estimator.save_model(filepath)
56+
57+
def _get_categorical_indices(self, feature_names, categorical_features):
58+
if categorical_features:
59+
return [feature_names.index(feature) for feature in categorical_features]
60+
else:
61+
return None

Diff for: toolkit/catboost_transformers/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
catboost
2+
steppy

Diff for: toolkit/lightgbm_transformers/__init__.py

Whitespace-only changes.

Diff for: toolkit/lightgbm_transformers/models.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
try:
2+
import lightgbm as lgb
3+
import numpy as np
4+
import pandas as pd
5+
from sklearn.externals import joblib
6+
from steppy.base import BaseTransformer
7+
from steppy.utils import get_logger
8+
9+
from toolkit.utils import SteppyToolkitError
10+
except ImportError as e:
11+
msg = 'SteppyToolkitError: you have missing modules. Install requirements specific to lightgbm_transformers.' \
12+
'Use this file: toolkit/lightgbm_transformers/requirements.txt'
13+
raise SteppyToolkitError(msg) from e
14+
15+
logger = get_logger()
16+
17+
18+
class LightGBM(BaseTransformer):
19+
"""
20+
Accepts three dictionaries that reflects LightGBM API:
21+
- booster_parameters -> parameters of the Booster
22+
See: https://lightgbm.readthedocs.io/en/latest/Parameters.html
23+
- dataset_parameters -> parameters of the lightgbm.Dataset class
24+
See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#data-structure-api
25+
- training_parameters -> parameters of the lightgbm.train function
26+
See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#training-api
27+
"""
28+
def __init__(self,
29+
booster_parameters=None,
30+
dataset_parameters=None,
31+
training_parameters=None):
32+
super().__init__()
33+
logger.info('initializing LightGBM transformer')
34+
if booster_parameters is not None:
35+
isinstance(booster_parameters, dict), 'LightGBM transformer: booster_parameters must be dict, ' \
36+
'got {} instead'.format(type(booster_parameters))
37+
if dataset_parameters is not None:
38+
isinstance(dataset_parameters, dict), 'LightGBM transformer: dataset_parameters must be dict, ' \
39+
'got {} instead'.format(type(dataset_parameters))
40+
if training_parameters is not None:
41+
isinstance(training_parameters, dict), 'LightGBM transformer: training_parameters must be dict, ' \
42+
'got {} instead'.format(type(training_parameters))
43+
44+
self.booster_parameters = booster_parameters or {}
45+
self.dataset_parameters = dataset_parameters or {}
46+
self.training_parameters = training_parameters or {}
47+
48+
def fit(self, X, y, X_valid, y_valid):
49+
self._check_target_shape_and_type(y, 'y')
50+
self._check_target_shape_and_type(y_valid, 'y_valid')
51+
y = self._format_target(y)
52+
y_valid = self._format_target(y_valid)
53+
54+
logger.info('LightGBM transformer, train data shape {}'.format(X.shape))
55+
logger.info('LightGBM transformer, validation data shape {}'.format(X_valid.shape))
56+
logger.info('LightGBM transformer, train labels shape {}'.format(y.shape))
57+
logger.info('LightGBM transformer, validation labels shape {}'.format(y_valid.shape))
58+
59+
data_train = lgb.Dataset(data=X,
60+
label=y,
61+
**self.dataset_parameters)
62+
data_valid = lgb.Dataset(data=X_valid,
63+
label=y_valid,
64+
**self.dataset_parameters)
65+
self.estimator = lgb.train(params=self.booster_parameters,
66+
train_set=data_train,
67+
valid_sets=[data_train, data_valid],
68+
valid_names=['data_train', 'data_valid'],
69+
**self.training_parameters)
70+
return self
71+
72+
def transform(self, X, y=None):
73+
prediction = self.estimator.predict(X)
74+
return {'prediction': prediction}
75+
76+
def load(self, filepath):
77+
self.estimator = joblib.load(filepath)
78+
return self
79+
80+
def persist(self, filepath):
81+
joblib.dump(self.estimator, filepath)
82+
83+
def _check_target_shape_and_type(self, target, name):
84+
if not any([isinstance(target, obj_type) for obj_type in [pd.Series, np.ndarray, list]]):
85+
msg = '"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format(type(target))
86+
raise SteppyToolkitError(msg)
87+
if not isinstance(target, list):
88+
assert len(target.shape) == 1, '"{}" must be 1-D. It is {}-D instead.'.format(name, len(target.shape))
89+
90+
def _format_target(self, target):
91+
if isinstance(target, pd.Series):
92+
return target.values
93+
elif isinstance(target, np.ndarray):
94+
return target
95+
elif isinstance(target, list):
96+
return np.array(target)
97+
else:
98+
raise TypeError(
99+
'"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format(
100+
type(target)))

Diff for: toolkit/lightgbm_transformers/requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
attrdict
2+
lightgbm
3+
numpy
4+
pandas
5+
sklearn
6+
steppy

Diff for: toolkit/misc.py

-114
This file was deleted.

0 commit comments

Comments
 (0)