Skip to content

Commit

Permalink
feat: DIA-1450: Image Classification support in adala (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Nov 26, 2024
1 parent 97ee09c commit 70b38bd
Show file tree
Hide file tree
Showing 8 changed files with 1,277 additions and 184 deletions.
8 changes: 6 additions & 2 deletions adala/runtimes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .base import Runtime, AsyncRuntime
from ._openai import OpenAIChatRuntime, OpenAIVisionRuntime, AsyncOpenAIChatRuntime
from ._litellm import LiteLLMChatRuntime, AsyncLiteLLMChatRuntime
from ._openai import OpenAIChatRuntime, AsyncOpenAIChatRuntime, AsyncOpenAIVisionRuntime
from ._litellm import (
LiteLLMChatRuntime,
AsyncLiteLLMChatRuntime,
AsyncLiteLLMVisionRuntime,
)
415 changes: 254 additions & 161 deletions adala/runtimes/_litellm.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from ._litellm import AsyncLiteLLMChatRuntime, LiteLLMChatRuntime, LiteLLMVisionRuntime
from ._litellm import (
AsyncLiteLLMChatRuntime,
LiteLLMChatRuntime,
AsyncLiteLLMVisionRuntime,
)


# litellm already reads the OPENAI_API_KEY env var, which was the reason for this class
OpenAIChatRuntime = LiteLLMChatRuntime
AsyncOpenAIChatRuntime = AsyncLiteLLMChatRuntime
OpenAIVisionRuntime = LiteLLMVisionRuntime
AsyncOpenAIVisionRuntime = AsyncLiteLLMVisionRuntime
74 changes: 56 additions & 18 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import pandas as pd
from typing import Type, Iterator, Optional
from functools import cached_property
from collections import defaultdict
from adala.skills._base import TransformSkill
from adala.runtimes import AsyncLiteLLMVisionRuntime
from adala.runtimes._litellm import MessageChunkType
from pydantic import BaseModel, Field, model_validator

from adala.runtimes import Runtime, AsyncRuntime
from adala.utils.internal_data import InternalDataFrame

from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk.label_interface.control_tags import ControlTag, ObjectTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import (
json_schema_to_pydantic,
)
Expand All @@ -35,39 +38,60 @@ class LabelStudioSkill(TransformSkill):

# TODO: implement postprocessing to verify Taxonomy

@cached_property
def label_interface(self) -> LabelInterface:
return LabelInterface(self.label_config)

@cached_property
def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
# NOTE: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config
if tag.tag.lower() == "labels":
control_tag_names = self.allowed_control_tags or list(
self.label_interface._controls.keys()
)
for tag_name in control_tag_names:
tag = self.label_interface.get_control(tag_name)
if tag.tag.lower() in {"labels", "hypertextlabels"}:
yield tag

@cached_property
def image_tags(self) -> Iterator[ObjectTag]:
# check if any image tags are used as input variables
object_tag_names = self.allowed_object_tags or list(
self.label_interface._objects.keys()
)
for tag_name in object_tag_names:
tag = self.label_interface.get_object(tag_name)
if tag.tag.lower() == "image":
yield tag

@model_validator(mode="after")
def validate_response_model(self):

interface = LabelInterface(self.label_config)
logger.debug(f"Read labeling config {self.label_config}")

if self.allowed_control_tags or self.allowed_object_tags:
if self.allowed_control_tags:
control_tags = {
tag: interface._controls[tag] for tag in self.allowed_control_tags
tag: self.label_interface._controls[tag]
for tag in self.allowed_control_tags
}
else:
control_tags = interface._controls
control_tags = self.label_interface._controls
if self.allowed_object_tags:
object_tags = {
tag: interface._objects[tag] for tag in self.allowed_object_tags
tag: self.label_interface._objects[tag]
for tag in self.allowed_object_tags
}
else:
object_tags = interface._objects
object_tags = self.label_interface._objects
interface = LabelInterface.create_instance(
tags={**control_tags, **object_tags}
)
logger.debug(
f"Filtered labeling config based on allowed tags {self.allowed_control_tags=} and {self.allowed_object_tags=} to {interface.config}"
)
else:
interface = self.label_interface

# NOTE: filtered label config is used for the response model, but full label config is used for the prompt, so that the model has as much context as possible.
self.field_schema = interface.to_json_schema()
Expand Down Expand Up @@ -100,14 +124,28 @@ async def aapply(
) -> InternalDataFrame:

with json_schema_to_pydantic(self.field_schema) as ResponseModel:
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
)
for ner_tag in self.ner_tags():
# special handling to flag image inputs if they exist
if isinstance(runtime, AsyncLiteLLMVisionRuntime):
input_field_types = defaultdict(lambda: MessageChunkType.TEXT)
for tag in self.image_tags:
input_field_types[tag.name] = MessageChunkType.IMAGE_URL
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
input_field_types=input_field_types,
)
else:
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
)
for ner_tag in self.ner_tags:
input_field_name = ner_tag.objects[0].value.lstrip("$")
output_field_name = ner_tag.name
quote_string_field_name = "text"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": "Hey, how''s it going?"}], "model":
"gpt-4o-mini", "max_tokens": 1000, "seed": 47, "temperature": 0.0}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '143'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.47.1
x-stainless-arch:
- x64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.11.5
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA4xSy27bMBC86yu2vPRiFbbs1o9LUPTS9NBDgz7QIBBociWxobgsuYLjBgb6G/29
fklB2bEUNAV6IcCZncHMkvcZgDBabECoRrJqvc1ff9FXTf3+zXfeVl/ryw/x46d3P1q7M/X+6k5M
koK231Dxg+qFotZbZEPuSKuAkjG5zpbz4uVyvVque6IljTbJas/5gvLWOJMX02KRT5f5bHVSN2QU
RrGB6wwA4L4/U06n8U5sYDp5QFqMUdYoNuchABHIJkTIGE1k6VhMBlKRY3R99MvnLWgyroYdWjsB
bqS7hT11z+At7UBuqeN0vYDPjeTfP39FIJeAAK1xGpi03F+MzQNWXZSpoOusPeGHc1pLtQ+0jSf+
jFfGmdiUAWUkl5JFJi969pAB3PRb6R4VFT5Q67lkukWXDGeLo50Y3mJErk4kE0s74PNi8oRbqZGl
sXG0VaGkalAPyuEJZKcNjYhs1PnvME95H3sbV/+P/UAohZ5Rlz6gNupx4WEsYPqp/xo777gPLOI+
MrZlZVyNwQdz/CeVL+caZ8VqNX21Ftkh+wMAAP//AwDs57wINQMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8e85a915891d6208-ORD
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 26 Nov 2024 00:11:19 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=f2eAWUmcSjgkraa7rJvzhr53.Kz3y7EZniQwAmrWmHg-1732579879-1.0.1.1-FcTG.L1LC0IYeDrJNsA3S_9CqAeK8RVmE9li1oKj8OrrEOFELgjJ.wfKOQqQi8SWUsocl.oe2kGwriII9BVQ5Q;
path=/; expires=Tue, 26-Nov-24 00:41:19 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=5M11WH7821NNRxCf3t86tF5_JSGA0RXiNMeAxl1Pa4A-1732579879834-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- heartex
openai-processing-ms:
- '488'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149998994'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_c89ae189bd037c2fdf4605f19a3115f5
status:
code: 200
message: OK
- request:
body: '{"messages": [{"role": "user", "content": "\n Given
the title of a museum painting:\nIt''s definitely not the Mona Lisa\n and the
image of the painting:\nhttps://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa%2C_by_Leonardo_da_Vinci%2C_from_C2RMF_retouched.jpg/687px-Mona_Lisa%2C_by_Leonardo_da_Vinci%2C_from_C2RMF_retouched.jpg\n,\n classify
the painting as either \"Mona Lisa\" or \"Not Mona Lisa\".\n They
may or may not agree with each other. If the title and image disagree, believe
the image.\n "}], "model": "gpt-4o-mini", "max_tokens": 1000,
"seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", "function":
{"name": "MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel",
"description": "Correctly extracted `MyModel` with all the required parameters
with correct types", "parameters": {"properties": {"classification": {"description":
"Choices for image", "enum": ["Mona Lisa", "Not Mona Lisa"], "title": "Classification",
"type": "string"}}, "required": ["classification"], "type": "object"}}}]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '1124'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.47.1
x-stainless-arch:
- x64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.11.5
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA4xTUW/TMBB+z6+w7rlBSelIyNsmgRhqERpioFEUuc4l9ebYnu1IlKr/HdnpkrQU
iTxY1n33fXf3+bKPCAFeQUGAbaljrRbx9ffqC79/Xr3p3n+6edfJ5cOH34+39c0i+5YgzDxDbR6R
uRfWK6ZaLdBxJXuYGaQOvWqavZ5fZW/zPAlAqyoUntZoFy9U3HLJ43kyX8RJFqf5kb1VnKGFgvyI
CCFkH07fp6zwFxQkaIVIi9bSBqEYkggBo4SPALWWW0elg9kIMiUdSt+67ISYAE4pUTIqxFi4//aT
+2gWFaLcpJpf7+rlXf7ViXZ7e/f8RD9/5GZSr5fe6dBQ3Uk2mDTBh3hxVowQkLQN3NVuFbybnSdQ
03QtSufbhv0amPBz15xRL7mGYg0rJSlZckvXcIAT/iG6dP85scVg3Vkqjn4d44fhAYRqtFEbe+Yn
1Fxyuy0NUhvmAuuU7mv7OqECdCdvB9qoVrvSqSeUXnCepr0ejPs1otkRc8pRMSXlswtyZYWO8vC2
wzoxyrZYjdRxrWhXcTUBosnQfzdzSbsfnMvmf+RHgDHUDqtSG6w4Ox14TDPo/75/pQ0mh4bB7qzD
tqy5bNBow8PuQ63LJEuuNnWesQSiQ/QHAAD//wMAq681QQkEAAA=
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8e85a91a5a7622f3-ORD
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 26 Nov 2024 00:11:20 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=mR.lQGByVqO3YXPOJhOAYfQSCaSh.GGUAiqvmTKYeF4-1732579880-1.0.1.1-kjoNgd4tNmz.8ile246dtkSjbL3C9pTtBxM35zH_sQENgFJuN91lWEVTAYebM_Au.qq8D_Sr1S1_DegpYxCo7A;
path=/; expires=Tue, 26-Nov-24 00:41:20 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=vJrrlSUKQKX62ERSv.300oGbNMFud1yC5ztTRDPBooA-1732579880316-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- heartex
openai-processing-ms:
- '189'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149998865'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_5d437fcbab69225ff907cf1da14e1bb7
status:
code: 200
message: OK
version: 1
Loading

0 comments on commit 70b38bd

Please sign in to comment.