Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try "warm up" phase #41

Merged
merged 10 commits into from
Jun 19, 2024
Merged

Try "warm up" phase #41

merged 10 commits into from
Jun 19, 2024

Conversation

matsen
Copy link
Contributor

@matsen matsen commented Jun 19, 2024

Didn't seem to help!

@matsen
Copy link
Contributor Author

matsen commented Jun 19, 2024

Here's what I ended up adding, which is going to get ripped out:

class IgnoreMetricLambdaLR:
    """
    The sole purpose of this class is to wrap a LambdaLR scheduler so that we can pass
    a metric to the step method without changing the underlying scheduler. This means
    we don't have to change anything in our training loop to use this scheduler.
    """

    def __init__(self, optimizer, lr_lambda):
        self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(self, metric=None):
        self.scheduler.step()

    def get_last_lr(self):
        return self.scheduler.get_last_lr()

    def state_dict(self):
        return self.scheduler.state_dict()

    def load_state_dict(self, state_dict):
        self.scheduler.load_state_dict(state_dict)


def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
    """
    Linearly increase the learning rate from min_lr to max_lr over warmup_epochs,
    then linearly decrease the learning rate from max_lr to min_lr.
    
    Meant to be used with LambdaLR.
    
    pd.Series([
        linear_bump_lr(epoch, warmup_epochs=20, total_epochs=200, max_lr=0.01, min_lr=1e-5)
        for epoch in range(200)]).plot()
    """
    if epoch < warmup_epochs:
        lr = min_lr + ((max_lr - min_lr) / warmup_epochs) * epoch
    else:
        lr = max_lr - ((max_lr - min_lr) / (total_epochs - warmup_epochs)) * (
            epoch - warmup_epochs
        )
    return lr


def linear_bump_scheduler(optimizer, warmup_epochs, total_epochs, max_lr, min_lr):
    """
    A learning rate scheduler that linearly increases the learning rate from 0 to max_lr
    over warmup_epochs, then linearly decreases the learning rate from max_lr to min_lr.
    """
    return IgnoreMetricLambdaLR(
        optimizer,
        lambda epoch: linear_bump_lr(
            epoch=epoch,
            warmup_epochs=warmup_epochs,
            total_epochs=total_epochs,
            max_lr=max_lr,
            min_lr=min_lr,
        ),
    )

@matsen
Copy link
Contributor Author

matsen commented Jun 19, 2024

And

        self.scheduler = linear_bump_scheduler(
            self.optimizer, warmup_epochs=20, total_epochs=200, max_lr=0.01, min_lr=1e-5

Note: that this is a multiplier on the base LR.

@matsen matsen linked an issue Jun 19, 2024 that may be closed by this pull request
@matsen matsen merged commit d88b855 into main Jun 19, 2024
1 check passed
@matsen matsen deleted the 39-warm-up branch June 19, 2024 19:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

try linear warm-up and higher peak learning rate
1 participant