diff --git a/src/encord_active/lib/embeddings/cnn.py b/src/encord_active/lib/embeddings/cnn.py index 8bcc20d93..1225d67bd 100644 --- a/src/encord_active/lib/embeddings/cnn.py +++ b/src/encord_active/lib/embeddings/cnn.py @@ -88,7 +88,7 @@ def assemble_object_batch(data_unit: dict, img_path: Path, transforms: Optional[ @torch.inference_mode() -def generate_cnn_image_embeddings(iterator: Iterator, filepath: str) -> None: +def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]: start = time.perf_counter() feature_extractor, transforms = get_model_and_transforms() @@ -117,16 +117,15 @@ def generate_cnn_image_embeddings(iterator: Iterator, filepath: str) -> None: ) collections.append(entry) - with open(filepath, "wb") as f: - pickle.dump(collections, f) - logger.info( f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds", ) + return collections + @torch.inference_mode() -def generate_cnn_object_embeddings(iterator: Iterator, filepath: str) -> None: +def generate_cnn_object_embeddings(iterator: Iterator) -> List[LabelEmbedding]: start = time.perf_counter() feature_extractor, transforms = get_model_and_transforms() @@ -163,16 +162,15 @@ def generate_cnn_object_embeddings(iterator: Iterator, filepath: str) -> None: collections.append(entry) - with open(filepath, "wb") as f: - pickle.dump(collections, f) - logger.info( f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds", ) + return collections + @torch.inference_mode() -def generate_cnn_classification_embeddings(iterator: Iterator, filepath: str) -> None: +def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbedding]: image_collections = get_cnn_embeddings(iterator, embedding_type=EmbeddingType.IMAGE) ontology_class_hash_to_index: dict[str, dict] = {} @@ -259,13 +257,12 @@ def generate_cnn_classification_embeddings(iterator: Iterator, filepath: str) -> ) collections.append(entry) - with open(filepath, "wb") as f: - pickle.dump(collections, f) - logger.info( f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds", ) + return collections + def get_cnn_embeddings( iterator: Iterator, embedding_type: EmbeddingType, *, force: bool = False @@ -292,14 +289,18 @@ def get_cnn_embeddings( def generate_cnn_embeddings(iterator: Iterator, embedding_type: EmbeddingType, target: str): if embedding_type == EmbeddingType.IMAGE: - generate_cnn_image_embeddings(iterator, filepath=target) + cnn_embeddings = generate_cnn_image_embeddings(iterator) elif embedding_type == EmbeddingType.OBJECT: - generate_cnn_object_embeddings(iterator, filepath=target) + cnn_embeddings = generate_cnn_object_embeddings(iterator) elif embedding_type == EmbeddingType.CLASSIFICATION: - generate_cnn_classification_embeddings(iterator, filepath=target) + cnn_embeddings = generate_cnn_classification_embeddings(iterator) + else: + raise ValueError(f"Unsupported embedding type {embedding_type}") + + target_path = Path(target) + target_path.parent.mkdir(parents=True, exist_ok=True) + target_path.write_bytes(pickle.dumps(cnn_embeddings)) - with open(target, "rb") as f: - cnn_embeddings = pickle.load(f) logger.info("Done!") return cnn_embeddings