Skip to content

Commit

Permalink
revise code and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Jul 19, 2023
1 parent 752b47d commit 2821807
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
32 changes: 16 additions & 16 deletions cpp/src/glm/qn/glm_base_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,24 @@ inline void linearBwdMG(const raft::handle_t& handle,
}
}

/**
* @brief Aggregates local gradient vectors and loss values from local training data. This
* class is the multi-node-multi-gpu version of GLMWithData.
*
* The implementation overrides existing GLMWithData::() function. The purpose is to
* aggregate local gradient vectors and loss values from distributed X, y, where X represents the
* input vectors and y represents labels.
*
* GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd.
* linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not
* require communication. getLossAndDz calculates local loss so requires allreduce to obtain a
* global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a
* global gradient vector. The global loss and the global gradient vector will be used in
* min_lbfgs to update coefficient. The update runs individually on every GPU and when finished,
* all GPUs have the same value of coefficient.
*/
template <typename T, class GLMObjective>
struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
/**
* @brief Aggregates local gradient vectors and loss values from local training data. This
* class is the multi-node-multi-gpu version of GLMWithData.
*
* The implementation overrides existing GLMWithData::() function. The purpose is to
* aggregate local gradient vectors and loss values from distributed X, y, where X represents the
* input vectors and y represents labels.
*
* GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd.
* linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not
* require communication. getLossAndDz calculates local loss so requires allreduce to obtain a
* global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a
* global gradient vector. The global loss and the global gradient vector will be used my
* min_lbfgs to update coefficient. The update runs individually on every GPU and when finished,
* all GPUs have the same value of coefficient.
*/
const raft::handle_t* handle_p;
int rank;
int64_t n_samples;
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
#include <cuml/linear_model/qn.h>
#include <cuml/linear_model/qn_mg.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/util/cudart_utils.hpp>
using namespace MLCommon;

#include "qn/glm_base_mg.cuh"

#include <cuda_runtime.h>
#include <iostream>

namespace ML {
namespace GLM {
Expand All @@ -53,10 +53,12 @@ void qnFit_impl(const raft::handle_t& handle,
{
switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
ASSERT(C == 2, "qn_mg.cu: logistic loss invalid C");
RAFT_EXPECTS(
C == 2,
"qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2");
} break;
default: {
ASSERT(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
}
}

Expand Down Expand Up @@ -107,8 +109,9 @@ void qnFit_impl(raft::handle_t& handle,
T* f,
int* num_iters)
{
ASSERT(input_data.size() == 1, "qn_mg.cu currently does not accept more than one input matrix");
ASSERT(labels.size() == input_data.size(), "labels size does not equal to input_data size");
RAFT_EXPECTS(input_data.size() == 1,
"qn_mg.cu currently does not accept more than one input matrix");
RAFT_EXPECTS(labels.size() == input_data.size(), "labels size does not equal to input_data size");

auto data_X = input_data[0];
auto data_y = labels[0];
Expand Down
1 change: 1 addition & 0 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ from cuml.common.opg_data_utils_mg cimport *
# the cdef was copied from cuml.linear_model.qn
cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM" nogil:

# TODO: Use single-GPU version qn_loss_type and qn_params https://github.com/rapidsai/cuml/issues/5502
cdef enum qn_loss_type "ML::GLM::qn_loss_type":
QN_LOSS_LOGISTIC "ML::GLM::QN_LOSS_LOGISTIC"
QN_LOSS_SQUARED "ML::GLM::QN_LOSS_SQUARED"
Expand Down

0 comments on commit 2821807

Please sign in to comment.