Skip to content

Commit

Permalink
Accept multiple inputs for get_batches
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Apr 24, 2023
1 parent 1a95234 commit 3d7b919
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions integrations/prodigy_openai/scripts/get_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections import Counter
from pathlib import Path
from typing import List

import typer
from wasabi import msg

import spacy
from spacy.tokens import DocBin

Arg = typer.Argument
Opt = typer.Option


def get_distribution(
# fmt: off
input_path: List[Path] = typer.Argument(..., help="Path to the spaCy file."),
n: int = typer.Option(5, "-n", "--top-n", help="Top-n entities to include in the report."),
# fmt: on
):
"""Get the distribution of entities given a list of spaCy files"""
nlp = spacy.blank("en")

docs = []
for path in input_path:
doc_bin = DocBin().from_disk(path)
_docs = list(doc_bin.get_docs(nlp.vocab))
docs.extend(_docs)

# Get the entity counts
num_docs = len(docs)
msg.info(f"Found {num_docs} documents in {', '.join([str(p) for p in input_path])}")
entity_counts = Counter()
for doc in docs:
for ent in doc.ents:
if ent.label_ not in entity_counts:
entity_counts[ent.label_] = 0
else:
entity_counts[ent.label_] += 1

# Get the distribution (normalize everything)
total = sum(entity_counts.values())
_fmt_counts = " ".join(
[
f"{ent} ({(count / total) * 100:.2f}%)"
for ent, count in entity_counts.most_common(n)
]
)
msg.text(f"Top-{n} entities by count: {_fmt_counts}")


if __name__ == "__main__":
typer.run(get_distribution)

0 comments on commit 3d7b919

Please sign in to comment.