Skip to content

Commit

Permalink
code submission
Browse files Browse the repository at this point in the history
  • Loading branch information
XingyuXie committed Sep 1, 2022
1 parent 69adc04 commit 920e25b
Show file tree
Hide file tree
Showing 89 changed files with 18,455 additions and 0 deletions.
126 changes: 126 additions & 0 deletions CV/MAE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models

We provide the instruction to modify the official training and fine-tuning files used in [MAE](https://github.com/facebookresearch/mae) such that you can use Adan to train MAE. **Please follow MAE instruction to install necessary packages.**



## Environment

Our experiments for this task are based on the following pkg version.

```python
torch.__version__ = '1.7.1+cu110'
torchvision.__version__ = '0.8.2+cu110'
timm.__version__ = '0.4.5'
torchaudio.__version__ = '0.7.2'
```
If you want to strictly follow our environment, please refer to our released docker image [xyxie/adan-image:mae](https://hub.docker.com/repository/docker/xyxie/adan-image).



## Usage of Adan for MAE

### Two steps to use Adan

**Step 1.** add the following parameters to the `main_pretrain.py` and `main_finetune.py`.

```python
parser.add_argument('--use-adan', action='store_true', default=False, help='whether to use Adan')
parser.add_argument('--max-grad-norm', type=float, default=0.0, help='if the l2 norm is large than this hyper-parameter, then we clip the gradient (default: 0.0, no gradient clip)')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='optimizer epsilon to avoid the bad case where second-order moment is zero (default: None, use opt default 1e-8 in adan)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='optimizer betas in Adan (default: None, use opt default [0.98, 0.92, 0.99] in Adan)')
```

* `use-adan`: whether to use Adan. The default optimizer is AdamW.

* `max-grad-norm`: it determines whether to perform gradient clipping.

* `opt-eps`: optimizer epsilon to avoid the bad case where second-order moment is zero.

* `opt-betas`: optimizer betas for Adan.



**Step 2.** creat the Adan optimizer as follows. In this step, you can directly replace the vanilla optimizer creator :

```python
# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
if args.use_adan:
if args.bias_decay:
param = model_without_ddp.parameters()
else:
param = param_groups
args.weight_decay = 0.0
optimizer = Adan(param, weight_decay=args.weight_decay,
lr=args.lr, betas=args.opt_betas,
eps = args.opt_eps, max_grad_norm=args.max_grad_norm)
else:
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
```



## MAE Pre-training

```python
python main_pretrain.py \
--batch_size 256 --accum_iter 1 \
--model ${MODEL_NAME} --norm_pix_loss --mask_ratio 0.75 \
--epochs 800 \
--lr ${LR} --weight_decay 0.02 --warmup_epochs ${WR_EPOCH} \
--min_lr ${MIN_LR} \
--opt-betas 0.98 0.92 0.90 --opt-eps 1e-8 --max-grad-norm 10.0 \
--use-adan \
--data_path ${IMAGENET_DIR}
--output_dir ${OUT_DIR}
```

- The pre-training file `main_pretrain.py` comes from [MAE](https://github.com/facebookresearch/mae).
- We use **16** A100 GPUs for MAE-Base and **32** A100 GPUs for MAE-Large.
- There are some differences between hyper-parameters for MAE-Base and MAE-Large

| | MODEL_NAME | LR | MIN_LR | WR_EPOCH |
| :-------: | :-------------------: | :----: | :----: | :------: |
| MAE-Base | mae_vit_base_patch16 | 2.0e-3 | 1e-8 | 40 |
| MAE-Large | mae_vit_large_patch16 | 2.2e-3 | 1e-4 | 80 |



## MAE Fine-tuning

```python
python main_finetune.py \
--accum_iter 1 \
--batch_size 256 \
--model ${MODEL_NAME} \
--finetune ${PATH to Ptr-trained Model} \
--epochs ${EPOCH} \
--lr 1.5e-2 --layer_decay ${LAYER_DECAY} \
--min-lr ${MIN_LR} \
--opt-betas 0.98 0.92 0.99 \
--opt-eps 1e-8 --max-grad-norm 0 \
--use-adan --warmup-epochs ${WR_EPOCH} \
--weight_decay ${WD} --drop_path ${DROP_PATH} \
--mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path ${IMAGENET_DIR}
```

- The fine-tune file `main_finetune.py` comes from [MAE](https://github.com/facebookresearch/mae).
- We use **16** A100 GPUs for MAE-Base and **32** A100 GPUs for MAE-Large.
- There are some differences between hyper-parameters for MAE-Base and MAE-Large

| | MODEL_NAME | EPOCH | MIN_LR | LAYER_DECAY | WR_EPOCH | WD | DROP_PATH |
| :-------: | :---------------: | :---: | :----: | :---------: | :------: | ---- | :-------: |
| MAE-Base | vit_base_patch16 | 100 | 1e-6 | 0.65 | 40 | 5e-3 | 0.1 |
| MAE-Large | vit_large_patch16 | 50 | 1e-5 | 0.75 | 10 | 1e-3 | 0.2 |



## Results and Logs

| | MAE-Base | MAE-Large |
| :------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| Top-1 Acc. (%) | 83.8 | 85.9 |
| download | [log-pretrain](./exp_results/MAE/base/log_base_pretrain.txt)/[log-finetune](./exp_results/MAE/base/log_base_ft.txt)/model | [log-pretrain](./exp_results/MAE/large/log_large_pretrain.txt)/[log-finetune](./exp_results/MAE/large/log_large_ft.txt)/model |

154 changes: 154 additions & 0 deletions CV/MAE/adan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import torch
from torch.optim.optimizer import Optimizer
from timm.utils import *


class Adan(Optimizer):
"""
Implements a pytorch variant of Adan
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float, flot], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
max_grad_norm (float, optional): value used to clip
global grad norm (default: 0.0 no clip)
no_prox (bool): how to perform the decoupled weight decay (default: False)
"""

def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8,
weight_decay=0.0, max_grad_norm=0.0, no_prox=False):
if not 0.0 <= max_grad_norm:
raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm))
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm, no_prox=no_prox)
super(Adan, self).__init__(params, defaults)

def __setstate__(self, state):
super(Adan, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('no_prox', False)

@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization

# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(p)

@torch.no_grad()
def step(self):
"""
Performs a single optimization step.
"""
if self.defaults['max_grad_norm'] > 0:
device = self.param_groups[0]['params'][0].device
global_grad_norm = torch.zeros(1, device=device)

max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
for group in self.param_groups:

for p in group['params']:
if p.grad is not None:
grad = p.grad
global_grad_norm.add_(grad.pow(2).sum())

global_grad_norm = torch.sqrt(global_grad_norm)

clip_global_grad_norm = torch.clamp(max_grad_norm / (global_grad_norm + group['eps']), max=1.0)
else:
clip_global_grad_norm = 1.0

for group in self.param_groups:
beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

bias_correction1 = 1.0 - beta1 ** group['step']

bias_correction2 = 1.0 - beta2 ** group['step']

bias_correction3 = 1.0 - beta3 ** group['step']

for p in group['params']:
if p.grad is None:
continue

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)

grad = p.grad.mul_(clip_global_grad_norm)
if 'pre_grad' not in state or group['step'] == 1:
state['pre_grad'] = grad

copy_grad = grad.clone()

exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']
diff = grad - state['pre_grad']

update = grad + beta2 * diff
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t
exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t

denom = ((exp_avg_sq).sqrt() / math.sqrt(bias_correction3)).add_(group['eps'])
update = ((exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2)).div_(denom)

if group['no_prox']:
p.data.mul_(1 - group['lr'] * group['weight_decay'])
p.add_(update, alpha=-group['lr'])
else:
p.add_(update, alpha=-group['lr'])
p.data.div_(1 + group['lr'] * group['weight_decay'])

state['pre_grad'] = copy_grad
Loading

0 comments on commit 920e25b

Please sign in to comment.