Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Field-aware factorization machines #604

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c4a8b21
Remove nonccl build. Rewrite builds to CentOS. Unify x86_64 and ppc64…
mdymczyk Apr 7, 2018
23ee96b
Bring back 1 script and some convenient make targets
mdymczyk May 6, 2018
c0552da
Make full git clone on googletest
mdymczyk May 10, 2018
5e39049
Move tests to a common root folder. Move python req files to python f…
mdymczyk Apr 30, 2018
fa86dd7
Install custom arrow and pillow in runtime dockers for ppc64le. DRY d…
mdymczyk May 1, 2018
a8e92eb
Install git for runtime on ppc64le
mdymczyk May 1, 2018
1f46abb
Install git in runtime docker
mdymczyk May 1, 2018
ba7b3d3
Initial FFM GPU implementation.
mdymczyk May 2, 2018
46f737b
Initial Python bindings
mdymczyk May 11, 2018
8aec2ab
Initial ffm prediction impl
mdymczyk May 11, 2018
d443976
Fix gradient/weights computation. Pass parameters by ref to Python. T…
mdymczyk May 11, 2018
b1b0c6a
Fixes logloss calc. Fixes weight index calculation.
mdymczyk May 11, 2018
d27df65
Fixes label initialization - on the actual object not a copy.
mdymczyk May 13, 2018
0a7514a
Pass ffm model as reference so final weights get updated and passed b…
mdymczyk May 14, 2018
329d2de
Rewrite FFM data structures into pure pointers - much faster but stil…
mdymczyk May 15, 2018
90911f5
Faster wTx computation but still slow and not scalable
mdymczyk May 16, 2018
6ad0d25
FFM wTx using a faster kernel - still only on par with CPU
mdymczyk May 17, 2018
25fb40b
FFM now runs ~2x faster on large data:
mdymczyk May 21, 2018
a127759
Don't copy/allocate weights unnecessarily in the model
mdymczyk May 21, 2018
d2b97a3
Fix indexing issues in FFM
mdymczyk May 25, 2018
94265db
Initial validation dataset handling
mdymczyk May 31, 2018
3990da8
Validation dataset
mdymczyk Jun 1, 2018
f08a917
Validation data and early stopping
mdymczyk Jun 1, 2018
665237a
Support GPU computation of datasets larger than GPU memory
mdymczyk Jun 1, 2018
1732077
FFM CPU
mdymczyk Jun 9, 2018
b32728e
Fix tests and pylint
mdymczyk Jun 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Initial ffm prediction impl
mdymczyk committed May 31, 2018
commit 8aec2ab6c7ae783f54507c85e370cdce6e318610
31 changes: 31 additions & 0 deletions src/base/ffm/ffm.cpp
Original file line number Diff line number Diff line change
@@ -11,6 +11,9 @@ namespace ffm {
template<typename T>
FFM<T>::FFM(Params const &params) : params(params), model(params) {}

template<typename T>
FFM<T>::FFM(Params const & params, T *weights) : params(params), model(params, weights) {}

template<typename T>
void FFM<T>::fit(const Dataset<T> &dataset) {
Trainer<T> trainer(dataset, this->model, this->params);
@@ -23,6 +26,12 @@ void FFM<T>::fit(const Dataset<T> &dataset) {
}
}

template<typename T>
void FFM<T>::predict(const Dataset<T> &dataset, T *predictions) {
Trainer<T> trainer(dataset, this->model, this->params);
trainer.predict(predictions);
}

template<typename T>
Dataset<T> &rowsToDataset(Row<T> *rows, Params &params) {
size_t numRows = params.numRows;
@@ -91,4 +100,26 @@ void ffm_fit_double(Row<double> *rows, double *w, Params _param) {
ffm.model.copyTo(w);
}

void ffm_predict_float(Row<float> *rows, float *predictions, float *w, Params _param) {
// TODO temporary hack, change this so it's passed from Python and never changed
size_t fields = _param.numFields;
size_t features = _param.numFeatures;

log_debug(_param.verbose, "Converting %d float rows into a dataset for predictions.", _param.numRows);
Dataset<float> dataset = rowsToDataset(rows, _param);

_param.numFeatures = features;
_param.numFields = fields;

FFM<float> ffm(_param, w);
_param.printParams();
log_debug(_param.verbose, "Running FFM predict for float.");
ffm.predict(dataset, predictions);
// TODO copy to result
}

void ffm_predict_double(Row<double> *rows, double *predictions, double *w, Params _param) {

}

} // namespace ffm
3 changes: 3 additions & 0 deletions src/base/ffm/ffm.h
Original file line number Diff line number Diff line change
@@ -16,11 +16,14 @@ template <typename T>
class FFM {
public:
FFM(Params const & params);
FFM(Params const & params, T *weights);

Model<T> model;

void fit(const Dataset<T> &dataset);

void predict(const Dataset<T> &dataset, T *predictions);

private:
const Params params;
};
15 changes: 15 additions & 0 deletions src/base/ffm/model.cpp
Original file line number Diff line number Diff line change
@@ -26,6 +26,21 @@ Model<T>::Model(Params const &params) : weights(params.numFeatures * params.numF

}

template<typename T>
Model<T>::Model(Params const &params, T *weights) : weights(params.numFeatures * params.numFields * params.k) {
this->numFeatures = params.numFeatures;
this->numFields = params.numFields;
this->k = params.k;
this->normalize = params.normalize;

float coef = 1.0f / sqrt(this->k);

for (int i = 0; i < this->weights.size(); i++) {
this->weights[i] = weights[i];
}

}

template<typename T>
void Model<T>::copyTo(T *dstWeights) {
memcpy(dstWeights, this->weights.data(), this->weights.size() * sizeof(T));
2 changes: 2 additions & 0 deletions src/base/ffm/model.h
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@ class Model {

Model(Params const &params);

Model(Params const &params, T *weights);

void copyTo(T *dstWeights);

std::vector<T> weights;
2 changes: 2 additions & 0 deletions src/base/ffm/trainer.h
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@ class Trainer {

void oneEpoch(bool update);

void predict(T *predictions);

bool earlyStop();

private:
5 changes: 5 additions & 0 deletions src/cpu/ffm/trainer.cpp
Original file line number Diff line number Diff line change
@@ -12,6 +12,11 @@ Trainer<T>::Trainer(const Dataset<T> &dataset, Model<T> &model, Params const &pa
// TODO implement
}

template<typename T>
void Trainer<T>::predict(T *predictions) {
// TODO implement
}

template<typename T>
void Trainer<T>::oneEpoch(bool update) {
// TODO implement
27 changes: 26 additions & 1 deletion src/gpu/ffm/trainer.cu
Original file line number Diff line number Diff line change
@@ -38,10 +38,11 @@ T wTx(Row<T> *row,

auto weights_ptr = thrust::raw_pointer_cast(weights.data());

T r = params.normalize ? row->scale : 1.0;

#pragma omp parallel for schedule(static) reduction(+: loss)
for (size_t n1 = 0; n1 < row->size; n1++) {
Node<T> *node1 = nodes[n1];
T r = params.normalize ? row->scale : 1.0;

loss += thrust::transform_reduce(nodes.begin() + n1 + 1, nodes.end(), [=]__device__(Node<T> * node2) {
size_t feature1 = node1->featureIdx;
@@ -87,6 +88,30 @@ T wTx(Row<T> *row,
return loss;
}

template<typename T>
void Trainer<T>::predict(T *predictions) {
for (int i = 0; i < params.nGpus; i++) {
log_verbose(this->params.verbose, "Copying weights of size %zu to GPU %d for predictions", this->model.weights.size(), i);
thrust::device_vector<T> localWeights(this->model.weights.begin(), this->model.weights.end());

int record = 0;
while (trainDataBatcher[i]->hasNext()) {
log_verbose(this->params.verbose, "Getting batch of size %zu on GPU %d for predictions", this->params.batchSize, i);
DatasetBatch<T> batch = trainDataBatcher[i]->nextBatch(this->params.batchSize);

T loss = 0;
// TODO parallelize somehow
while (batch.hasNext()) {
Row<T> *row = batch.nextRow();

T t = wTx(row, localWeights, this->params);

predictions[record++] = 1 / (1 + exp(-t));
}
}
}
}

template<typename T>
// TODO return loss
void Trainer<T>::oneEpoch(bool update) {
3 changes: 3 additions & 0 deletions src/include/solver/ffm_api.h
Original file line number Diff line number Diff line change
@@ -39,4 +39,7 @@ typedef struct Params {
void ffm_fit_float(Row<float> *rows, float *w, Params _param);
void ffm_fit_double(Row<double> *rows, double *w, Params _param);

void ffm_predict_float(Row<float> *rows, float *predictions, float *w, Params _param);
void ffm_predict_double(Row<double> *rows, double *predictions, double *w, Params _param);

}
43 changes: 38 additions & 5 deletions src/interface_py/h2o4gpu/solvers/ffm.py
Original file line number Diff line number Diff line change
@@ -84,6 +84,9 @@ def __init__(
self.row_arr_holder = []
self.node_arr_holder = []

self.learned_params = None
self.predictions = None

@classmethod
def _get_param_names(cls):
# TODO implement
@@ -113,7 +116,7 @@ def fit(self, X, y):

params.numRows = np.shape(X)[0]

rows, featureIdx, fieldIdx = self._numpy_to_ffm_rows(X, y, lib)
rows, featureIdx, fieldIdx = self._numpy_to_ffm_rows(lib, X, y)

weights = np.zeros(params.k * featureIdx * fieldIdx, dtype=self.dtype)

@@ -126,11 +129,12 @@ def fit(self, X, y):
self.row_arr_holder = []
self.node_arr_holder = []

self.learned_params = params
self.weights = weights
return self


def _numpy_to_ffm_rows(self, X, y, lib):
def _numpy_to_ffm_rows(self, lib, X, y=None):
(node_creator, node_arr_creator, row_creator, row_arr_creator) = \
(lib.floatNode, lib.NodeFloatArray, lib.floatRow, lib.RowFloatArray) if self.dtype == np.float32 \
else (lib.doubleNode, lib.NodeDoubleArray, lib.doubleRow, lib.RowDoubleArray)
@@ -152,13 +156,42 @@ def _numpy_to_ffm_rows(self, X, y, lib):
feature_idx = max(feature_idx, node.featureIdx + 1)
field_idx = max(field_idx, node.fieldIdx + 1)
# Scale is being set automatically on the C++ side
row = row_creator(int(y[r]), 1.0, nr_nodes, node_arr)
row = row_creator( 0 if y is None else int(y[r]) , 1.0, nr_nodes, node_arr)
row_arr.__setitem__(r, row)
return row_arr, feature_idx, field_idx

def predict(self, X):
# TODO implement
pass
lib = self._load_lib()

params = lib.params_ffm()
params.verbose = self.verbose
params.learningRate = self.learning_rate
params.regLambda = self.reg_lambda
params.nIter = self.max_iter
params.batchSize = self.batch_size
params.k = self.k
params.normalize = self.normalize
params.autoStop = self.auto_stop
params.nGpus = self.nGpus
params.numFields = self.learned_params.numFields
params.numFeatures = self.learned_params.numFeatures

params.numRows = np.shape(X)[0]

rows, featureIdx, fieldIdx = self._numpy_to_ffm_rows(lib, X)

self.predictions = np.zeros(params.numRows, dtype=self.dtype)

if self.dtype == np.float32:
lib.ffm_predict_float(rows, self.predictions, self.weights, params)
else:
lib.ffm_predict_double(rows, self.predictions, self.weights, params)

# Cleans up the memory
self.row_arr_holder = []
self.node_arr_holder = []

return self

def transform(self, X, y=None):
# TODO implement
4 changes: 2 additions & 2 deletions src/swig/solver/ffm.i
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@

%rename("params_ffm") ffm::Params;

%apply (float *INPLACE_ARRAY1) {float *w};
%apply (double *INPLACE_ARRAY1) {double *w};
%apply (float *INPLACE_ARRAY1) {float *predictions, float *w};
%apply (double *INPLACE_ARRAY1) {double *predictions, double *w};

%include "../../include/data/ffm/data.h"
%include "../../include/solver/ffm_api.h"