Skip to content

Commit

Permalink
feat: DIA-1323: Pass NER tasks from Adala server to Adala's Entity Ex…
Browse files Browse the repository at this point in the history
…traction during Inference Run (#181)

Co-authored-by: nik <[email protected]>
  • Loading branch information
niklub and nik authored Aug 16, 2024
1 parent 97efe2c commit 2c8214d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 57 deletions.
1 change: 1 addition & 0 deletions adala/skills/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .skillset import SkillSet, LinearSkillSet, ParallelSkillSet
from .collection.classification import ClassificationSkill
from .collection.entity_extraction import EntityExtraction
from .collection.rag import RAGSkill
from .collection.ontology_creation import OntologyCreator, OntologyMerger
from ._base import Skill, TransformSkill, AnalysisSkill, SynthesisSkill
9 changes: 9 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import json

import fastapi
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from adala.agents import Agent
from aiokafka import AIOKafkaProducer
from aiokafka.errors import UnknownTopicOrPartitionError
Expand Down Expand Up @@ -132,6 +135,12 @@ def get_index():
return {"status": "ok"}


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
logger.error(f'Request validation error: {exc}')
return JSONResponse(content=str(exc), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)


@app.post("/jobs/submit-streaming", response_model=Response[JobCreated])
async def submit_streaming(request: SubmitStreamingRequest):
"""
Expand Down
66 changes: 32 additions & 34 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Any, List, Dict
import json
from abc import abstractmethod
from pydantic import BaseModel, Field, computed_field, ConfigDict, model_validator
Expand Down Expand Up @@ -74,7 +74,7 @@ class LSEBatchItem(BaseModel):

task_id: int
# TODO this field no longer populates if there was an error, so validation fails without a default - should probably split this item into 3 different constructors corresponding to new internal adala objects (or just reuse those objects)
output: Optional[str] = None
output: Optional[Dict] = None
# we don't need to use reserved names anymore here because they're not in a DataFrame, but a structure with proper typing available
error: bool = Field(False, alias="_adala_error")
message: Optional[str] = Field(None, alias="_adala_message")
Expand All @@ -95,6 +95,32 @@ def check_error_consistency(self):

return self

@classmethod
def from_result(cls, result: Dict) -> "LSEBatchItem":
"""
Prepare a result for processing by the handler:
- extract error, message and detail if result is a failed prediction
- otherwise, put the result payload to the output field
"""
# Copy system fields
prepared_result = {k: v for k, v in result.items() if k in (
'task_id', '_adala_error', '_adala_message', '_adala_details')}

# Normalize results if they contain NaN
if result.get('_adala_error') != result.get('_adala_error'):
prepared_result['_adala_error'] = False
if result.get('_adala_message') != result.get('_adala_message'):
prepared_result['_adala_message'] = None
if result.get('_adala_details') != result.get('_adala_details'):
prepared_result['_adala_details'] = None

# filter out the rest of custom fields
prepared_result['output'] = {k: v for k, v in result.items() if k not in prepared_result}

logger.debug(f'Prepared result: {prepared_result}')

return cls(**prepared_result)


class LSEHandler(ResultHandler):
"""
Expand Down Expand Up @@ -143,26 +169,11 @@ def prepare_errors_payload(self, error_batch):

return transformed_errors

def __call__(self, result_batch: list[LSEBatchItem]):
def __call__(self, result_batch: list[Dict]):
logger.debug(f"\n\nHandler received batch: {result_batch}\n\n")

# coerce dicts to LSEBatchItems for validation
norm_result_batch = []
for result in result_batch:

# This is checking for NaNs to avoid validation errors
if result.get('_adala_error') != result.get('_adala_error'):
result['_adala_error'] = False
if result.get('_adala_message') != result.get('_adala_message'):
result['_adala_message'] = None
if result.get('_adala_details') != result.get('_adala_details'):
result['_adala_details'] = None
if result.get('output') != result.get('output'):
result['output'] = None

logger.debug('Record in LSEHandler: %s', result)

norm_result_batch.append(LSEBatchItem(**result))
norm_result_batch = [LSEBatchItem.from_result(result) for result in result_batch]

result_batch = [record for record in norm_result_batch if not record.error]
error_batch = [record for record in norm_result_batch if record.error]
Expand Down Expand Up @@ -219,24 +230,11 @@ def write_header(self):

return self

def __call__(self, result_batch: list[LSEBatchItem]):
def __call__(self, result_batch: List[Dict]):
logger.debug(f"\n\nHandler received batch: {result_batch}\n\n")

# coerce dicts to LSEBatchItems for validation
norm_result_batch = []
for result in result_batch:

# This is checking for NaNs to avoid validation errors
if result.get('_adala_error') != result.get('_adala_error'):
result['_adala_error'] = False
if result.get('_adala_message') != result.get('_adala_message'):
result['_adala_message'] = None
if result.get('_adala_details') != result.get('_adala_details'):
result['_adala_details'] = None
if result.get('output') != result.get('output'):
result['output'] = None

norm_result_batch.append(LSEBatchItem(**result))
norm_result_batch = [LSEBatchItem.from_result(result) for result in result_batch]

# open and write to file
with open(self.output_path, "a") as f:
Expand Down
96 changes: 73 additions & 23 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from tempfile import NamedTemporaryFile
import pandas as pd
from copy import deepcopy


# TODO manage which keys correspond to which models/deployments, probably using a litellm Router
Expand Down Expand Up @@ -42,9 +43,9 @@
"runtimes": {
"default": {
"type": "AsyncLiteLLMChatRuntime",
"model": "gpt-3.5-turbo-0125",
"model": "gpt-4o-mini",
"api_key": OPENAI_API_KEY,
"max_tokens": 10,
"max_tokens": 200,
"temperature": 0,
"batch_size": 100,
"timeout": 10,
Expand Down Expand Up @@ -144,29 +145,68 @@ def test_ready_endpoint(client, redis_mock):
assert result == "ok", f"Expected status = ok, but instead returned {result}."


@pytest.mark.use_openai
@pytest.mark.use_server
def test_streaming(client):

data = pd.DataFrame.from_records(
@pytest.mark.parametrize("input_data, skills, output_column", [
# text classification
(
[
{"task_id": 1, "text": "anytexthere", "output": "Feature Lack"},
{"task_id": 2, "text": "othertexthere", "output": "Feature Lack"},
{"task_id": 3, "text": "anytexthere", "output": "Feature Lack"},
{"task_id": 4, "text": "othertexthere", "output": "Feature Lack"},
]
],
[{
"type": "ClassificationSkill",
"name": "text_classifier",
"instructions": "Always return the answer 'Feature Lack'.",
"input_template": "{text}",
"output_template": "{output}",
"labels": {
"output": [
"Feature Lack",
"Price",
"Integration Issues",
"Usability Concerns",
"Competitor Advantage",
]
},
}],
"output"
),
# entity extraction
(
[
{"task_id": 1, "text": "John Doe, 26 years old, works at Google", "entities": [{"start": 0, "end": 8, "label": "PERSON"}, {"start": 26, "end": 36, "label": "AGE"}, {"start": 47, "end": 53, "label": "ORG"}]},
{"task_id": 2, "text": "Jane Doe, 30 years old, works at Microsoft", "entities": [{"start": 0, "end": 8, "label": "PERSON"}, {"start": 26, "end": 36, "label": "AGE"}, {"start": 47, "end": 55, "label": "ORG"}]},
{"task_id": 3, "text": "John Smith, 40 years old, works at Amazon", "entities": [{"start": 0, "end": 10, "label": "PERSON"}, {"start": 28, "end": 38, "label": "AGE"}, {"start": 49, "end": 55, "label": "ORG"}]},
{"task_id": 4, "text": "Jane Smith, 35 years old, works at Facebook", "entities": [{"start": 0, "end": 10, "label": "PERSON"}, {"start": 28, "end": 38, "label": "AGE"}, {"start": 49, "end": 57, "label": "ORG"}]},
],
[{
"type": "EntityExtraction",
"name": "entity_extraction",
"input_template": 'Extract entities from the input text.\n\nInput:\n"""\n{text}\n"""',
"labels": ["PERSON", "AGE", "ORG"]
}],
"entities"
)
batch_data = data.drop("output", axis=1).to_dict(orient="records")
expected_output = data.set_index("task_id")["output"]
])
@pytest.mark.use_openai
@pytest.mark.use_server
def test_streaming_use_cases(client, input_data, skills, output_column):

data = pd.DataFrame.from_records(input_data)
batch_data = data.drop(output_column, axis=1).to_dict(orient="records")
submit_payload = deepcopy(SUBMIT_PAYLOAD)

with NamedTemporaryFile(mode="r") as f:

SUBMIT_PAYLOAD["result_handler"] = {
submit_payload["agent"]["skills"] = skills

submit_payload["result_handler"] = {
"type": "CSVHandler",
"output_path": f.name,
}

resp = client.post("/jobs/submit-streaming", json=SUBMIT_PAYLOAD)
resp = client.post("/jobs/submit-streaming", json=submit_payload)
resp.raise_for_status()
job_id = resp.json()["data"]["job_id"]

Expand Down Expand Up @@ -198,9 +238,19 @@ def test_streaming(client):

output = pd.read_csv(f.name).set_index("task_id")
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"

# check for expected output
expected_outputs = data.set_index("task_id")[output_column].tolist()
actual_outputs = [eval(item)[output_column] for item in output.output.tolist()]
for actual_output, expected_output in zip(actual_outputs, expected_outputs):
if skills[0]["type"] == "EntityExtraction":
# Live generations may be flaky, check only 3 entities are presented
actual_labels = [entity["label"] for entity in actual_output]
expected_labels = [entity["label"] for entity in expected_output]
assert actual_labels == expected_labels
continue

assert actual_output == expected_output, "adala did not return expected output"


@pytest.mark.use_openai
Expand All @@ -222,7 +272,6 @@ async def test_streaming_n_concurrent_requests(async_client):
)
batch_payload_data = data.drop("output", axis=1).to_dict(orient="records")
batch_payload_datas = [batch_payload_data[:2], batch_payload_data[2:]]
expected_output = data.set_index("task_id")["output"]

# this sometimes takes too long and flakes, set timeout_sec if behavior continues
outputs = await asyncio.gather(
Expand All @@ -236,9 +285,10 @@ async def test_streaming_n_concurrent_requests(async_client):

for output in outputs:
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"
expected_outputs = data.set_index("task_id")["output"].tolist()
actual_outputs = [eval(item)["output"] for item in output.output.tolist()]
for actual_output, expected_output in zip(actual_outputs, expected_outputs):
assert actual_output == expected_output, "adala did not return expected output"


@pytest.mark.skip(
Expand Down Expand Up @@ -332,7 +382,6 @@ def test_streaming_azure(client):
]
)
batch_data = data.drop("output", axis=1).to_dict(orient="records")
expected_output = data.set_index("task_id")["output"]

with NamedTemporaryFile(mode="r") as f:

Expand Down Expand Up @@ -386,6 +435,7 @@ def test_streaming_azure(client):

output = pd.read_csv(f.name).set_index("task_id")
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"
expected_outputs = data.set_index("task_id")["output"].tolist()
actual_outputs = [eval(item)["output"] for item in output.output.tolist()]
for actual_output, expected_output in zip(actual_outputs, expected_outputs):
assert actual_output == expected_output, "adala did not return expected output"

0 comments on commit 2c8214d

Please sign in to comment.