Conversation
| std::vector<std::shared_ptr<Tensor>> params_; | ||
| float learning_rate_ = 0.0f; | ||
| float initial_learning_rate_ = 0.0f; | ||
| bool initial_lr_set_ = false; |
There was a problem hiding this comment.
这部分比较冗余。optimizer 里面可以只存有代表当前学习率的 learning_rate_,不需要额外存 initial lr 的状态;语义上初始学习率可以仅存在 lr scheduler 里(你是实际上已经这样做了,存在 lr scheduler 的 base_lr)。
There was a problem hiding this comment.
此处为对齐PyTorch初始化时的设置( 源码链接 ),
PyTorch在对调度器进行初始化时,会访问其关联优化器的参数列表,并进行setdefault,设置initial_rate_,对于首次被关联的优化器,将现在的学习率设置为initial_lr,对于非首次关联的调度器,返回现有值。
目前仅能想到作用为,可保证如果有多个调度器关联同一优化器声明(ChainedScheduler或SequentialLR等),他们的base_lr_均为第一个调度器进行初始化时优化器的学习率。暂不清楚其他应用场景,但出于与PyTorch保持一致,增设了相关参数,如果只涉及ChainedScheduler或SequentialLR的话,确实有其他替代方案,是否需要更改?
There was a problem hiding this comment.
可以,与 torch 对齐吧,保留这个 initial_lr 的成员
infini_train/include/lr_scheduler.h
Outdated
|
|
||
| std::shared_ptr<Optimizer> optimizer_; | ||
| int64_t last_step_; | ||
| float current_lr_; |
There was a problem hiding this comment.
current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。
建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。
There was a problem hiding this comment.
已修改,由于需要调度器具备恢复训练的能力,而如SequentialLR或ChainedScheduler等不支持closed-form计算,无法根据base_lr和last_epoch快速得到学习率,因此保留接口仅用于学习率恢复,并调整命名为recover_lr避免混淆。
There was a problem hiding this comment.
Pull request overview
This PR introduces a learning-rate scheduler system to infini_train, integrates it with optimizers (including distributed optimizer), and adds standalone C++ test executables plus example CLI wiring to exercise the new schedulers.
Changes:
- Add
LRSchedulerbase + concrete schedulers (ConstantLR/StepLR/LinearLR/LambdaLR/SequentialLR/ChainedScheduler) and aCreateLRSchedulerfactory. - Extend
Optimizerwith runtime-settable learning rate and initial learning rate tracking; propagate LR toDistributedOptimizer. - Add scheduler coverage tests and wire scheduler flags into
example/gpt2andexample/llama3; register new test executables in CMake.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
infini_train/include/lr_scheduler.h |
Declares scheduler APIs, configs, and concrete scheduler types. |
infini_train/src/lr_scheduler.cc |
Implements scheduler logic, factory creation, state save/load, sequential/chained behavior. |
infini_train/include/optimizer.h |
Adds LR getters/setters + initial LR tracking to support schedulers. |
infini_train/src/optimizer.cc |
Implements optimizer LR plumbing and updates SGD/Adam to use base LR storage. |
infini_train/include/nn/parallel/ddp/distributed_optimizer.h |
Overrides LR get/set for distributed optimizer so schedulers affect the real base optimizer. |
infini_train/src/nn/parallel/ddp/distributed_optimizer.cc |
Implements LR propagation to/from the wrapped base optimizer. |
example/gpt2/main.cc |
Adds scheduler CLI flags and steps the scheduler during training. |
example/llama3/main.cc |
Adds scheduler CLI flags and steps the scheduler during training. |
test/lr_scheduler/test_helpers.h |
Shared minimal test helpers/macros for scheduler tests. |
test/lr_scheduler/test_*.cc |
Adds functional + state + validation tests for schedulers. |
CMakeLists.txt |
Adds new scheduler test executables to the build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
另外几个在开发规范上需要修改的地方:
|
…r accessors, passthrough SetLearningRate/GetLearningRate, and add initial_learning_rate and it's accessors
…StepLR, LinearLR, LambdaLR and SequentialLR
…base class, add factory method Create<T>() with two-phase init and update all tests to use Create<T>() factory method. - Change Step() to virtual with default implementation - Add pure virtual ComputeLR() for subclasses to implement. - Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to implement ComputeLR() instead of Step(). - All existing tests pass without behavioral changes. BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step().
…closed and chained form, adjust LinearLR、SequentialLR - enhance LRScheduler with chained and closed form learning rate methods - adapt methods(Step, InitialStep, GetClosedFormLR, GetChainedFormLR) to match PyTorch‘s design - add tests for consistency - refactor LinearLR: add end_factor, and rename this class - add SequentialLR InitialStep and UndoChildInitialSteps BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR.
- Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step) - Add CreateLRScheduler() factory function - Support automatic warmup wrapping via SequentialLR when warmup_steps > 0 - Adapt test files
…ogs, and integrate scheduler into training loop
…s, add validation tests for learning rate schedulers - it now only be used for learning rate recovery when using loadstate
dc748bd to
327d263
Compare
No description provided.