Skip to content

Commit

Permalink
fix: RND-109: Add simple overlapping entities disambiguation (#187)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <[email protected]>
Co-authored-by: matt-bernstein <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 42f79be commit 459fe95
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 51 deletions.
32 changes: 24 additions & 8 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,38 @@ def extract_indices(self, df):
text = row[input_field_name]
entities = row[output_field_name]
to_remove = []
found_entities_ends = {}
for entity in entities:
# TODO: current naive implementation assumes that the quote_string is unique in the text.
# this can be as a baseline for now
# and we can improve this to handle entities ambiguity (for example, requesting "prefix" in response model)
# as well as fuzzy pattern matching
# TODO: current naive implementation uses exact string matching which can seem to be a baseline
# we can improve this further by handling ambiguity, for example:
# - requesting surrounding context from LLM
# - perform fuzzy matching over strings if model still hallucinates when copying the text
ent_str = entity[self._quote_string_field_name]
# to avoid overlapping entities, start from the end of the last entity with the same prefix
matching_end_indices = [
found_entities_ends[found_ent]
for found_ent in found_entities_ends
if found_ent.startswith(ent_str)
]
if matching_end_indices:
# start searching from the end of the last entity with the same prefix
start_search_idx = max(matching_end_indices)
else:
# start searching from the beginning
start_search_idx = 0

start_idx = text.lower().find(
entity[self._quote_string_field_name].lower()
entity[self._quote_string_field_name].lower(),
start_search_idx,
)
if start_idx == -1:
# we need to remove the entity if it is not found in the text
to_remove.append(entity)
else:
end_index = start_idx + len(entity[self._quote_string_field_name])
entity["start"] = start_idx
entity["end"] = start_idx + len(
entity[self._quote_string_field_name]
)
entity["end"] = end_index
found_entities_ends[ent_str] = end_index
for entity in to_remove:
entities.remove(entity)
return df
Expand Down
53 changes: 12 additions & 41 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
[tool.poetry.dependencies]
python = ">=3.9,<3.12"
pandas = "*"
openai = "^1.14.3"
openai = "^1.42.0"
guidance = "0.0.64"
pydantic = "^2"
rich = "^13"
Expand Down Expand Up @@ -62,7 +62,7 @@ fakeredis = "^2.23.2"
flower = "^2.0.1"
pytest-asyncio = "^0.23.7"
celery = {extras = ["pytest"], version = "^5.4.0"}
openai-responses = "^0.8.1"
openai-responses = "^0.9.1"
pytest-recording = "^0.13.1"
mockafka-py = "^0.1.57"

Expand Down

0 comments on commit 459fe95

Please sign in to comment.