@@ -266,8 +266,11 @@ async def _start(self, model: str, additional_parameters: str = ''):
266
266
cmd = f'python /app/train_det.py --exist-ok --patience { self .patience } \
267
267
--batch-size { batch_size } --img { resolution } --data dataset.yaml --weights { model } \
268
268
--project { self .training .training_folder } --name result --hyp { self .hyperparameter_path } \
269
- --epochs { self .epochs } { additional_parameters } \
270
- --point_sizes_by_id { p_sizes_by_id [:- 1 ]} --flip_label_pairs { flip_label_pairs [:- 1 ]} '
269
+ --epochs { self .epochs } { additional_parameters } '
270
+ if p_sizes_by_id :
271
+ cmd += f' --point_sizes_by_id { p_sizes_by_id [:- 1 ]} '
272
+ if flip_label_pairs :
273
+ cmd += f' --flip_label_pairs { flip_label_pairs [:- 1 ]} '
271
274
272
275
await self .executor .start (cmd , env = {'WANDB_MODE' : 'disabled' })
273
276
@@ -284,24 +287,27 @@ def _save_additional_hyperparameters(self) -> None:
284
287
raise CriticalError (f'No hyperparameter file found at { self .hyperparameter_path } ' )
285
288
286
289
with open (self .hyperparameter_path , errors = 'ignore' ) as f :
287
- hyp = yaml .safe_load (f ) # load hyps dict
288
- hyp = {k : float (v ) for k , v in hyp .items ()}
290
+ hyp = dict (yaml .safe_load (f )) # load hyps dict
289
291
290
292
self .epochs = int (hyp .get ('epochs' , self .epochs ))
291
- self .detect_nms_conf_thres = hyp .get ('detect_nms_conf_thres' , self .detect_nms_conf_thres )
292
- self .detect_nms_iou_thres = hyp .get ('detect_nms_iou_thres' , self .detect_nms_iou_thres )
293
+ self .detect_nms_conf_thres = float ( hyp .get ('detect_nms_conf_thres' , self .detect_nms_conf_thres ) )
294
+ self .detect_nms_iou_thres = float ( hyp .get ('detect_nms_iou_thres' , self .detect_nms_iou_thres ) )
293
295
294
- for item in str (hyp .get ('point_sizes_by_id' , '' )).split (',' ):
295
- k , v = item .split (':' )
296
- self .point_sizes_by_uuid [str (k )] = float (v )
296
+ if point_sizes_by_id_str := str (hyp .get ('point_sizes_by_id' , '' )):
297
+ for item in point_sizes_by_id_str .split (',' ):
298
+ k , v = item .split (':' )
299
+ self .point_sizes_by_uuid [str (k )] = float (v )
297
300
298
- for item in str (hyp .get ('flip_label_pairs' , '' )).split (',' ):
299
- k , v = item .split (':' )
300
- self .flip_label_uuid_pairs .append ((str (k ), str (v )))
301
+ if flip_label_pairs_str := str (hyp .get ('flip_label_pairs' , '' )):
302
+ for item in flip_label_pairs_str .split (',' ):
303
+ k , v = item .split (':' )
304
+ self .flip_label_uuid_pairs .append ((str (k ), str (v )))
301
305
302
306
hyp_str = ', ' .join (f'{ k } ={ v } ' for k , v in hyp .items ())
303
307
logging .info ('parsed hyperparameters %s: epochs: %d, detect_nms_conf_thres: %f, detect_nms_iou_thres: %f' ,
304
308
hyp_str , self .epochs , self .detect_nms_conf_thres , self .detect_nms_iou_thres )
309
+ logging .info ('point_sizes_by_id: %s' , self .point_sizes_by_uuid )
310
+ logging .info ('flip_label_pairs: %s' , self .flip_label_uuid_pairs )
305
311
306
312
def _parse (self , labels_path : str , images_folder : str , model_information : ModelInformation ) -> list [Detections ]:
307
313
detections = []
0 commit comments