Skip to content

Commit

Permalink
checkpoint tlm_openai
Browse files Browse the repository at this point in the history
  • Loading branch information
huiwengoh committed Oct 24, 2024
1 parent 3c095db commit 882379d
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 1 deletion.
52 changes: 52 additions & 0 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
cleanset_base_url = f"{base_url}/cleansets"
model_base_url = f"{base_url}/v1/deployment"
tlm_base_url = f"{base_url}/v0/trustworthy_llm"
tlm_openai_base_url = f"{base_url}/v0/trustworthy_llm_openai"


def _construct_headers(
Expand Down Expand Up @@ -1081,5 +1082,56 @@ async def tlm_get_confidence_score(
return cast(JSONDict, res_json)


@tlm_retry
async def tlm_openai_prompt(
api_key: str,
prompt: str,
quality_preset: str,
options: Optional[JSONDict],
openai_args: Optional[JSONDict],
rate_handler: TlmRateHandler,
client_session: Optional[aiohttp.ClientSession] = None,
batch_index: Optional[int] = None,
) -> JSONDict:
"""
TODO
"""
local_scoped_client = False
if not client_session:
client_session = aiohttp.ClientSession()
local_scoped_client = True

try:
async with rate_handler:
base_api_url = os.environ.get("CLEANLAB_API_TLM_BASE_URL", tlm_openai_base_url)
res = await client_session.post(
f"{base_api_url}/prompt",
json=dict(
prompt=prompt,
quality=quality_preset,
options=options or {},
openai_args=openai_args or {},
user_id=api_key,
client_id=api_key,
),
headers=_construct_headers(api_key),
)

res_json = await res.json()

handle_rate_limit_error_from_resp(res)
await handle_tlm_client_error_from_resp(res, batch_index)
await handle_tlm_api_error_from_resp(res, batch_index)

if not res_json.get("deberta_success", True):
raise TlmPartialSuccess("Partial failure on deberta call -- slowdown request rate.")

finally:
if local_scoped_client:
await client_session.close()

return cast(JSONDict, res_json)


def send_telemetry(info: JSONDict) -> None:
requests.post(f"{cli_base_url}/telemetry", json=info)
25 changes: 24 additions & 1 deletion cleanlab_studio/studio/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
init_dataset_source,
telemetry,
)
from cleanlab_studio.utils import tlm_lite, tlm_calibrated
from cleanlab_studio.utils import tlm_lite, tlm_calibrated, tlm_openai

from . import enrichment, inference, trustworthy_language_model

Expand Down Expand Up @@ -673,6 +673,29 @@ def TLMCalibrated(
verbose=verbose,
)

def TLMOpenAI(
self,
quality_preset: TLMQualityPreset = "medium",
*,
options: Optional[trustworthy_language_model.TLMOptions] = None,
response_format: Optional[Any] = None, # TODO: typing
timeout: Optional[float] = None,
verbose: Optional[bool] = None,
) -> tlm_openai.TLMOpenAI:
"""
Instantiate a version of the Trustworthy Language Model that you can calibrate using existing ratings for example prompt-response pairs.
For more details, see the documentation of:
[cleanlab_studio.utils.tlm_calibrated.TLMCalibrated](../utils.tlm_calibrated/#class-tlmcalibrated)
"""
return tlm_openai.TLMOpenAI(
self._api_key,
quality_preset,
options=options,
response_format=response_format,
timeout=timeout,
verbose=verbose,
)

def poll_cleanset_status(self, cleanset_id: str, timeout: Optional[int] = None) -> bool:
"""
This method has been deprecated, instead use: `wait_until_cleanset_ready()`
Expand Down
101 changes: 101 additions & 0 deletions cleanlab_studio/utils/tlm_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import asyncio
from typing import Any, Optional, cast

import json
import aiohttp
from cleanlab_studio.studio.trustworthy_language_model import (
TLM,
TLMOptions,
TLMResponse,
handle_tlm_exceptions,
)
from cleanlab_studio.internal.api import api
from cleanlab_studio.internal.constants import (
_TLM_MAX_RETRIES,
)
from cleanlab_studio.internal.types import TLMQualityPreset


class TLMOpenAI(TLM):
def __init__(
self,
api_key: str,
quality_preset: TLMQualityPreset,
*,
options: Optional[TLMOptions] = None,
response_format: Optional[Any] = None, # TODO: change type
timeout: Optional[float] = None,
verbose: Optional[bool] = None,
) -> None:
# TODO: figure out which validations diverge
super().__init__(api_key, quality_preset, options=options, timeout=timeout, verbose=verbose)

self._response_format = response_format
self._response_format_is_pydantic_object = None
self._response_format_json = None

self._openai_args = {}

if self._response_format is not None:
try:
from openai.lib._parsing import type_to_response_format_param
except ImportError:
raise ImportError(
"Cannot import openai which is required to use TLMOpenAI. "
"Please install it using `pip install openai` and try again."
)

self._response_format_is_pydantic_object = not isinstance(self._response_format, dict)

# TODO: could implement ourselves
self._response_format_json = type_to_response_format_param(self._response_format)

self._openai_args["response_format"] = self._response_format_json

@handle_tlm_exceptions("TLMResponse")
async def _prompt_async(
self,
prompt: str,
client_session: Optional[aiohttp.ClientSession] = None,
timeout: Optional[float] = None,
capture_exceptions: bool = False,
batch_index: Optional[int] = None,
) -> TLMResponse:
"""
TODO
"""
response_json = await asyncio.wait_for(
api.tlm_openai_prompt(
self._api_key,
prompt,
self._quality_preset,
self._options,
self._openai_args,
self._rate_handler,
client_session,
batch_index=batch_index,
retries=_TLM_MAX_RETRIES,
),
timeout=timeout,
)

# TODO: error handling
response = response_json["response"]
if self._response_format is not None:
if self._response_format_is_pydantic_object:
response = self._response_format(**json.loads(response))
else:
response = json.loads(response)

tlm_response = {
"response": response,
"trustworthiness_score": response_json["confidence_score"],
}

if self._return_log:
tlm_response["log"] = response_json["log"]

# TODO: wrong typing here (need to update TLMResponse)
return cast(TLMResponse, tlm_response)

0 comments on commit 882379d

Please sign in to comment.