From 28218077119f414c092af356da27fc8074d7415b Mon Sep 17 00:00:00 2001 From: jinfeng Date: Wed, 19 Jul 2023 15:30:20 -0700 Subject: [PATCH] revise code and comments --- cpp/src/glm/qn/glm_base_mg.cuh | 32 +++++++++---------- cpp/src/glm/qn_mg.cu | 13 +++++--- .../linear_model/logistic_regression_mg.pyx | 1 + 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/cpp/src/glm/qn/glm_base_mg.cuh b/cpp/src/glm/qn/glm_base_mg.cuh index 0816824e14..1304ddaf60 100644 --- a/cpp/src/glm/qn/glm_base_mg.cuh +++ b/cpp/src/glm/qn/glm_base_mg.cuh @@ -68,24 +68,24 @@ inline void linearBwdMG(const raft::handle_t& handle, } } +/** + * @brief Aggregates local gradient vectors and loss values from local training data. This + * class is the multi-node-multi-gpu version of GLMWithData. + * + * The implementation overrides existing GLMWithData::() function. The purpose is to + * aggregate local gradient vectors and loss values from distributed X, y, where X represents the + * input vectors and y represents labels. + * + * GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd. + * linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not + * require communication. getLossAndDz calculates local loss so requires allreduce to obtain a + * global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a + * global gradient vector. The global loss and the global gradient vector will be used in + * min_lbfgs to update coefficient. The update runs individually on every GPU and when finished, + * all GPUs have the same value of coefficient. + */ template struct GLMWithDataMG : ML::GLM::detail::GLMWithData { - /** - * @brief Aggregates local gradient vectors and loss values from local training data. This - * class is the multi-node-multi-gpu version of GLMWithData. - * - * The implementation overrides existing GLMWithData::() function. The purpose is to - * aggregate local gradient vectors and loss values from distributed X, y, where X represents the - * input vectors and y represents labels. - * - * GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd. - * linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not - * require communication. getLossAndDz calculates local loss so requires allreduce to obtain a - * global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a - * global gradient vector. The global loss and the global gradient vector will be used my - * min_lbfgs to update coefficient. The update runs individually on every GPU and when finished, - * all GPUs have the same value of coefficient. - */ const raft::handle_t* handle_p; int rank; int64_t n_samples; diff --git a/cpp/src/glm/qn_mg.cu b/cpp/src/glm/qn_mg.cu index abcc4ccfeb..2a20be37ae 100644 --- a/cpp/src/glm/qn_mg.cu +++ b/cpp/src/glm/qn_mg.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include using namespace MLCommon; @@ -29,7 +30,6 @@ using namespace MLCommon; #include "qn/glm_base_mg.cuh" #include -#include namespace ML { namespace GLM { @@ -53,10 +53,12 @@ void qnFit_impl(const raft::handle_t& handle, { switch (pams.loss) { case QN_LOSS_LOGISTIC: { - ASSERT(C == 2, "qn_mg.cu: logistic loss invalid C"); + RAFT_EXPECTS( + C == 2, + "qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2"); } break; default: { - ASSERT(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss); + RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss); } } @@ -107,8 +109,9 @@ void qnFit_impl(raft::handle_t& handle, T* f, int* num_iters) { - ASSERT(input_data.size() == 1, "qn_mg.cu currently does not accept more than one input matrix"); - ASSERT(labels.size() == input_data.size(), "labels size does not equal to input_data size"); + RAFT_EXPECTS(input_data.size() == 1, + "qn_mg.cu currently does not accept more than one input matrix"); + RAFT_EXPECTS(labels.size() == input_data.size(), "labels size does not equal to input_data size"); auto data_X = input_data[0]; auto data_y = labels[0]; diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx index 884bb49943..83eb915186 100644 --- a/python/cuml/linear_model/logistic_regression_mg.pyx +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -39,6 +39,7 @@ from cuml.common.opg_data_utils_mg cimport * # the cdef was copied from cuml.linear_model.qn cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM" nogil: + # TODO: Use single-GPU version qn_loss_type and qn_params https://github.com/rapidsai/cuml/issues/5502 cdef enum qn_loss_type "ML::GLM::qn_loss_type": QN_LOSS_LOGISTIC "ML::GLM::QN_LOSS_LOGISTIC" QN_LOSS_SQUARED "ML::GLM::QN_LOSS_SQUARED"