Skip to content

Commit

Permalink
draft for dual_coefs
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Sep 20, 2024
1 parent 28641bb commit 1a4733e
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 87 deletions.
7 changes: 1 addition & 6 deletions cpp/bench/sg/svc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ struct SvcParams {
BlobsParams blobs;
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D> model;
};

template <typename D>
class SVC : public BlobsFixture<D, D> {
public:
SVC(const std::string& name, const SvcParams<D>& p)
: BlobsFixture<D, D>(name, p.data, p.blobs),
kernel(p.kernel),
model(p.model),
svm_param(p.svm_param)
: BlobsFixture<D, D>(name, p.data, p.blobs), kernel(p.kernel), svm_param(p.svm_param)
{
std::vector<std::string> kernel_names{"linear", "poly", "rbf", "tanh"};
std::ostringstream oss;
Expand Down Expand Up @@ -101,7 +97,6 @@ std::vector<SvcParams<D>> getInputs()

// SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity})
p.svm_param = ML::SVM::SvmParameter{1, 200, 100, 100, 1e-3, CUML_LEVEL_INFO, 0, ML::SVM::C_SVC};
p.model = ML::SVM::SvmModel<D>{0, 0, 0, nullptr, {}, nullptr, 0, nullptr};

std::vector<Triplets> rowcols = {{50000, 2, 2}, {2048, 100000, 2}, {50000, 1000, 2}};

Expand Down
13 changes: 4 additions & 9 deletions cpp/bench/sg/svr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,13 @@ struct SvrParams {
RegressionParams regression;
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D>* model;
};

template <typename D>
class SVR : public RegressionFixture<D> {
public:
SVR(const std::string& name, const SvrParams<D>& p)
: RegressionFixture<D>(name, p.data, p.regression),
kernel(p.kernel),
model(p.model),
svm_param(p.svm_param)
: RegressionFixture<D>(name, p.data, p.regression), kernel(p.kernel), svm_param(p.svm_param)
{
std::vector<std::string> kernel_names{"linear", "poly", "rbf", "tanh"};
std::ostringstream oss;
Expand All @@ -69,16 +65,16 @@ class SVR : public RegressionFixture<D> {
this->data.y.data(),
this->svm_param,
this->kernel,
*(this->model));
this->model);
this->handle->sync_stream(this->stream);
ML::SVM::svmFreeBuffers(*this->handle, *(this->model));
ML::SVM::svmFreeBuffers(*this->handle, this->model);
});
}

private:
raft::distance::kernels::KernelParams kernel;
ML::SVM::SvmParameter svm_param;
ML::SVM::SvmModel<D>* model;
ML::SVM::SvmModel<D> model;
};

template <typename D>
Expand All @@ -103,7 +99,6 @@ std::vector<SvrParams<D>> getInputs()
// epsilon, svmType})
p.svm_param =
ML::SVM::SvmParameter{1, 200, 200, 100, 1e-3, CUML_LEVEL_INFO, 0.1, ML::SVM::EPSILON_SVR};
p.model = new ML::SVM::SvmModel<D>{0, 0, 0, 0};

std::vector<Triplets> rowcols = {{50000, 2, 2}, {1024, 10000, 10}, {3000, 200, 200}};

Expand Down
12 changes: 10 additions & 2 deletions cpp/include/cuml/svm/svm_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,8 @@
*/
#pragma once

#include <rmm/device_buffer.hpp>

namespace ML {
namespace SVM {

Expand All @@ -25,6 +27,11 @@ struct SupportStorage {
int* indptr = nullptr;
int* indices = nullptr;
math_t* data = nullptr;
/*
rmm::device_buffer indptr_bf;
rmm::device_buffer indices_bf;
rmm::device_buffer data_bf;
*/
};

/**
Expand All @@ -39,7 +46,8 @@ struct SvmModel {

//! Non-zero dual coefficients ( dual_coef[i] = \f$ y_i \alpha_i \f$).
//! Size [n_support].
math_t* dual_coefs;
// math_t* dual_coefs;
rmm::device_buffer dual_coefs;

//! Support vector storage - can contain either CSR or dense
SupportStorage<math_t> support_matrix;
Expand Down
11 changes: 5 additions & 6 deletions cpp/src/svm/results.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class Results {
*/
void Get(const math_t* alpha,
const math_t* f,
math_t** dual_coefs,
rmm::device_buffer* dual_coefs,
int* n_support,
int** idx,
SupportStorage<math_t>* support_matrix,
Expand All @@ -130,7 +130,7 @@ class Results {
*idx = GetSupportVectorIndices(val_tmp.data(), *n_support);
*support_matrix = CollectSupportVectorMatrix(*idx, *n_support);
} else {
*dual_coefs = nullptr;
dual_coefs->resize(0, stream);
*idx = nullptr;
*support_matrix = {};
}
Expand Down Expand Up @@ -205,14 +205,13 @@ class Results {
* unallocated on entry, on exit size [n_support]
* @param [out] n_support number of support vectors
*/
void GetDualCoefs(const math_t* val_tmp, math_t** dual_coefs, int* n_support)
void GetDualCoefs(const math_t* val_tmp, rmm::device_buffer* dual_coefs, int* n_support)
{
// Return only the non-zero coefficients
auto select_op = [] __device__(math_t a) { return 0 != a; };
*n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data());
*dual_coefs = (math_t*)rmm_alloc.allocate_async(
*n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
raft::copy(*dual_coefs, val_selected.data(), *n_support, stream);
dual_coefs->resize(*n_support * sizeof(math_t), stream);
raft::copy((math_t*)dual_coefs->data(), val_selected.data(), *n_support, stream);
handle.sync_stream(stream);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void SmoSolver<math_t>::Solve(MatrixViewType matrix,
int n_cols,
math_t* y,
const math_t* sample_weight,
math_t** dual_coefs,
rmm::device_buffer* dual_coefs,
int* n_support,
SupportStorage<math_t>* support_matrix,
int** idx,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/svm/smosolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class SmoSolver {
int n_cols,
math_t* y,
const math_t* sample_weight,
math_t** dual_coefs,
rmm::device_buffer* dual_coefs,
int* n_support,
SupportStorage<math_t>* support_matrix,
int** idx,
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/svm/svc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ SVC<math_t>::SVC(raft::handle_t& handle,
param(SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}),
kernel_params(kernel_params)
{
model.n_support = 0;
model.dual_coefs = nullptr;
model.n_support = 0;
model.dual_coefs.resize(0, handle.get_stream());
model.support_matrix = {};
model.support_idx = nullptr;
model.unique_labels = nullptr;
Expand All @@ -162,7 +162,7 @@ void SVC<math_t>::fit(
math_t* input, int n_rows, int n_cols, math_t* labels, const math_t* sample_weight)
{
model.n_cols = n_cols;
if (model.dual_coefs) svmFreeBuffers(handle, model);
if (!model.dual_coefs.is_empty()) svmFreeBuffers(handle, model);
svcFit(handle, input, n_rows, n_cols, labels, param, kernel_params, model, sample_weight);
}

Expand Down
7 changes: 2 additions & 5 deletions cpp/src/svm/svc_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ void svcPredictX(const raft::handle_t& handle,
&one,
K.data(),
transpose_kernel ? model.n_support : n_batch,
model.dual_coefs,
(math_t*)model.dual_coefs.data(),
1,
&null,
y.data() + i,
Expand Down Expand Up @@ -357,9 +357,7 @@ void svmFreeBuffers(const raft::handle_t& handle, SvmModel<math_t>& m)
{
cudaStream_t stream = handle.get_stream();
rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();
if (m.dual_coefs)
rmm_alloc.deallocate_async(
m.dual_coefs, m.n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
m.dual_coefs.resize(0, stream);
if (m.support_idx)
rmm_alloc.deallocate_async(
m.support_idx, m.n_support * sizeof(int), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
Expand Down Expand Up @@ -394,7 +392,6 @@ void svmFreeBuffers(const raft::handle_t& handle, SvmModel<math_t>& m)
if (m.unique_labels)
rmm_alloc.deallocate_async(
m.unique_labels, m.n_classes * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
m.dual_coefs = nullptr;
m.support_idx = nullptr;
m.unique_labels = nullptr;
}
Expand Down
69 changes: 51 additions & 18 deletions cpp/src/svm/svm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ cumlError_t cumlSpSvcFit(cumlHandle_t handle,

ML::SVM::SvmModel<float> model;

rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();

cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
cudaStream_t stream = handle_ptr->get_stream();
if (status == CUML_SUCCESS) {
try {
ML::SVM::svcFit(*handle_ptr,
Expand All @@ -78,13 +81,21 @@ cumlError_t cumlSpSvcFit(cumlHandle_t handle,
kernel_param,
model,
static_cast<float*>(nullptr));
*n_support = model.n_support;
*b = model.b;
*dual_coefs = model.dual_coefs;
*n_support = model.n_support;
*b = model.b;
if (model.dual_coefs.size() > 0) {
*dual_coefs = (float*)rmm_alloc.allocate_async(
model.dual_coefs.size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(
dual_coefs, model.dual_coefs.data(), model.dual_coefs.size(), cudaMemcpyDefault, stream));
} else {
*dual_coefs = nullptr;
}
*x_support = model.support_matrix.data;
*support_idx = model.support_idx;
*n_classes = model.n_classes;
*unique_labels = model.unique_labels;

}
// TODO: Implement this
// catch (const MLCommon::Exception& e)
Expand Down Expand Up @@ -138,9 +149,12 @@ cumlError_t cumlDpSvcFit(cumlHandle_t handle,

ML::SVM::SvmModel<double> model;

rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();

cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
cudaStream_t stream = handle_ptr->get_stream();
if (status == CUML_SUCCESS) {
try {
ML::SVM::svcFit(*handle_ptr,
Expand All @@ -152,9 +166,16 @@ cumlError_t cumlDpSvcFit(cumlHandle_t handle,
kernel_param,
model,
static_cast<double*>(nullptr));
*n_support = model.n_support;
*b = model.b;
*dual_coefs = model.dual_coefs;
*n_support = model.n_support;
*b = model.b;
if (model.dual_coefs.size() > 0) {
*dual_coefs = (double*)rmm_alloc.allocate_async(
model.dual_coefs.size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(
dual_coefs, model.dual_coefs.data(), model.dual_coefs.size(), cudaMemcpyDefault, stream));
} else {
*dual_coefs = nullptr;
}
*x_support = model.support_matrix.data;
*support_idx = model.support_idx;
*n_classes = model.n_classes;
Expand Down Expand Up @@ -191,25 +212,31 @@ cumlError_t cumlSpSvcPredict(cumlHandle_t handle,
float buffer_size,
int predict_class)
{
cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
cudaStream_t stream = handle_ptr->get_stream();

raft::distance::kernels::KernelParams kernel_param;
kernel_param.kernel = (raft::distance::kernels::KernelType)kernel;
kernel_param.degree = degree;
kernel_param.gamma = gamma;
kernel_param.coef0 = coef0;

ML::SVM::SvmModel<float> model;
model.n_support = n_support;
model.b = b;
model.dual_coefs = dual_coefs;
model.n_support = n_support;
model.b = b;
if (n_support > 0) {
model.dual_coefs.resize(n_support * sizeof(float), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(
model.dual_coefs.data(), dual_coefs, n_support * sizeof(float), cudaMemcpyDefault, stream));
}

model.support_matrix = {.data = x_support};
model.support_idx = nullptr;
model.n_classes = n_classes;
model.unique_labels = unique_labels;

cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
if (status == CUML_SUCCESS) {
try {
ML::SVM::svcPredict(
Expand Down Expand Up @@ -246,25 +273,31 @@ cumlError_t cumlDpSvcPredict(cumlHandle_t handle,
double buffer_size,
int predict_class)
{
cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
cudaStream_t stream = handle_ptr->get_stream();

raft::distance::kernels::KernelParams kernel_param;
kernel_param.kernel = (raft::distance::kernels::KernelType)kernel;
kernel_param.degree = degree;
kernel_param.gamma = gamma;
kernel_param.coef0 = coef0;

ML::SVM::SvmModel<double> model;
model.n_support = n_support;
model.b = b;
model.dual_coefs = dual_coefs;
model.n_support = n_support;
model.b = b;
if (n_support > 0) {
model.dual_coefs.resize(n_support * sizeof(double), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(
model.dual_coefs.data(), dual_coefs, n_support * sizeof(double), cudaMemcpyDefault, stream));
}

model.support_matrix = {.data = x_support};
model.support_idx = nullptr;
model.n_classes = n_classes;
model.unique_labels = unique_labels;

cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
if (status == CUML_SUCCESS) {
try {
ML::SVM::svcPredict(
Expand Down
Loading

0 comments on commit 1a4733e

Please sign in to comment.