Skip to content

Commit 95a766c

Browse files
committed
some fixes for hyper-parameter conversion
1 parent 02a9c0e commit 95a766c

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

trainer/app_code/yolov5_trainer.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,11 @@ async def _start(self, model: str, additional_parameters: str = ''):
266266
cmd = f'python /app/train_det.py --exist-ok --patience {self.patience} \
267267
--batch-size {batch_size} --img {resolution} --data dataset.yaml --weights {model} \
268268
--project {self.training.training_folder} --name result --hyp {self.hyperparameter_path} \
269-
--epochs {self.epochs} {additional_parameters} \
270-
--point_sizes_by_id {p_sizes_by_id[:-1]} --flip_label_pairs {flip_label_pairs[:-1]}'
269+
--epochs {self.epochs} {additional_parameters}'
270+
if p_sizes_by_id:
271+
cmd += f' --point_sizes_by_id {p_sizes_by_id[:-1]}'
272+
if flip_label_pairs:
273+
cmd += f' --flip_label_pairs {flip_label_pairs[:-1]}'
271274

272275
await self.executor.start(cmd, env={'WANDB_MODE': 'disabled'})
273276

@@ -284,24 +287,27 @@ def _save_additional_hyperparameters(self) -> None:
284287
raise CriticalError(f'No hyperparameter file found at {self.hyperparameter_path}')
285288

286289
with open(self.hyperparameter_path, errors='ignore') as f:
287-
hyp = yaml.safe_load(f) # load hyps dict
288-
hyp = {k: float(v) for k, v in hyp.items()}
290+
hyp = dict(yaml.safe_load(f)) # load hyps dict
289291

290292
self.epochs = int(hyp.get('epochs', self.epochs))
291-
self.detect_nms_conf_thres = hyp.get('detect_nms_conf_thres', self.detect_nms_conf_thres)
292-
self.detect_nms_iou_thres = hyp.get('detect_nms_iou_thres', self.detect_nms_iou_thres)
293+
self.detect_nms_conf_thres = float(hyp.get('detect_nms_conf_thres', self.detect_nms_conf_thres))
294+
self.detect_nms_iou_thres = float(hyp.get('detect_nms_iou_thres', self.detect_nms_iou_thres))
293295

294-
for item in str(hyp.get('point_sizes_by_id', '')).split(','):
295-
k, v = item.split(':')
296-
self.point_sizes_by_uuid[str(k)] = float(v)
296+
if point_sizes_by_id_str := str(hyp.get('point_sizes_by_id', '')):
297+
for item in point_sizes_by_id_str.split(','):
298+
k, v = item.split(':')
299+
self.point_sizes_by_uuid[str(k)] = float(v)
297300

298-
for item in str(hyp.get('flip_label_pairs', '')).split(','):
299-
k, v = item.split(':')
300-
self.flip_label_uuid_pairs.append((str(k), str(v)))
301+
if flip_label_pairs_str := str(hyp.get('flip_label_pairs', '')):
302+
for item in flip_label_pairs_str.split(','):
303+
k, v = item.split(':')
304+
self.flip_label_uuid_pairs.append((str(k), str(v)))
301305

302306
hyp_str = ', '.join(f'{k}={v}' for k, v in hyp.items())
303307
logging.info('parsed hyperparameters %s: epochs: %d, detect_nms_conf_thres: %f, detect_nms_iou_thres: %f',
304308
hyp_str, self.epochs, self.detect_nms_conf_thres, self.detect_nms_iou_thres)
309+
logging.info('point_sizes_by_id: %s', self.point_sizes_by_uuid)
310+
logging.info('flip_label_pairs: %s', self.flip_label_uuid_pairs)
305311

306312
def _parse(self, labels_path: str, images_folder: str, model_information: ModelInformation) -> list[Detections]:
307313
detections = []

trainer/hyp_det.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ translate: 0.245
2626
scale: 0.898
2727
shear: 0.602
2828
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
29-
flipud: 0 # 0.00856 # image flip up-down (probability)
30-
fliplr: 0 # 0.5 # image flip left-right (probability)
29+
flipud: 0.0 # 0.00856 # image flip up-down (probability)
30+
fliplr: 0.0 # 0.5 # image flip left-right (probability)
3131
mosaic: 1.0 # image mosaic (probability)
3232
mixup: 0.243 # image mixup (probability)
3333
copy_paste: 0.0
@@ -39,5 +39,5 @@ detect_nms_conf_thres: 0.2
3939
detect_nms_iou_thres: 0.45
4040
#
4141
#extra point parameters (not directly used by yolov5 but by the trainer logic)
42-
# point_sizes_by_id: "1111-2222-3333-4444:0.03,5555-6666-7777-8888:0.05"
43-
# flip_label_pairs: "1111-2222-3333-4444:5555-6666-7777-8888"
42+
point_sizes_by_id: "" # e.g "1111-2222-3333-4444:0.03,5555-6666-7777-8888:0.05"
43+
flip_label_pairs: "" # e.g "1111-2222-3333-4444:5555-6666-7777-8888"

0 commit comments

Comments
 (0)