-
Notifications
You must be signed in to change notification settings - Fork 44
【训练营】学习率调度器实现 #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
【训练营】学习率调度器实现 #113
Changes from 19 commits
7a16589
0514862
81295e8
8e7cda0
1e65881
d924d3d
baca2ef
7df75d7
d0ac538
df4c68d
5b4ef6d
8c11dd9
6823244
fb9d997
7a29a61
b64566e
3a7abb4
f7b3fcb
1f95e29
afd98ff
151dda0
2980f93
dc748bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| #include "infini_train/include/core/runtime/device_guard.h" | ||
| #include "infini_train/include/dataloader.h" | ||
| #include "infini_train/include/device.h" | ||
| #include "infini_train/include/lr_scheduler.h" | ||
| #include "infini_train/include/nn/modules/loss.h" | ||
| #include "infini_train/include/nn/modules/module.h" | ||
| #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" | ||
|
|
@@ -54,6 +55,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); | |
| // optimization | ||
| DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); | ||
| DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); | ||
| // lr scheduler | ||
| DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); | ||
| DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); | ||
| DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); | ||
| DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); | ||
| DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); | ||
| DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); | ||
| DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); | ||
| DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); | ||
| DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); | ||
| // evaluation | ||
| DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); | ||
| DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); | ||
|
|
@@ -247,6 +258,20 @@ void Train(const nn::parallel::Rank &rank) { | |
| optimizer = optimizer_creator(model->Parameters()); | ||
| } | ||
|
|
||
| LRSchedulerConfig sched_config; | ||
| sched_config.type = FLAGS_lr_scheduler; | ||
| sched_config.warmup_steps = FLAGS_warmup_steps; | ||
| sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor); | ||
| sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor); | ||
| sched_config.step_size = FLAGS_step_size; | ||
| sched_config.step_gamma = static_cast<float>(FLAGS_gamma); | ||
| sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor); | ||
| sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor); | ||
| sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用 | ||
| sched_config.constant_total_iters = FLAGS_lr_total_iters; | ||
| sched_config.linear_total_iters = FLAGS_lr_total_iters; | ||
| auto scheduler = CreateLRScheduler(optimizer, sched_config); | ||
|
|
||
| auto train_iter = train_loader.begin(); | ||
| std::shared_ptr<nn::Module> loss_fn | ||
| = (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>()) | ||
|
|
@@ -330,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) { | |
| } | ||
|
|
||
| optimizer->Step(); | ||
| if (scheduler) { | ||
| scheduler->Step(); | ||
| } | ||
| } else { | ||
| auto [x, y] = *train_iter; | ||
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
|
|
@@ -339,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) { | |
| y = std::make_shared<Tensor>(y->To(device)); | ||
|
|
||
| lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); | ||
| if (scheduler) { | ||
| scheduler->Step(); | ||
| } | ||
| } | ||
|
|
||
| if (ddp_world_size > 1) { | ||
|
|
@@ -354,11 +385,11 @@ void Train(const nn::parallel::Rank &rank) { | |
| if (rank.IsLastRank()) { | ||
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); | ||
|
||
| LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " | ||
| "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", | ||
| step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, | ||
| tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, | ||
| step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, | ||
| used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, | ||
| pp_world_size); | ||
|
|
||
| if ((step + 1) % FLAGS_freq_generate_txt == 0) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| #pragma once | ||
|
|
||
| #include <cmath> | ||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <memory> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <variant> | ||
| #include <vector> | ||
|
|
||
| namespace infini_train { | ||
|
|
||
| class Optimizer; | ||
|
|
||
| using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>; | ||
| using StateDict = std::unordered_map<std::string, StateValue>; | ||
|
|
||
| struct LRSchedulerConfig { | ||
| std::string type = "none"; | ||
| // ConstantLR | ||
| float constant_factor = 1.0f / 3.0f; | ||
| int constant_total_iters = 5; | ||
| // StepLR | ||
| int64_t step_size = 10; | ||
| float step_gamma = 0.1f; | ||
| // LinearLR | ||
| float linear_start_factor = 1.0f / 3.0f; | ||
| float linear_end_factor = 1.0f; | ||
| int linear_total_iters = 5; | ||
| // LambdaLR | ||
| std::function<float(int64_t)> lambda_fn = nullptr; | ||
| // SequentialLR | ||
| std::vector<LRSchedulerConfig> sequential_configs; | ||
| std::vector<int64_t> sequential_milestones; | ||
| // ChainedScheduler | ||
| std::vector<LRSchedulerConfig> chained_configs; | ||
| // warmup | ||
| int64_t warmup_steps = 0; | ||
| float warmup_start_factor = 1.0f / 3.0f; | ||
| float warmup_end_factor = 1.0f; | ||
| }; | ||
|
|
||
| class LRScheduler { | ||
| public: | ||
| template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) { | ||
| auto scheduler = std::make_shared<T>(std::forward<Args>(args)...); | ||
| scheduler->InitialStep(); | ||
| return scheduler; | ||
| } | ||
|
|
||
| explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1); | ||
| virtual ~LRScheduler() = default; | ||
|
|
||
| LRScheduler(const LRScheduler &) = delete; | ||
| LRScheduler &operator=(const LRScheduler &) = delete; | ||
|
|
||
| virtual void Step(); | ||
| virtual void Step(int64_t epoch); | ||
| virtual void InitialStep(); | ||
|
|
||
| float GetLR() const; | ||
| float BaseLR() const; | ||
| int64_t LastStep() const; | ||
|
|
||
| void ResetStep(int64_t step = -1); | ||
| virtual StateDict State() const; | ||
| virtual void LoadState(const StateDict &state); | ||
|
|
||
| protected: | ||
| virtual float GetClosedFormLR() const = 0; | ||
| virtual float GetChainedFormLR() const; | ||
| void ApplyLR(float lr); | ||
|
|
||
| std::shared_ptr<Optimizer> optimizer_; | ||
| int64_t last_step_; | ||
| float current_lr_; | ||
|
||
| float base_lr_; | ||
| bool is_initial_ = false; | ||
| }; | ||
|
|
||
| std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config); | ||
|
|
||
| namespace lr_schedulers { | ||
|
|
||
| class ConstantLR : public LRScheduler { | ||
| public: | ||
| ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, | ||
| int64_t last_step = -1); | ||
| ~ConstantLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const float factor_; | ||
| const int64_t total_iters_; | ||
| }; | ||
|
|
||
| class StepLR : public LRScheduler { | ||
| public: | ||
| StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); | ||
| ~StepLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const int64_t step_size_; | ||
| const float gamma_; | ||
| }; | ||
|
|
||
| class LinearLR : public LRScheduler { | ||
| public: | ||
| LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, | ||
| int64_t total_iters = 5, int64_t last_step = -1); | ||
| ~LinearLR() override = default; | ||
|
|
||
| protected: | ||
| float GetChainedFormLR() const override; | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const float start_factor_; | ||
| const float end_factor_; | ||
| const int64_t total_iters_; | ||
| }; | ||
|
|
||
| class LambdaLR : public LRScheduler { | ||
| public: | ||
| using LambdaFunc = std::function<float(int64_t)>; | ||
|
|
||
| LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); | ||
| ~LambdaLR() override = default; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override; | ||
|
|
||
| private: | ||
| const LambdaFunc lr_lambda_; | ||
| }; | ||
|
|
||
| class SequentialLR : public LRScheduler { | ||
| public: | ||
| SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers, | ||
| std::vector<int64_t> milestones, int64_t last_step = -1); | ||
| ~SequentialLR() override = default; | ||
|
|
||
| void Step() override; | ||
| void InitialStep() override; | ||
|
|
||
| StateDict State() const override; | ||
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } | ||
|
||
| void UndoChildInitialSteps(); | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<LRScheduler>> schedulers_; | ||
| std::vector<int64_t> milestones_; | ||
| }; | ||
|
|
||
| class ChainedScheduler : public LRScheduler { | ||
| public: | ||
| ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers, | ||
| int64_t last_step = -1); | ||
| ~ChainedScheduler() override = default; | ||
|
|
||
| void Step() override; | ||
| void InitialStep() override; | ||
|
|
||
| StateDict State() const override; | ||
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<LRScheduler>> schedulers_; | ||
| }; | ||
|
|
||
| } // namespace lr_schedulers | ||
| } // namespace infini_train | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。