From bcdc11271137861520e447371417f889f8ce09f1 Mon Sep 17 00:00:00 2001 From: matt-bernstein <60152561+matt-bernstein@users.noreply.github.com> Date: Mon, 18 Nov 2024 01:01:09 -0500 Subject: [PATCH] feat: DIA-1524: support for partially supported projects in prompts (#250) --- adala/skills/collection/label_studio.py | 29 +- ...bel_studio_skill_partial_label_config.yaml | 334 ++++++++++++++++++ tests/test_label_studio_skill.py | 72 ++++ tests/test_stream_inference.py | 4 +- 4 files changed, 435 insertions(+), 4 deletions(-) create mode 100644 tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml diff --git a/adala/skills/collection/label_studio.py b/adala/skills/collection/label_studio.py index 40f5038c..edf3a0d0 100644 --- a/adala/skills/collection/label_studio.py +++ b/adala/skills/collection/label_studio.py @@ -1,6 +1,6 @@ import logging import pandas as pd -from typing import Type, Iterator +from typing import Type, Iterator, Optional from functools import cached_property from adala.skills._base import TransformSkill from pydantic import BaseModel, Field, model_validator @@ -30,6 +30,8 @@ class LabelStudioSkill(TransformSkill): ) # ------------------------------ label_config: str = "" + allowed_control_tags: Optional[list[str]] = None + allowed_object_tags: Optional[list[str]] = None # TODO: implement postprocessing to verify Taxonomy @@ -37,8 +39,8 @@ def ner_tags(self) -> Iterator[ControlTag]: # check if the input config has NER tag ( + ), and return its `from_name` and `to_name` interface = LabelInterface(self.label_config) for tag in interface.controls: - # TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE - if tag.tag == "Labels": + # NOTE: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config + if tag.tag.lower() == "labels": yield tag @model_validator(mode="after") @@ -47,6 +49,27 @@ def validate_response_model(self): interface = LabelInterface(self.label_config) logger.debug(f"Read labeling config {self.label_config}") + if self.allowed_control_tags or self.allowed_object_tags: + if self.allowed_control_tags: + control_tags = { + tag: interface._controls[tag] for tag in self.allowed_control_tags + } + else: + control_tags = interface._controls + if self.allowed_object_tags: + object_tags = { + tag: interface._objects[tag] for tag in self.allowed_object_tags + } + else: + object_tags = interface._objects + interface = LabelInterface.create_instance( + tags={**control_tags, **object_tags} + ) + logger.debug( + f"Filtered labeling config based on allowed tags {self.allowed_control_tags=} and {self.allowed_object_tags=} to {interface.config}" + ) + + # NOTE: filtered label config is used for the response model, but full label config is used for the prompt, so that the model has as much context as possible. self.field_schema = interface.to_json_schema() logger.debug(f"Converted labeling config to json schema: {self.field_schema}") diff --git a/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml new file mode 100644 index 00000000..6295ac14 --- /dev/null +++ b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_partial_label_config.yaml @@ -0,0 +1,334 @@ +interactions: +- request: + body: '{"messages": [{"role": "user", "content": "Hey, how''s it going?"}], "model": + "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '142' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - x64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - Linux + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.5 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA4xSwY7TMBS85ysevnBpUNIGtvSyQnuAlVaA2ANCCEWu/Zp46/gZ+0WlWlXiN/g9 + vgQ57TZZsUhcLHnmzWjm2fcZgDBarECoVrLqvM3f3PK2u7q9vrl5XX1Y3Bn8fvV+t/20L77sP74V + s6Sg9R0qflC9UNR5i2zIHWkVUDIm1/JiUVYvi+WraiA60miTrPGcV5R3xpl8XsyrvLjIy+VJ3ZJR + GMUKvmYAAPfDmXI6jT/ECorZA9JhjLJBsToPAYhANiFCxmgiS8diNpKKHKMbol8/70CTcQ3s0NoZ + cCvdFvbUP4N3tAO5pp7T9RI+t5J///wVgVwCAnTGaWDScn85NQ+46aNMBV1v7Qk/nNNaanygdTzx + Z3xjnIltHVBGcilZZPJiYA8ZwLdhK/2josIH6jzXTFt0ybCsjnZifIsJuTyRTCztiC/msyfcao0s + jY2TrQolVYt6VI5PIHttaEJkk85/h3nK+9jbuOZ/7EdCKfSMuvYBtVGPC49jAdNP/dfYecdDYBH3 + kbGrN8Y1GHwwx3+y8XWxloUu59WmFNkh+wMAAP//AwDNFQ3bNQMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8e19fd3f8a1329bd-ORD + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 12 Nov 2024 22:34:25 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q; + path=/; expires=Tue, 12-Nov-24 23:04:25 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '407' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999794' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_6a70cd04fcd3e0434d662aaa032a14ee + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "\n Given + the github issue title:\nI can''t login\n and the description:\nI can''t login + to the platform\n, \n classify the issue. Provide a rationale + for your classification. \n Evaluate the final classification + on a Likert scale from 1 to 5, \n where 1 is \"Completely + irrelevant\" and 5 is \"Completely relevant\"."}], "model": "gpt-4o-mini", "max_tokens": + 200, "seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", "function": + {"name": "MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel", + "description": "Correctly extracted `MyModel` with all the required parameters + with correct types", "parameters": {"properties": {"classification": {"description": + "Choices for [''title'']", "enum": ["Bug report", "Feature request", "Question", + "Other"], "title": "Classification", "type": "string"}, "evaluation": {"description": + "Rating for [''title''] (0 to 5)", "maximum": 5, "minimum": 0, "title": "Evaluation", + "type": "integer"}, "rationale": {"description": "Text for [''title'']", "title": + "Rationale", "type": "string"}}, "required": ["classification", "evaluation", + "rationale"], "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1252' + content-type: + - application/json + cookie: + - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q; + _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - x64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - Linux + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.5 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA4xUTW/bMAy951cQPCdFkiVtkduC7VAUHYZu3WUZDEamba2y5EpU2yDIfx/kfNjp + OmA+GDIf3+OHSG8HAKhzXACqikTVjRl9/CaPdr6Mlfq0vr/1y+fXl/gwv4mXn6/mHoeJ4da/WcmR + daFc3RgW7eweVp5JOKlOrj5MZvPx9eW8BWqXs0m0spHRzI1qbfVoOp7ORuOr0eT6wK6cVhxwAT8H + AADb9p3ytDm/4gLGw6Ol5hCoZFycnADQO5MsSCHoIGQFhx2onBW2KXUbjekB4pzJFBnTBd4/2965 + axYZk319vVdPtxylMNdP/EA3X26X5N2PXry99KZpEyqiVacm9fCTffEmGABaqlvu3eau7d3wrQP5 + MtZsJaWN2xUqk+outKIkucLFCpexBM+N87LC4Qr5mUw8ovPhCn37QYZb7+8Vgw4hMijD5M0GtM2T + HAcgaLxbG67hRUsFxpWltiVoKw6kYmgMSeF8PYSXSqsKdKIoV9fOQuoCuALWsQSpSICKgpUEiIE9 + kFIcApDN4dgOMlo2Fyvc4VnNu8F751+9q/RcxEDmcMcH++40NMaVqYrwZgaw0FaHKvNMob0LDOKa + fewUp42A8WzesPGubiQT98g2CU7H870edjvRobPD5KI4IdNjzY6sM70sZyHdDuRpBxSpivOO2u0C + xVy7HjDoVf13Nu9p7yvXtvwf+Q5QihvhPGs851qdV9y5eU6/jH+5nbrcJoxhE4TrrNC2ZN943S4s + Fk02XtM4n0xnxQQHu8EfAAAA//8DABuRKKC+BAAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8e19fd436ebd29bd-ORD + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 12 Nov 2024 22:34:26 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '1123' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999703' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_4e031827f28ce0f1485cab379aad87a3 + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "\n Given + the github issue title:\nSupport new file types\n and the description:\nIt would + be great if we could upload files of type .docx\n, \n classify + the issue. Provide a rationale for your classification. \n Evaluate + the final classification on a Likert scale from 1 to 5, \n where + 1 is \"Completely irrelevant\" and 5 is \"Completely relevant\"."}], "model": + "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0, "tool_choice": + {"type": "function", "function": {"name": "MyModel"}}, "tools": [{"type": "function", + "function": {"name": "MyModel", "description": "Correctly extracted `MyModel` + with all the required parameters with correct types", "parameters": {"properties": + {"classification": {"description": "Choices for [''title'']", "enum": ["Bug + report", "Feature request", "Question", "Other"], "title": "Classification", + "type": "string"}, "evaluation": {"description": "Rating for [''title''] (0 + to 5)", "maximum": 5, "minimum": 0, "title": "Evaluation", "type": "integer"}, + "rationale": {"description": "Text for [''title'']", "title": "Rationale", "type": + "string"}}, "required": ["classification", "evaluation", "rationale"], "type": + "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1288' + content-type: + - application/json + cookie: + - __cf_bm=HyTXGMAnampYodd1FBKXDn1fFn._JJQjwqdXC7a_s3Y-1731450865-1.0.1.1-nAAbEyHIdCcb.qzV6kgyU.cvivIWguvH8pRLTma34zZDa8uJap5atJ75MWVBx.v5qW.CVLIliF8ObHpXl8wO9Q; + _cfuvid=9tSzGGt58kMce1IKq064RtvIP0MLmM6pnpleBDUlOkQ-1731450865040-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - x64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - Linux + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.5 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA4xUXW/bMAx8z68g+LQBTuGkSVrkbR8o1m0psLUbhs6Doci0rUWWVElumwb574Ps + JHa6DpgfBJvHOx5FyZsBAIoM54C8ZJ5XRg7fXPuVvrue2sXtw/tPT+ubZX01iz9/yaq39wVGgaGX + v4n7PeuE68pI8kKrFuaWmKegOjo7HU2m8fls1gCVzkgGWmH8cKKHlVBiOI7Hk2F8Nhyd79ilFpwc + zuHnAABg06zBp8roEecQR/tIRc6xgnB+SAJAq2WIIHNOOM+Ux6gDuVaeVLCuail7gNdappxJ2RVu + n03vvdssJmV6G+t3Hz4+XV79uPCnavHNXX5ffT0n2avXSq9NYyivFT9sUg8/xOfPigGgYlXDXawX + zd5FzxOYLeqKlA+2cZMgl6HvXHAWJBOcJ3hBzNeWwNJdTc4nGCVI90zW+5RplKBtPpikhnJTEgjn + 6rDueUIV4GpjtPWQawsMFD1ALiRB6BBenWSaP76O4KEUvAShsmCCHDDIyAlLLUsBqZIpTsE1aNuq + 7CwKBb4kcGvnqTpJcItH/W4HL73/6o3RUl47Jnfz3cW3hwMjdWGsXrpn88dcKOHK1BJzzRzQeW3a + 2qFOUwHro7OGxurK+NTrFakgOB5NWj3s7kOHTsY70GvPZI81nUUv6KUZeSaaw3g4/5zxkrKO2t0D + VmdC94BBr+u/3byk3XYuVPE/8h3AORlPWWosZYIfd9ylWQq/i3+lHXa5MYzt7NNcqIKssaK5rJib + NF6yOBuNJ/kIB9vBHwAAAP//AwC1d/r9ugQAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8e19fd4b593729bd-ORD + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 12 Nov 2024 22:34:27 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '748' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999694' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_891bd00fef969262eafc45e9a07f3be4 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_label_studio_skill.py b/tests/test_label_studio_skill.py index bcfe0d33..ca3279a1 100644 --- a/tests/test_label_studio_skill.py +++ b/tests/test_label_studio_skill.py @@ -74,6 +74,78 @@ async def test_label_studio_skill(): ] +@pytest.mark.asyncio +@pytest.mark.vcr +async def test_label_studio_skill_partial_label_config(): + + df = pd.DataFrame( + [ + {"title": "I can't login", "description": "I can't login to the platform"}, + { + "title": "Support new file types", + "description": "It would be great if we could upload files of type .docx", + }, + ] + ) + + agent_payload = { + "runtimes": { + "default": { + "type": "AsyncLiteLLMChatRuntime", + "model": "gpt-4o-mini", + "api_key": os.getenv("OPENAI_API_KEY"), + "max_tokens": 200, + "temperature": 0, + "batch_size": 100, + "timeout": 10, + "verbose": False, + } + }, + "skills": [ + { + "type": "LabelStudioSkill", + "name": "AnnotationResult", + "input_template": """ + Given the github issue title:\n{title}\n and the description:\n{description}\n, + classify the issue. Provide a rationale for your classification. + Evaluate the final classification on a Likert scale from 1 to 5, + where 1 is "Completely irrelevant" and 5 is "Completely relevant".""", + "label_config": """ + +
+ + +