diff --git a/mdai_utils/download_annotations.py b/mdai_utils/download_annotations.py index 7ba04d1..820567f 100644 --- a/mdai_utils/download_annotations.py +++ b/mdai_utils/download_annotations.py @@ -64,6 +64,8 @@ def get_mask_from_vertices(vertices, upscale_factor=100, original_image_shape=No vertices = np.array(vertices).reshape((-1, 2)) """ slice_label_mask = get_empty_mask(original_image_shape) + if len(vertices) == 0: + return slice_label_mask slice_shape = slice_label_mask.shape apply_supersample = False if upscale_factor is None else upscale_factor > 1 if not apply_supersample: @@ -86,13 +88,21 @@ def get_mask_from_vertices(vertices, upscale_factor=100, original_image_shape=No def get_mask_from_annotation(row, original_image_shape=None, upscale_factor=100): annotationMode = row["annotationMode"] + rowdata = row.get("data", None) + if rowdata is None: + return get_empty_mask(original_image_shape) if annotationMode == "freeform": - vertices = np.array(row["data"]["vertices"]).reshape((-1, 2)) + rowvertices = rowdata.get("vertices", []) + vertices = np.array(rowvertices).reshape((-1, 2)) slice_label_mask = get_mask_from_vertices( vertices, upscale_factor, original_image_shape ) elif annotationMode == "mask": slice_label_mask = get_empty_mask(original_image_shape) + if rowdata.get("foreground", None) is None: + return slice_label_mask + if rowdata.get("background", None) is None: + return slice_label_mask if row.data["foreground"]: for i in row.data["foreground"]: slice_label_mask = cv2.fillPoly( @@ -305,9 +315,10 @@ def merge_slices_into3D( ] = global_annotations_dict["mdai_label_group_ids"] if study_id in global_annotations_dict: # Check for study labels first: - for label in global_annotations_dict[study_id]["study_labels"]: - global_labels_dict.setdefault(label, 0) - global_labels_dict[label] += 1 + if "study_labels" in global_annotations_dict[study_id]: + for label in global_annotations_dict[study_id]["study_labels"]: + global_labels_dict.setdefault(label, 0) + global_labels_dict[label] += 1 # Check for series labels: if series_id in global_annotations_dict[study_id]: for label in global_annotations_dict[study_id][series_id]: @@ -497,7 +508,7 @@ def main(args): ) # Get the json for annotations - last_json_file = get_last_json_file(out_folder, match_str=mdai_project_id) + last_json_file = get_last_json_file(out_folder, match_str=mdai_dataset_id) logger.info(f"Last json file: {last_json_file}") @@ -547,7 +558,13 @@ def main(args): hash_path = Path(hash_id.replace(sep, "/") + ".dcm") raw_slice_path = Path(match_folder) / hash_path pair_data_entry["image"] = str(raw_slice_path.resolve()) - raw_slice_image = itk.imread(raw_slice_path) + try: + raw_slice_image = itk.imread(raw_slice_path) + except RuntimeError as e: + logger.warning( + f"Could not read dicom file: {raw_slice_path}.\n{e}\nIt might be an invalid json data from mdai. Safe to ignore." + ) + continue for _, row in group.iterrows(): row_label_id = row["labelId"] @@ -566,8 +583,6 @@ def main(args): # slice_label_mask = np.flipud(slice_label_mask) # Use itk to save the mask, even in nifti format label_image = itk.image_from_array(slice_label_mask) - # We are going to save a 3D slice, we are interested in storing the z-position. - label_name = ( mdai_label_ids.inverse.get(row_label_id, False) or row.get("labelName", False) @@ -598,6 +613,9 @@ def main(args): itk.imwrite(label_image, str(label_file)) + # Check that pair_data_entry contains at least one label + if len(pair_data_entry) < 3: + continue pair_data.append(pair_data_entry) with open(pair_data_json_file, "w") as f: