Skip to content

Commit e7a2412

Browse files
authored
Fix "training_staus" typo (#476)
1 parent ab27810 commit e7a2412

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

rtdetr_paddle/ppdet/engine/callbacks.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ def on_step_end(self, status):
105105
epoch_id = status['epoch_id']
106106
step_id = status['step_id']
107107
steps_per_epoch = status['steps_per_epoch']
108-
training_staus = status['training_staus']
108+
training_status = status['training_status']
109109
batch_time = status['batch_time']
110110
data_time = status['data_time']
111111

112112
epoches = self.model.cfg.epoch
113113
batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
114114
))]['batch_size']
115115

116-
logs = training_staus.log()
116+
logs = training_status.log()
117117
space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
118118
if step_id % self.model.cfg.log_iter == 0:
119119
eta_steps = (epoches - epoch_id) * steps_per_epoch - step_id
@@ -278,8 +278,8 @@ def on_step_end(self, status):
278278
mode = status['mode']
279279
if dist.get_world_size() < 2 or dist.get_rank() == 0:
280280
if mode == 'train':
281-
training_staus = status['training_staus']
282-
for loss_name, loss_value in training_staus.get().items():
281+
training_status = status['training_status']
282+
for loss_name, loss_value in training_status.get().items():
283283
self.vdl_writer.add_scalar(loss_name, loss_value,
284284
self.vdl_loss_step)
285285
self.vdl_loss_step += 1
@@ -401,7 +401,7 @@ def on_step_end(self, status):
401401
mode = status['mode']
402402
if dist.get_world_size() < 2 or dist.get_rank() == 0:
403403
if mode == 'train':
404-
training_status = status['training_staus'].get()
404+
training_status = status['training_status'].get()
405405
for k, v in training_status.items():
406406
training_status[k] = float(v)
407407

rtdetr_paddle/ppdet/engine/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def train(self, validate=False):
300300
self.cfg.log_iter, fmt='{avg:.4f}')
301301
self.status['data_time'] = stats.SmoothedValue(
302302
self.cfg.log_iter, fmt='{avg:.4f}')
303-
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
303+
self.status['training_status'] = stats.TrainingStats(self.cfg.log_iter)
304304

305305
profiler_options = self.cfg.get('profiler_options', None)
306306

@@ -385,7 +385,7 @@ def train(self, validate=False):
385385
self.status['learning_rate'] = curr_lr
386386

387387
if self._nranks < 2 or self._local_rank == 0:
388-
self.status['training_staus'].update(outputs)
388+
self.status['training_status'].update(outputs)
389389

390390
self.status['batch_time'].update(time.time() - iter_tic)
391391
self._compose_callback.on_step_end(self.status)

0 commit comments

Comments
 (0)