-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP Write Generalized Linear Machine class #5006
Changes from all commits
0bbdb9b
6a65fa0
2d91213
e2faedb
e439de7
5c2a6cb
ad67ed0
73d726b
c4e57b5
32ec5e6
07d262f
3bc2c16
3bfe4aa
23e2232
ff9c4bd
10bbee8
1d9da84
d07cfdb
551d134
c74cb24
5b43658
97a7a53
474e141
582341c
57d2e69
d167089
763c7f6
83c9f37
64a224d
67a27ef
437e90b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
/* | ||
* This software is distributed under BSD 3-clause license (see LICENSE file). | ||
* | ||
* Author: Tej Sukhatme | ||
*/ | ||
|
||
#include <shogun/features/DotFeatures.h> | ||
#include <shogun/labels/Labels.h> | ||
#include <shogun/labels/RegressionLabels.h> | ||
#include <shogun/lib/observers/ObservedValueTemplated.h> | ||
#include <shogun/machine/GLM.h> | ||
#include <shogun/machine/LinearMachine.h> | ||
#include <shogun/mathematics/NormalDistribution.h> | ||
#include <shogun/mathematics/RandomNamespace.h> | ||
#include <shogun/mathematics/linalg/LinalgNamespace.h> | ||
#include <shogun/optimization/ConstLearningRate.h> | ||
#include <shogun/optimization/ElasticNetPenalty.h> | ||
#include <shogun/optimization/GradientDescendUpdater.h> | ||
#include <shogun/optimization/SGDMinimizer.h> | ||
|
||
#include <cmath> | ||
|
||
using namespace shogun; | ||
|
||
GLM::GLM() | ||
{ | ||
SG_ADD_OPTIONS( | ||
(machine_int_t*)&distribution, "distribution_type", | ||
"variable to store name of distribution type", | ||
ParameterProperties::HYPER, SG_OPTIONS(POISSON)); | ||
SG_ADD( | ||
&m_eta, "eta", | ||
"threshold parameter that linearizes the exp() function above eta", | ||
ParameterProperties::HYPER); | ||
SG_ADD( | ||
&m_lambda, "lambda", "regularization parameter of penalty term", | ||
ParameterProperties::HYPER); | ||
SG_ADD( | ||
&m_alpha, "alpha", | ||
"weighting between L1 penalty and L2 penalty term of the loss function", | ||
ParameterProperties::HYPER); | ||
SG_ADD( | ||
&m_tolerance, "tolerance", "convergence threshold or stopping criteria", | ||
ParameterProperties::HYPER); | ||
SG_ADD( | ||
&m_learning_rate, "learning_rate", "learning rate for gradient descent", | ||
ParameterProperties::HYPER); | ||
|
||
m_gradient_updater = std::make_shared<GradientDescendUpdater>(); | ||
m_penalty = std::make_shared<ElasticNetPenalty>(); | ||
m_cost_function = std::make_shared<GLMCostFunction>(); | ||
} | ||
|
||
GLM::GLM( | ||
GLM_DISTRIBUTION distr, float64_t alpha, float64_t lambda, | ||
float64_t learning_rate, int32_t max_iterations, float64_t tolerance, | ||
float64_t eta) | ||
: GLM() | ||
{ | ||
distribution = distr; | ||
m_alpha = alpha; | ||
m_lambda = lambda; | ||
m_learning_rate = learning_rate; | ||
m_max_iterations = max_iterations; | ||
m_tolerance = tolerance; | ||
m_eta = eta; | ||
|
||
m_penalty->set_l1_ratio(m_alpha); | ||
} | ||
|
||
std::shared_ptr<RegressionLabels> | ||
GLM::apply_regression(std::shared_ptr<Features> data) | ||
{ | ||
if (data) | ||
{ | ||
if (!data->has_property(FP_DOT)) | ||
error("Specified features are not of type CDotFeatures"); | ||
set_features(std::static_pointer_cast<DotFeatures>(data)); | ||
} | ||
|
||
require(features, "Features are not provided"); | ||
|
||
auto num = features->get_num_vectors(); | ||
ASSERT(num > 0) | ||
ASSERT(m_w.vlen == features->get_dim_feature_space()) | ||
SGVector<float64_t> out(num); | ||
features->dense_dot_range( | ||
out.vector, 0, num, NULL, m_w.vector, m_w.vlen, bias); | ||
auto result = m_cost_function->non_linearity( | ||
out, m_compute_bias, m_eta, distribution); | ||
|
||
return std::make_shared<RegressionLabels>(result); | ||
} | ||
|
||
void GLM::init_model(const std::shared_ptr<Features> data) | ||
{ | ||
ASSERT(m_labels) | ||
if (data) | ||
{ | ||
if (!data->has_property(FP_DOT)) | ||
error("Specified features are not of type CDotFeatures"); | ||
set_features(std::static_pointer_cast<DotFeatures>(data)); | ||
} | ||
ASSERT(features) | ||
|
||
NormalDistribution<float64_t> normal_dist; | ||
const auto& n_features = features->get_dim_feature_space(); | ||
|
||
if (m_w.vlen == 0) | ||
{ | ||
if (m_compute_bias && bias == 0) | ||
Hephaestus12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bias = 1.0 / (n_features + 1) * normal_dist(m_prng); | ||
|
||
if (n_features > 0) | ||
{ | ||
m_w = SGVector<float64_t>(n_features); | ||
|
||
std::generate(m_w.begin(), m_w.end(), [&]() { | ||
auto rand = normal_dist(m_prng); | ||
return 1.0 / (n_features + 1) * rand; | ||
}); | ||
} | ||
} | ||
Hephaestus12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
void GLM::iteration() | ||
{ | ||
SGVector<float64_t> w_old = m_w.clone(); | ||
|
||
auto X = get_features()->get_computed_dot_feature_matrix(); | ||
auto y = regression_labels(get_labels())->get_labels(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be fixed at some point. I guess when we have the LabelEncoder in Machine, otherwise you are performing a potentially expensive operation in each iteration There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems like that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, but it should be possible to add it? |
||
|
||
auto gradient_w = m_cost_function->get_gradient_weights( | ||
X, y, m_w, bias, m_lambda, m_alpha, m_compute_bias, m_eta, | ||
distribution); | ||
auto gradient_bias = m_cost_function->get_gradient_bias( | ||
X, y, m_w, bias, m_compute_bias, m_eta, distribution); | ||
|
||
// Update | ||
// TODO: Use gradient updater | ||
// m_gradient_updater->update_variable(m_w, gradient_w, learning_rate); | ||
m_w = linalg::add(m_w, gradient_w, 1.0, -1 * m_learning_rate); | ||
|
||
if (m_compute_bias) | ||
bias -= m_learning_rate * gradient_bias; | ||
|
||
// Apply proximal operator | ||
// TODO: Use proximity updater. | ||
// m_penalty->update_variable_for_proximity(m_w, m_lambda * m_alpha); | ||
for (auto i : range(m_w.vlen)) | ||
{ | ||
if (std::abs(m_w[i]) < (m_lambda * m_alpha)) | ||
m_w[i] = 0; | ||
else | ||
{ | ||
if (m_w[i] > 0) | ||
m_w[i] -= (m_lambda * m_alpha); | ||
else | ||
m_w[i] += (m_lambda * m_alpha); | ||
} | ||
} | ||
|
||
// Convergence by relative parameter change tolerance | ||
auto norm_update = linalg::norm(linalg::add(m_w, w_old, 1.0, -1.0)); | ||
float32_t checker = linalg::norm(m_w) == 0 | ||
? norm_update | ||
: std::abs(norm_update / linalg::norm(m_w)); | ||
if (m_current_iteration > 0 && checker < m_tolerance) | ||
m_complete = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason why the features are added to the state here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No particular reason, I did it this way as it was done like this in LinearMachine:
shogun/src/shogun/machine/LinearMachine.cpp
Lines 70 to 76 in ec557d8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, you'll have to see if this has side effects that you rely on. @LiuYuHui do you know if this is required? Because you had to refactor this code recently in the feature branch