From bcdc11271137861520e447371417f889f8ce09f1 Mon Sep 17 00:00:00 2001
From: matt-bernstein <60152561+matt-bernstein@users.noreply.github.com>
Date: Mon, 18 Nov 2024 01:01:09 -0500
Subject: [PATCH] feat: DIA-1524: support for partially supported projects in
prompts (#250)
---
adala/skills/collection/label_studio.py | 29 +-
...bel_studio_skill_partial_label_config.yaml | 334 ++++++++++++++++++
tests/test_label_studio_skill.py | 72 ++++
tests/test_stream_inference.py | 4 +-
4 files changed, 435 insertions(+), 4 deletions(-)
create mode 100644 tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml
diff --git a/adala/skills/collection/label_studio.py b/adala/skills/collection/label_studio.py
index 40f5038c..edf3a0d0 100644
--- a/adala/skills/collection/label_studio.py
+++ b/adala/skills/collection/label_studio.py
@@ -1,6 +1,6 @@
import logging
import pandas as pd
-from typing import Type, Iterator
+from typing import Type, Iterator, Optional
from functools import cached_property
from adala.skills._base import TransformSkill
from pydantic import BaseModel, Field, model_validator
@@ -30,6 +30,8 @@ class LabelStudioSkill(TransformSkill):
)
# ------------------------------
label_config: str = ""
+ allowed_control_tags: Optional[list[str]] = None
+ allowed_object_tags: Optional[list[str]] = None
# TODO: implement postprocessing to verify Taxonomy
@@ -37,8 +39,8 @@ def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag ( + ), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
- # TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
- if tag.tag == "Labels":
+ # NOTE: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config
+ if tag.tag.lower() == "labels":
yield tag
@model_validator(mode="after")
@@ -47,6 +49,27 @@ def validate_response_model(self):
interface = LabelInterface(self.label_config)
logger.debug(f"Read labeling config {self.label_config}")
+ if self.allowed_control_tags or self.allowed_object_tags:
+ if self.allowed_control_tags:
+ control_tags = {
+ tag: interface._controls[tag] for tag in self.allowed_control_tags
+ }
+ else:
+ control_tags = interface._controls
+ if self.allowed_object_tags:
+ object_tags = {
+ tag: interface._objects[tag] for tag in self.allowed_object_tags
+ }
+ else:
+ object_tags = interface._objects
+ interface = LabelInterface.create_instance(
+ tags={**control_tags, **object_tags}
+ )
+ logger.debug(
+ f"Filtered labeling config based on allowed tags {self.allowed_control_tags=} and {self.allowed_object_tags=} to {interface.config}"
+ )
+
+ # NOTE: filtered label config is used for the response model, but full label config is used for the prompt, so that the model has as much context as possible.
self.field_schema = interface.to_json_schema()
logger.debug(f"Converted labeling config to json schema: {self.field_schema}")
diff --git a/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml
new file mode 100644
index 00000000..6295ac14
--- /dev/null
+++ b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml
@@ -0,0 +1,334 @@
+interactions:
+- request:
+ body: '{"messages": [{"role": "user", "content": "Hey, how''s it going?"}], "model":
+ "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0}'
+ headers:
+ accept:
+ - application/json
+ accept-encoding:
+ - gzip, deflate
+ connection:
+ - keep-alive
+ content-length:
+ - '142'
+ content-type:
+ - application/json
+ host:
+ - api.openai.com
+ user-agent:
+ - OpenAI/Python 1.47.1
+ x-stainless-arch:
+ - x64
+ x-stainless-async:
+ - 'false'
+ x-stainless-lang:
+ - python
+ x-stainless-os:
+ - Linux
+ x-stainless-package-version:
+ - 1.47.1
+ x-stainless-raw-response:
+ - 'true'
+ x-stainless-runtime:
+ - CPython
+ x-stainless-runtime-version:
+ - 3.11.5
+ method: POST
+ uri: https://api.openai.com/v1/chat/completions
+ response:
+ body:
+ string: !!binary |
+ H4sIAAAAAAAAA4xSwY7TMBS85ysevnBpUNIGtvSyQnuAlVaA2ANCCEWu/Zp46/gZ+0WlWlXiN/g9
+ vgQ57TZZsUhcLHnmzWjm2fcZgDBarECoVrLqvM3f3PK2u7q9vrl5XX1Y3Bn8fvV+t/20L77sP74V
+ s6Sg9R0qflC9UNR5i2zIHWkVUDIm1/JiUVYvi+WraiA60miTrPGcV5R3xpl8XsyrvLjIy+VJ3ZJR
+ GMUKvmYAAPfDmXI6jT/ECorZA9JhjLJBsToPAYhANiFCxmgiS8diNpKKHKMbol8/70CTcQ3s0NoZ
+ cCvdFvbUP4N3tAO5pp7T9RI+t5J///wVgVwCAnTGaWDScn85NQ+46aNMBV1v7Qk/nNNaanygdTzx
+ Z3xjnIltHVBGcilZZPJiYA8ZwLdhK/2josIH6jzXTFt0ybCsjnZifIsJuTyRTCztiC/msyfcao0s
+ jY2TrQolVYt6VI5PIHttaEJkk85/h3nK+9jbuOZ/7EdCKfSMuvYBtVGPC49jAdNP/dfYecdDYBH3
+ kbGrN8Y1GHwwx3+y8XWxloUu59WmFNkh+wMAAP//AwDNFQ3bNQMAAA==
+ headers:
+ CF-Cache-Status:
+ - DYNAMIC
+ CF-RAY:
+ - 8e19fd3f8a1329bd-ORD
+ Connection:
+ - keep-alive
+ Content-Encoding:
+ - gzip
+ Content-Type:
+ - application/json
+ Date:
+ - Tue, 12 Nov 2024 22:34:25 GMT
+ Server:
+ - cloudflare
+ Set-Cookie:
+ - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q;
+ path=/; expires=Tue, 12-Nov-24 23:04:25 GMT; domain=.api.openai.com; HttpOnly;
+ Secure; SameSite=None
+ - _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000;
+ path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
+ Transfer-Encoding:
+ - chunked
+ X-Content-Type-Options:
+ - nosniff
+ access-control-expose-headers:
+ - X-Request-ID
+ alt-svc:
+ - h3=":443"; ma=86400
+ openai-organization:
+ - heartex
+ openai-processing-ms:
+ - '407'
+ openai-version:
+ - '2020-10-01'
+ strict-transport-security:
+ - max-age=31536000; includeSubDomains; preload
+ x-ratelimit-limit-requests:
+ - '30000'
+ x-ratelimit-limit-tokens:
+ - '150000000'
+ x-ratelimit-remaining-requests:
+ - '29999'
+ x-ratelimit-remaining-tokens:
+ - '149999794'
+ x-ratelimit-reset-requests:
+ - 2ms
+ x-ratelimit-reset-tokens:
+ - 0s
+ x-request-id:
+ - req_6a70cd04fcd3e0434d662aaa032a14ee
+ status:
+ code: 200
+ message: OK
+- request:
+ body: '{"messages": [{"role": "user", "content": "\n Given
+ the github issue title:\nI can''t login\n and the description:\nI can''t login
+ to the platform\n, \n classify the issue. Provide a rationale
+ for your classification. \n Evaluate the final classification
+ on a Likert scale from 1 to 5, \n where 1 is \"Completely
+ irrelevant\" and 5 is \"Completely relevant\"."}], "model": "gpt-4o-mini", "max_tokens":
+ 200, "seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", "function":
+ {"name": "MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel",
+ "description": "Correctly extracted `MyModel` with all the required parameters
+ with correct types", "parameters": {"properties": {"classification": {"description":
+ "Choices for [''title'']", "enum": ["Bug report", "Feature request", "Question",
+ "Other"], "title": "Classification", "type": "string"}, "evaluation": {"description":
+ "Rating for [''title''] (0 to 5)", "maximum": 5, "minimum": 0, "title": "Evaluation",
+ "type": "integer"}, "rationale": {"description": "Text for [''title'']", "title":
+ "Rationale", "type": "string"}}, "required": ["classification", "evaluation",
+ "rationale"], "type": "object"}}}]}'
+ headers:
+ accept:
+ - application/json
+ accept-encoding:
+ - gzip, deflate
+ connection:
+ - keep-alive
+ content-length:
+ - '1252'
+ content-type:
+ - application/json
+ cookie:
+ - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q;
+ _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000
+ host:
+ - api.openai.com
+ user-agent:
+ - OpenAI/Python 1.47.1
+ x-stainless-arch:
+ - x64
+ x-stainless-async:
+ - 'false'
+ x-stainless-lang:
+ - python
+ x-stainless-os:
+ - Linux
+ x-stainless-package-version:
+ - 1.47.1
+ x-stainless-raw-response:
+ - 'true'
+ x-stainless-runtime:
+ - CPython
+ x-stainless-runtime-version:
+ - 3.11.5
+ method: POST
+ uri: https://api.openai.com/v1/chat/completions
+ response:
+ body:
+ string: !!binary |
+ H4sIAAAAAAAAA4xUTW/bMAy951cQPCdFkiVtkduC7VAUHYZu3WUZDEamba2y5EpU2yDIfx/kfNjp
+ OmA+GDIf3+OHSG8HAKhzXACqikTVjRl9/CaPdr6Mlfq0vr/1y+fXl/gwv4mXn6/mHoeJ4da/WcmR
+ daFc3RgW7eweVp5JOKlOrj5MZvPx9eW8BWqXs0m0spHRzI1qbfVoOp7ORuOr0eT6wK6cVhxwAT8H
+ AADb9p3ytDm/4gLGw6Ol5hCoZFycnADQO5MsSCHoIGQFhx2onBW2KXUbjekB4pzJFBnTBd4/2965
+ axYZk319vVdPtxylMNdP/EA3X26X5N2PXry99KZpEyqiVacm9fCTffEmGABaqlvu3eau7d3wrQP5
+ MtZsJaWN2xUqk+outKIkucLFCpexBM+N87LC4Qr5mUw8ovPhCn37QYZb7+8Vgw4hMijD5M0GtM2T
+ HAcgaLxbG67hRUsFxpWltiVoKw6kYmgMSeF8PYSXSqsKdKIoV9fOQuoCuALWsQSpSICKgpUEiIE9
+ kFIcApDN4dgOMlo2Fyvc4VnNu8F751+9q/RcxEDmcMcH++40NMaVqYrwZgaw0FaHKvNMob0LDOKa
+ fewUp42A8WzesPGubiQT98g2CU7H870edjvRobPD5KI4IdNjzY6sM70sZyHdDuRpBxSpivOO2u0C
+ xVy7HjDoVf13Nu9p7yvXtvwf+Q5QihvhPGs851qdV9y5eU6/jH+5nbrcJoxhE4TrrNC2ZN943S4s
+ Fk02XtM4n0xnxQQHu8EfAAAA//8DABuRKKC+BAAA
+ headers:
+ CF-Cache-Status:
+ - DYNAMIC
+ CF-RAY:
+ - 8e19fd436ebd29bd-ORD
+ Connection:
+ - keep-alive
+ Content-Encoding:
+ - gzip
+ Content-Type:
+ - application/json
+ Date:
+ - Tue, 12 Nov 2024 22:34:26 GMT
+ Server:
+ - cloudflare
+ Transfer-Encoding:
+ - chunked
+ X-Content-Type-Options:
+ - nosniff
+ access-control-expose-headers:
+ - X-Request-ID
+ alt-svc:
+ - h3=":443"; ma=86400
+ openai-organization:
+ - heartex
+ openai-processing-ms:
+ - '1123'
+ openai-version:
+ - '2020-10-01'
+ strict-transport-security:
+ - max-age=31536000; includeSubDomains; preload
+ x-ratelimit-limit-requests:
+ - '30000'
+ x-ratelimit-limit-tokens:
+ - '150000000'
+ x-ratelimit-remaining-requests:
+ - '29999'
+ x-ratelimit-remaining-tokens:
+ - '149999703'
+ x-ratelimit-reset-requests:
+ - 2ms
+ x-ratelimit-reset-tokens:
+ - 0s
+ x-request-id:
+ - req_4e031827f28ce0f1485cab379aad87a3
+ status:
+ code: 200
+ message: OK
+- request:
+ body: '{"messages": [{"role": "user", "content": "\n Given
+ the github issue title:\nSupport new file types\n and the description:\nIt would
+ be great if we could upload files of type .docx\n, \n classify
+ the issue. Provide a rationale for your classification. \n Evaluate
+ the final classification on a Likert scale from 1 to 5, \n where
+ 1 is \"Completely irrelevant\" and 5 is \"Completely relevant\"."}], "model":
+ "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0, "tool_choice":
+ {"type": "function", "function": {"name": "MyModel"}}, "tools": [{"type": "function",
+ "function": {"name": "MyModel", "description": "Correctly extracted `MyModel`
+ with all the required parameters with correct types", "parameters": {"properties":
+ {"classification": {"description": "Choices for [''title'']", "enum": ["Bug
+ report", "Feature request", "Question", "Other"], "title": "Classification",
+ "type": "string"}, "evaluation": {"description": "Rating for [''title''] (0
+ to 5)", "maximum": 5, "minimum": 0, "title": "Evaluation", "type": "integer"},
+ "rationale": {"description": "Text for [''title'']", "title": "Rationale", "type":
+ "string"}}, "required": ["classification", "evaluation", "rationale"], "type":
+ "object"}}}]}'
+ headers:
+ accept:
+ - application/json
+ accept-encoding:
+ - gzip, deflate
+ connection:
+ - keep-alive
+ content-length:
+ - '1288'
+ content-type:
+ - application/json
+ cookie:
+ - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q;
+ _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000
+ host:
+ - api.openai.com
+ user-agent:
+ - OpenAI/Python 1.47.1
+ x-stainless-arch:
+ - x64
+ x-stainless-async:
+ - 'false'
+ x-stainless-lang:
+ - python
+ x-stainless-os:
+ - Linux
+ x-stainless-package-version:
+ - 1.47.1
+ x-stainless-raw-response:
+ - 'true'
+ x-stainless-runtime:
+ - CPython
+ x-stainless-runtime-version:
+ - 3.11.5
+ method: POST
+ uri: https://api.openai.com/v1/chat/completions
+ response:
+ body:
+ string: !!binary |
+ H4sIAAAAAAAAA4xUXW/bMAx8z68g+LQBTuGkSVrkbR8o1m0psLUbhs6Doci0rUWWVElumwb574Ps
+ JHa6DpgfBJvHOx5FyZsBAIoM54C8ZJ5XRg7fXPuVvrue2sXtw/tPT+ubZX01iz9/yaq39wVGgaGX
+ v4n7PeuE68pI8kKrFuaWmKegOjo7HU2m8fls1gCVzkgGWmH8cKKHlVBiOI7Hk2F8Nhyd79ilFpwc
+ zuHnAABg06zBp8roEecQR/tIRc6xgnB+SAJAq2WIIHNOOM+Ux6gDuVaeVLCuail7gNdappxJ2RVu
+ n03vvdssJmV6G+t3Hz4+XV79uPCnavHNXX5ffT0n2avXSq9NYyivFT9sUg8/xOfPigGgYlXDXawX
+ zd5FzxOYLeqKlA+2cZMgl6HvXHAWJBOcJ3hBzNeWwNJdTc4nGCVI90zW+5RplKBtPpikhnJTEgjn
+ 6rDueUIV4GpjtPWQawsMFD1ALiRB6BBenWSaP76O4KEUvAShsmCCHDDIyAlLLUsBqZIpTsE1aNuq
+ 7CwKBb4kcGvnqTpJcItH/W4HL73/6o3RUl47Jnfz3cW3hwMjdWGsXrpn88dcKOHK1BJzzRzQeW3a
+ 2qFOUwHro7OGxurK+NTrFakgOB5NWj3s7kOHTsY70GvPZI81nUUv6KUZeSaaw3g4/5zxkrKO2t0D
+ VmdC94BBr+u/3byk3XYuVPE/8h3AORlPWWosZYIfd9ylWQq/i3+lHXa5MYzt7NNcqIKssaK5rJib
+ NF6yOBuNJ/kIB9vBHwAAAP//AwC1d/r9ugQAAA==
+ headers:
+ CF-Cache-Status:
+ - DYNAMIC
+ CF-RAY:
+ - 8e19fd4b593729bd-ORD
+ Connection:
+ - keep-alive
+ Content-Encoding:
+ - gzip
+ Content-Type:
+ - application/json
+ Date:
+ - Tue, 12 Nov 2024 22:34:27 GMT
+ Server:
+ - cloudflare
+ Transfer-Encoding:
+ - chunked
+ X-Content-Type-Options:
+ - nosniff
+ access-control-expose-headers:
+ - X-Request-ID
+ alt-svc:
+ - h3=":443"; ma=86400
+ openai-organization:
+ - heartex
+ openai-processing-ms:
+ - '748'
+ openai-version:
+ - '2020-10-01'
+ strict-transport-security:
+ - max-age=31536000; includeSubDomains; preload
+ x-ratelimit-limit-requests:
+ - '30000'
+ x-ratelimit-limit-tokens:
+ - '150000000'
+ x-ratelimit-remaining-requests:
+ - '29999'
+ x-ratelimit-remaining-tokens:
+ - '149999694'
+ x-ratelimit-reset-requests:
+ - 2ms
+ x-ratelimit-reset-tokens:
+ - 0s
+ x-request-id:
+ - req_891bd00fef969262eafc45e9a07f3be4
+ status:
+ code: 200
+ message: OK
+version: 1
diff --git a/tests/test_label_studio_skill.py b/tests/test_label_studio_skill.py
index bcfe0d33..ca3279a1 100644
--- a/tests/test_label_studio_skill.py
+++ b/tests/test_label_studio_skill.py
@@ -74,6 +74,78 @@ async def test_label_studio_skill():
]
+@pytest.mark.asyncio
+@pytest.mark.vcr
+async def test_label_studio_skill_partial_label_config():
+
+ df = pd.DataFrame(
+ [
+ {"title": "I can't login", "description": "I can't login to the platform"},
+ {
+ "title": "Support new file types",
+ "description": "It would be great if we could upload files of type .docx",
+ },
+ ]
+ )
+
+ agent_payload = {
+ "runtimes": {
+ "default": {
+ "type": "AsyncLiteLLMChatRuntime",
+ "model": "gpt-4o-mini",
+ "api_key": os.getenv("OPENAI_API_KEY"),
+ "max_tokens": 200,
+ "temperature": 0,
+ "batch_size": 100,
+ "timeout": 10,
+ "verbose": False,
+ }
+ },
+ "skills": [
+ {
+ "type": "LabelStudioSkill",
+ "name": "AnnotationResult",
+ "input_template": """
+ Given the github issue title:\n{title}\n and the description:\n{description}\n,
+ classify the issue. Provide a rationale for your classification.
+ Evaluate the final classification on a Likert scale from 1 to 5,
+ where 1 is "Completely irrelevant" and 5 is "Completely relevant".""",
+ "label_config": """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ "allowed_control_tags": ["classification", "evaluation", "rationale"],
+ }
+ ],
+ }
+
+ agent = Agent(**agent_payload)
+ predictions = await agent.arun(df)
+
+ assert predictions.classification.tolist() == ["Bug report", "Feature request"]
+ assert predictions.evaluation.tolist() == [5, 5]
+ assert "rationale" in predictions.columns
+ assert "screenshot_quality" not in predictions.columns
+
+
@pytest.mark.asyncio
@pytest.mark.vcr
async def test_label_studio_skill_with_ner():
diff --git a/tests/test_stream_inference.py b/tests/test_stream_inference.py
index 1e3f223e..de78b2b6 100644
--- a/tests/test_stream_inference.py
+++ b/tests/test_stream_inference.py
@@ -165,4 +165,6 @@ async def test_run_streaming(
# Verify that producer is called with the correct amount of send_and_wait calls and data
assert mock_kafka_producer.send_and_wait.call_count == 1
- mock_kafka_producer.send_and_wait.assert_any_call("output_topic", value=TEST_OUTPUT_DATA)
+ mock_kafka_producer.send_and_wait.assert_any_call(
+ "output_topic", value=TEST_OUTPUT_DATA
+ )