|
| 1 | +try: |
| 2 | + import lightgbm as lgb |
| 3 | + import numpy as np |
| 4 | + import pandas as pd |
| 5 | + from sklearn.externals import joblib |
| 6 | + from steppy.base import BaseTransformer |
| 7 | + from steppy.utils import get_logger |
| 8 | + |
| 9 | + from toolkit.utils import SteppyToolkitError |
| 10 | +except ImportError as e: |
| 11 | + msg = 'SteppyToolkitError: you have missing modules. Install requirements specific to lightgbm_transformers.' \ |
| 12 | + 'Use this file: toolkit/lightgbm_transformers/requirements.txt' |
| 13 | + raise SteppyToolkitError(msg) from e |
| 14 | + |
| 15 | +logger = get_logger() |
| 16 | + |
| 17 | + |
| 18 | +class LightGBM(BaseTransformer): |
| 19 | + """ |
| 20 | + Accepts three dictionaries that reflects LightGBM API: |
| 21 | + - booster_parameters -> parameters of the Booster |
| 22 | + See: https://lightgbm.readthedocs.io/en/latest/Parameters.html |
| 23 | + - dataset_parameters -> parameters of the lightgbm.Dataset class |
| 24 | + See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#data-structure-api |
| 25 | + - training_parameters -> parameters of the lightgbm.train function |
| 26 | + See: https://lightgbm.readthedocs.io/en/latest/Python-API.html#training-api |
| 27 | + """ |
| 28 | + def __init__(self, |
| 29 | + booster_parameters=None, |
| 30 | + dataset_parameters=None, |
| 31 | + training_parameters=None): |
| 32 | + super().__init__() |
| 33 | + logger.info('initializing LightGBM transformer') |
| 34 | + if booster_parameters is not None: |
| 35 | + isinstance(booster_parameters, dict), 'LightGBM transformer: booster_parameters must be dict, ' \ |
| 36 | + 'got {} instead'.format(type(booster_parameters)) |
| 37 | + if dataset_parameters is not None: |
| 38 | + isinstance(dataset_parameters, dict), 'LightGBM transformer: dataset_parameters must be dict, ' \ |
| 39 | + 'got {} instead'.format(type(dataset_parameters)) |
| 40 | + if training_parameters is not None: |
| 41 | + isinstance(training_parameters, dict), 'LightGBM transformer: training_parameters must be dict, ' \ |
| 42 | + 'got {} instead'.format(type(training_parameters)) |
| 43 | + |
| 44 | + self.booster_parameters = booster_parameters or {} |
| 45 | + self.dataset_parameters = dataset_parameters or {} |
| 46 | + self.training_parameters = training_parameters or {} |
| 47 | + |
| 48 | + def fit(self, X, y, X_valid, y_valid): |
| 49 | + self._check_target_shape_and_type(y, 'y') |
| 50 | + self._check_target_shape_and_type(y_valid, 'y_valid') |
| 51 | + y = self._format_target(y) |
| 52 | + y_valid = self._format_target(y_valid) |
| 53 | + |
| 54 | + logger.info('LightGBM transformer, train data shape {}'.format(X.shape)) |
| 55 | + logger.info('LightGBM transformer, validation data shape {}'.format(X_valid.shape)) |
| 56 | + logger.info('LightGBM transformer, train labels shape {}'.format(y.shape)) |
| 57 | + logger.info('LightGBM transformer, validation labels shape {}'.format(y_valid.shape)) |
| 58 | + |
| 59 | + data_train = lgb.Dataset(data=X, |
| 60 | + label=y, |
| 61 | + **self.dataset_parameters) |
| 62 | + data_valid = lgb.Dataset(data=X_valid, |
| 63 | + label=y_valid, |
| 64 | + **self.dataset_parameters) |
| 65 | + self.estimator = lgb.train(params=self.booster_parameters, |
| 66 | + train_set=data_train, |
| 67 | + valid_sets=[data_train, data_valid], |
| 68 | + valid_names=['data_train', 'data_valid'], |
| 69 | + **self.training_parameters) |
| 70 | + return self |
| 71 | + |
| 72 | + def transform(self, X, y=None): |
| 73 | + prediction = self.estimator.predict(X) |
| 74 | + return {'prediction': prediction} |
| 75 | + |
| 76 | + def load(self, filepath): |
| 77 | + self.estimator = joblib.load(filepath) |
| 78 | + return self |
| 79 | + |
| 80 | + def persist(self, filepath): |
| 81 | + joblib.dump(self.estimator, filepath) |
| 82 | + |
| 83 | + def _check_target_shape_and_type(self, target, name): |
| 84 | + if not any([isinstance(target, obj_type) for obj_type in [pd.Series, np.ndarray, list]]): |
| 85 | + msg = '"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format(type(target)) |
| 86 | + raise SteppyToolkitError(msg) |
| 87 | + if not isinstance(target, list): |
| 88 | + assert len(target.shape) == 1, '"{}" must be 1-D. It is {}-D instead.'.format(name, len(target.shape)) |
| 89 | + |
| 90 | + def _format_target(self, target): |
| 91 | + if isinstance(target, pd.Series): |
| 92 | + return target.values |
| 93 | + elif isinstance(target, np.ndarray): |
| 94 | + return target |
| 95 | + elif isinstance(target, list): |
| 96 | + return np.array(target) |
| 97 | + else: |
| 98 | + raise TypeError( |
| 99 | + '"target" must be "numpy.ndarray" or "Pandas.Series" or "list", got {} instead.'.format( |
| 100 | + type(target))) |
0 commit comments