From 459fe95f00cf29b4266d4060f2ae89efcd89fabe Mon Sep 17 00:00:00 2001 From: niklub Date: Thu, 22 Aug 2024 08:15:22 +0100 Subject: [PATCH] fix: RND-109: Add simple overlapping entities disambiguation (#187) Co-authored-by: nik Co-authored-by: matt-bernstein <60152561+matt-bernstein@users.noreply.github.com> --- adala/skills/collection/entity_extraction.py | 32 +++++++++--- poetry.lock | 53 +++++--------------- pyproject.toml | 4 +- 3 files changed, 38 insertions(+), 51 deletions(-) diff --git a/adala/skills/collection/entity_extraction.py b/adala/skills/collection/entity_extraction.py index 7b42ff24..35d8ed49 100644 --- a/adala/skills/collection/entity_extraction.py +++ b/adala/skills/collection/entity_extraction.py @@ -270,22 +270,38 @@ def extract_indices(self, df): text = row[input_field_name] entities = row[output_field_name] to_remove = [] + found_entities_ends = {} for entity in entities: - # TODO: current naive implementation assumes that the quote_string is unique in the text. - # this can be as a baseline for now - # and we can improve this to handle entities ambiguity (for example, requesting "prefix" in response model) - # as well as fuzzy pattern matching + # TODO: current naive implementation uses exact string matching which can seem to be a baseline + # we can improve this further by handling ambiguity, for example: + # - requesting surrounding context from LLM + # - perform fuzzy matching over strings if model still hallucinates when copying the text + ent_str = entity[self._quote_string_field_name] + # to avoid overlapping entities, start from the end of the last entity with the same prefix + matching_end_indices = [ + found_entities_ends[found_ent] + for found_ent in found_entities_ends + if found_ent.startswith(ent_str) + ] + if matching_end_indices: + # start searching from the end of the last entity with the same prefix + start_search_idx = max(matching_end_indices) + else: + # start searching from the beginning + start_search_idx = 0 + start_idx = text.lower().find( - entity[self._quote_string_field_name].lower() + entity[self._quote_string_field_name].lower(), + start_search_idx, ) if start_idx == -1: # we need to remove the entity if it is not found in the text to_remove.append(entity) else: + end_index = start_idx + len(entity[self._quote_string_field_name]) entity["start"] = start_idx - entity["end"] = start_idx + len( - entity[self._quote_string_field_name] - ) + entity["end"] = end_index + found_entities_ends[ent_str] = end_index for entity in to_remove: entities.remove(entity) return df diff --git a/poetry.lock b/poetry.lock index b5661d1f..b07897f6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1535,34 +1535,6 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] -[[package]] -name = "faker" -version = "24.14.1" -description = "Faker is a Python package that generates fake data for you." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Faker-24.14.1-py3-none-any.whl", hash = "sha256:a5edba3aa17a1d689c8907e5b0cd1653079c2466a4807f083aa7b5f80a00225d"}, - {file = "Faker-24.14.1.tar.gz", hash = "sha256:380a3697e696ae4fcf50a93a3d9e0286fab7dfbf05a9caa4421fa4727c6b1e89"}, -] - -[package.dependencies] -python-dateutil = ">=2.4" - -[[package]] -name = "faker-openai-api-provider" -version = "0.2.0" -description = "Generate fake data that resembles fields in OpenAI API responses" -optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "faker_openai_api_provider-0.2.0-py3-none-any.whl", hash = "sha256:67d46f159810cd02e5c4fad6f3714b9f83188c3ee74a126e3aaf013f88f7a2e0"}, - {file = "faker_openai_api_provider-0.2.0.tar.gz", hash = "sha256:168ecfd4ad06113ad64de9ee095e5e88593b575145efd4e8b84cba602e73c7d3"}, -] - -[package.dependencies] -faker = ">=24.2.0,<25.0.0" - [[package]] name = "fakeredis" version = "2.23.2" @@ -4381,44 +4353,43 @@ sympy = "*" [[package]] name = "openai" -version = "1.34.0" +version = "1.42.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.34.0-py3-none-any.whl", hash = "sha256:018623c2f795424044675c6230fa3bfbf98d9e0aab45d8fd116f2efb2cfb6b7e"}, - {file = "openai-1.34.0.tar.gz", hash = "sha256:95c8e2da4acd6958e626186957d656597613587195abd0fb2527566a93e76770"}, + {file = "openai-1.42.0-py3-none-any.whl", hash = "sha256:dc91e0307033a4f94931e5d03cc3b29b9717014ad5e73f9f2051b6cb5eda4d80"}, + {file = "openai-1.42.0.tar.gz", hash = "sha256:c9d31853b4e0bc2dc8bd08003b462a006035655a701471695d0bfdc08529cde3"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "openai-responses" -version = "0.8.1" +version = "0.9.1" description = "🧪🤖 Pytest plugin for automatically mocking OpenAI requests" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "openai_responses-0.8.1-py3-none-any.whl", hash = "sha256:8c4f5cc72aea11d01d20d6081f75136f85c39c55909a7384e601fc353191723d"}, - {file = "openai_responses-0.8.1.tar.gz", hash = "sha256:a80a3efc9f6eb7a7a4c44ee8106511b1a4594dcdd2bfdc3ee3f25db54a9f42f7"}, + {file = "openai_responses-0.9.1-py3-none-any.whl", hash = "sha256:d6f10f28efba6f7a1c0302b71ee30cfd6d289f04592b1c57b0686c10bc2762d0"}, + {file = "openai_responses-0.9.1.tar.gz", hash = "sha256:37525b50a22009d8f2a1bc058cd9fb1189b85e906ad6f28f7bd304ebb03af4dd"}, ] [package.dependencies] -faker = ">=24.2.0,<25.0.0" -faker-openai-api-provider = "0.2.0" -openai = ">=1.32,<1.36" -requests-toolbelt = ">=1.0.0,<2.0.0" -respx = ">=0.20.2,<0.21.0" +openai = ">=1.32,<1.43" +requests-toolbelt = ">=1,<2" +respx = ">=0.20,<0.21" [[package]] name = "opentelemetry-api" @@ -7606,4 +7577,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "b469dd5ba23a1cd6129aee5bd5344c3a28ccfd71bb00f3b1455e036b03d463e8" +content-hash = "1c631832619f23822eaf3b5bda829e92c557a1d4d6ce3f84b96073446aa056b7" diff --git a/pyproject.toml b/pyproject.toml index c9bbb355..bff98666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9,<3.12" pandas = "*" -openai = "^1.14.3" +openai = "^1.42.0" guidance = "0.0.64" pydantic = "^2" rich = "^13" @@ -62,7 +62,7 @@ fakeredis = "^2.23.2" flower = "^2.0.1" pytest-asyncio = "^0.23.7" celery = {extras = ["pytest"], version = "^5.4.0"} -openai-responses = "^0.8.1" +openai-responses = "^0.9.1" pytest-recording = "^0.13.1" mockafka-py = "^0.1.57"