Skip to content

Commit

Permalink
Reformatted and renamed client side enrich script
Browse files Browse the repository at this point in the history
  • Loading branch information
mturk24 committed Oct 29, 2024
1 parent 4e0db6a commit 1cdc0dd
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions cleanlab_studio/utils/data_enrichment/enrich.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import Any, List, Tuple, Union, Dict
from functools import lru_cache
from cleanlab_studio.internal.enrichment_utils import (
extract_df_subset,
get_prompt_outputs,
get_regex_match_or_replacement,
get_constrain_outputs_match,
get_optimized_prompt,
Replacement,
)
from cleanlab_studio.studio.enrichment import EnrichmentOptions


@lru_cache(maxsize=None)
def _get_pandas():
import pandas as pd

return pd


@lru_cache(maxsize=None)
def _get_tqdm():
from tqdm import tqdm

return tqdm

def online_inference(
data: Union['pd.DataFrame', List[dict]],

def run_online(
data: Union["pd.DataFrame", List[dict]],
options: EnrichmentOptions,
new_column_name: str,
studio: Any, # Add this parameter
studio: Any,
) -> Dict[str, Any]:
"""
Enrich data in real-time using the same logic as the run() method, but client-side.
Expand All @@ -40,7 +43,7 @@ def online_inference(
"""
pd = _get_pandas()
tqdm = _get_tqdm()

# Validate options
_validate_enrichment_options(options)

Expand All @@ -51,16 +54,18 @@ def online_inference(
df = data.copy()

# Extract options
prompt = options['prompt']
regex = options.get('regex')
constrain_outputs = options.get('constrain_outputs')
optimize_prompt = options.get('optimize_prompt', True)
quality_preset = options.get('quality_preset', 'medium')
prompt = options["prompt"]
regex = options.get("regex")
constrain_outputs = options.get("constrain_outputs")
optimize_prompt = options.get("optimize_prompt", True)
quality_preset = options.get("quality_preset", "medium")

if optimize_prompt:
prompt = get_optimized_prompt(prompt, constrain_outputs)

outputs = get_prompt_outputs(studio, prompt, df, quality_preset=quality_preset, **options.get('tlm_options', {}))
outputs = get_prompt_outputs(
studio, prompt, df, quality_preset=quality_preset, **options.get("tlm_options", {})
)
column_name_prefix = new_column_name + "_"

df[f"{column_name_prefix}trustworthiness"] = [
Expand All @@ -86,57 +91,61 @@ def online_inference(
lambda x: get_constrain_outputs_match(x, constrain_outputs)
)

enriched_df = df[[
f"{new_column_name}",
f"{column_name_prefix}trustworthiness",
f"{column_name_prefix}log",
]]
enriched_df = df[
[
f"{new_column_name}",
f"{column_name_prefix}trustworthiness",
f"{column_name_prefix}log",
]
]

# Simulate the response structure of the run() method
job_info = {
"job_id": "online_inference",
"job_id": "run_online",
"status": "SUCCEEDED",
"num_rows": len(enriched_df),
"processed_rows": len(enriched_df),
"average_trustworthiness_score": enriched_df[f"{column_name_prefix}trustworthiness"].mean(),
"results": enriched_df
"results": enriched_df,
}

return job_info


def _validate_enrichment_options(options: EnrichmentOptions) -> None:
required_keys = ['prompt']
required_keys = ["prompt"]
for key in required_keys:
if key not in options or options[key] is None:
raise ValueError(f"'{key}' is required in the options.")

# Validate types and values
if not isinstance(options['prompt'], str):
if not isinstance(options["prompt"], str):
raise TypeError("'prompt' must be a string.")

if 'constrain_outputs' in options and options['constrain_outputs'] is not None:
if not isinstance(options['constrain_outputs'], list):
if "constrain_outputs" in options and options["constrain_outputs"] is not None:
if not isinstance(options["constrain_outputs"], list):
raise TypeError("'constrain_outputs' must be a list if provided.")

if 'optimize_prompt' in options and options['optimize_prompt'] is not None:
if not isinstance(options['optimize_prompt'], bool):
if "optimize_prompt" in options and options["optimize_prompt"] is not None:
if not isinstance(options["optimize_prompt"], bool):
raise TypeError("'optimize_prompt' must be a boolean if provided.")

if 'quality_preset' in options and options['quality_preset'] is not None:
if not isinstance(options['quality_preset'], str):
if "quality_preset" in options and options["quality_preset"] is not None:
if not isinstance(options["quality_preset"], str):
raise TypeError("'quality_preset' must be a string if provided.")

if 'regex' in options and options['regex'] is not None:
regex = options['regex']
if "regex" in options and options["regex"] is not None:
regex = options["regex"]
if not isinstance(regex, (str, tuple, list)):
raise TypeError("'regex' must be a string, tuple, or list of tuples.")
if isinstance(regex, list) and not all(isinstance(item, tuple) for item in regex):
raise TypeError("All items in 'regex' list must be tuples.")


def process_regex(
column_data: Union['pd.Series', List[str]],
column_data: Union["pd.Series", List[str]],
regex: Union[str, Tuple[str, str], List[Tuple[str, str]]],
) -> Union['pd.Series', List[str]]:
) -> Union["pd.Series", List[str]]:
"""
Performs regex matches or replacements to the given string according to the given matching patterns and replacement strings.
Expand All @@ -162,7 +171,7 @@ def process_regex(
Extracted matches to the provided regular expression from each element of the data column (specifically, the first match is returned).
"""
pd = _get_pandas()

if isinstance(column_data, list):
return [get_regex_match_or_replacement(x, regex) for x in column_data]
elif isinstance(column_data, pd.Series):
Expand Down

0 comments on commit 1cdc0dd

Please sign in to comment.