Skip to content

Commit

Permalink
fix: ensure embeddings folder exists (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
Encord-davids authored Jan 20, 2023
1 parent 1e65d33 commit 62b75bc
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/encord_active/lib/embeddings/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def assemble_object_batch(data_unit: dict, img_path: Path, transforms: Optional[


@torch.inference_mode()
def generate_cnn_image_embeddings(iterator: Iterator, filepath: str) -> None:
def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]:
start = time.perf_counter()
feature_extractor, transforms = get_model_and_transforms()

Expand Down Expand Up @@ -117,16 +117,15 @@ def generate_cnn_image_embeddings(iterator: Iterator, filepath: str) -> None:
)
collections.append(entry)

with open(filepath, "wb") as f:
pickle.dump(collections, f)

logger.info(
f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds",
)

return collections


@torch.inference_mode()
def generate_cnn_object_embeddings(iterator: Iterator, filepath: str) -> None:
def generate_cnn_object_embeddings(iterator: Iterator) -> List[LabelEmbedding]:
start = time.perf_counter()
feature_extractor, transforms = get_model_and_transforms()

Expand Down Expand Up @@ -163,16 +162,15 @@ def generate_cnn_object_embeddings(iterator: Iterator, filepath: str) -> None:

collections.append(entry)

with open(filepath, "wb") as f:
pickle.dump(collections, f)

logger.info(
f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds",
)

return collections


@torch.inference_mode()
def generate_cnn_classification_embeddings(iterator: Iterator, filepath: str) -> None:
def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbedding]:
image_collections = get_cnn_embeddings(iterator, embedding_type=EmbeddingType.IMAGE)

ontology_class_hash_to_index: dict[str, dict] = {}
Expand Down Expand Up @@ -259,13 +257,12 @@ def generate_cnn_classification_embeddings(iterator: Iterator, filepath: str) ->
)
collections.append(entry)

with open(filepath, "wb") as f:
pickle.dump(collections, f)

logger.info(
f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds",
)

return collections


def get_cnn_embeddings(
iterator: Iterator, embedding_type: EmbeddingType, *, force: bool = False
Expand All @@ -292,14 +289,18 @@ def get_cnn_embeddings(

def generate_cnn_embeddings(iterator: Iterator, embedding_type: EmbeddingType, target: str):
if embedding_type == EmbeddingType.IMAGE:
generate_cnn_image_embeddings(iterator, filepath=target)
cnn_embeddings = generate_cnn_image_embeddings(iterator)
elif embedding_type == EmbeddingType.OBJECT:
generate_cnn_object_embeddings(iterator, filepath=target)
cnn_embeddings = generate_cnn_object_embeddings(iterator)
elif embedding_type == EmbeddingType.CLASSIFICATION:
generate_cnn_classification_embeddings(iterator, filepath=target)
cnn_embeddings = generate_cnn_classification_embeddings(iterator)
else:
raise ValueError(f"Unsupported embedding type {embedding_type}")

target_path = Path(target)
target_path.parent.mkdir(parents=True, exist_ok=True)
target_path.write_bytes(pickle.dumps(cnn_embeddings))

with open(target, "rb") as f:
cnn_embeddings = pickle.load(f)
logger.info("Done!")

return cnn_embeddings

0 comments on commit 62b75bc

Please sign in to comment.