Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Accumulate Gradient #76

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2f7612a
add swin transformer
firestonelib Oct 26, 2021
b5c2286
add swin transformer
firestonelib Oct 26, 2021
ba20e04
add swin transformer
firestonelib Oct 26, 2021
382f820
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 3, 2021
5805bc3
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 10, 2021
225f6f4
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 13, 2021
9e90431
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 15, 2021
efee9a2
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 18, 2021
0e88cee
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 23, 2021
db6353f
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 24, 2021
f802aa7
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 24, 2021
5b57327
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 25, 2021
f7aa5e6
Merge branch 'PaddlePaddle:main' into main
firestonelib Nov 30, 2021
30f7baf
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 17, 2021
23559e3
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 21, 2021
866e205
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 21, 2021
d068f93
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 22, 2021
0c1ed78
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 23, 2021
70f6a84
Merge branch 'PaddlePaddle:main' into main
firestonelib Dec 26, 2021
bbaec05
add accumulate gradients
firestonelib Dec 28, 2021
670e946
modify optimizer
firestonelib Dec 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ epochs: 300
output_dir: output_dir
seed: 0

accumulate_grad_steps: 1

model:
name: SwinWrapper
architecture:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ AMP:
"c_softmax_with_cross_entropy", "elementwise_div"]
level: 'O1'

hybrid:
dp_degree: 8
mp_degree: 1
pp_degree: 1

sharding:
sharding_stage: 2 # 2 or 'dp'
offload: False
accumulate_grad: False

accumulate_grad_steps: 1

model:
name: SwinWrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@ AMP:
custom_black_list: ["reduce_mean", "reduce_sum",
"c_softmax_with_cross_entropy", "elementwise_div"]
level: 'O1'

hybrid:
dp_degree: 8
mp_degree: 1
pp_degree: 1

sharding:
sharding_stage: 2 # 2 or 'dp'
offload: False
accumulate_grad: False

accumulate_grad_steps: 2

model:
name: SwinWrapper
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
epochs: 300
output_dir: output_dir
seed: 0

accumulate_grad_steps: 1
model:
name: SwinWrapper
architecture:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ epochs: 300
output_dir: output_dir
seed: 0

accumulate_grad_steps: 1

model:
name: SwinWrapper
architecture:
Expand Down
18 changes: 11 additions & 7 deletions passl/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(self, cfg):
use_simclr_iters = cfg.get('use_simclr_iters', False)
self.use_simclr_iters = use_simclr_iters
self.epochs = cfg.get('epochs', None)
self.accumulate_grad_steps = cfg.get('accumulate_grad_steps', 1)
self.accumulate_grads = True if self.accumulate_grad_steps > 1 else False
self.timestamp = cfg.timestamp
self.logs = OrderedDict()
# Ensure that the vdl log file can be closed normally
Expand Down Expand Up @@ -147,7 +149,7 @@ def __init__(self, cfg):
# distributed settings
if dist.get_world_size() > 1:
strategy = fleet.DistributedStrategy()
## Hybrid Parallel Training
# Hybrid Parallel Training
strategy.hybrid_configs = cfg.pop('hybrid') if 'hybrid' in cfg else {}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
Expand All @@ -157,7 +159,7 @@ def __init__(self, cfg):
set_hyrbid_parallel_seed(seed, 0, mp_rank, pp_rank)

# amp training
self.use_amp = cfg.get('use_amp', False) #if 'use_amp' in cfg else False
self.use_amp = cfg.get('use_amp', False)
if self.use_amp:
amp_cfg = cfg.pop('AMP')
self.auto_cast = amp_cfg.pop('auto_cast')
Expand All @@ -170,22 +172,24 @@ def __init__(self, cfg):
self.sharding_strategies = cfg.get('sharding', False)
if self.sharding_strategies:
self.sharding_stage = self.sharding_strategies['sharding_stage']
accumulate_grad = self.sharding_strategies['accumulate_grad']
offload = self.sharding_strategies['offload']
# Note: Only support partition optimizer stages and gradient now!
if self.sharding_stage == 2:
# Partition Optimizer
self.optimizer = ShardingOptimizerStage2(
params=self.model.parameters(),
optim=self.optimizer,
offload=offload)
# Partition Gradients
self.model = ShardingStage2(
self.model,
self.optimizer,
accumulate_grads=accumulate_grad)
accumulate_grads=self.accumulate_grads)
self.scaler = ShardingScaler(self.scaler)
elif self.sharding_stage == 'dp' and dist.get_world_size() > 1:
self.model = fleet.distributed_model(self.model)
else:
raise NotImplementedError()
elif dist.get_world_size() > 1:
self.model = fleet.distributed_model(self.model)



Expand Down Expand Up @@ -374,7 +378,7 @@ def val(self, **kargs):
outs[k] = AverageMeter(k, ':6.3f')
outs[k].update(float(v), current_samples)

log_str = f'Validate Epoch [{self.current_epoch + 1}] '
log_str = f'Validate Epoch [{self.current_epoch + 1}]'
log_items = []
for name, val in outs.items():
if isinstance(val, AverageMeter):
Expand Down
55 changes: 38 additions & 17 deletions passl/hooks/optimizer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,49 @@
class OptimizerHook(Hook):
def __init__(self, priority=1):
self.priority = priority

def train_iter_end(self, trainer):
if 'Lars' in trainer.cfg['optimizer']['name']:
trainer.optimizer.clear_gradients()
else:
trainer.optimizer.clear_grad()
accumulate_steps = trainer.accumulate_grad_steps
if accumulate_steps > 1:
if trainer.current_iter % accumulate_steps == 0:
if 'Lars' in trainer.cfg['optimizer']['name']:
trainer.optimizer.clear_gradients()
else:
trainer.optimizer.clear_grad()

loss = 0
loss = trainer.outputs['loss']

if trainer.use_amp:
scaled_loss = trainer.scaler.scale(loss)
scaled_loss.backward()
trainer.scaler.step(trainer.optimizer)
trainer.scaler.update()
loss = trainer.outputs['loss'] / accumulate_steps
if trainer.use_amp:
scaled_loss = trainer.scaler.scale(loss)
scaled_loss.backward()
trainer.scaler.step(trainer.optimizer)
trainer.scaler.update()

else:
loss.backward()
if 'lars' in trainer.optimizer.type:
trainer.optimizer.minimize(loss)
else:
trainer.optimizer.step()
else:
loss = trainer.outputs['loss'] / accumulate_steps
if trainer.use_amp:
scaled_loss = trainer.scaler.scale(loss)
scaled_loss.backward()
else:
loss.backward()
else:
loss.backward()
if 'lars' in trainer.optimizer.type:
trainer.optimizer.minimize(loss)
loss = trainer.outputs['loss']
if trainer.use_amp:
scaled_loss = trainer.scaler.scale(loss)
scaled_loss.backward()
trainer.scaler.step(trainer.optimizer)
trainer.scaler.update()
else:
trainer.optimizer.step()
loss.backward()
if 'lars' in trainer.optimizer.type:
trainer.optimizer.minimize(loss)
else:
trainer.optimizer.step()

if 'loss' not in trainer.outputs:
trainer.outputs['loss'] = loss