Skip to content

Commit

Permalink
Refactored static and dynamic enrichment APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mturk24 committed Oct 28, 2024
1 parent 3c095db commit 4e0db6a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 93 deletions.
32 changes: 32 additions & 0 deletions cleanlab_studio/studio/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
from tqdm import tqdm
from typing_extensions import NotRequired
from functools import lru_cache

from cleanlab_studio.errors import EnrichmentProjectError
from cleanlab_studio.internal.api import api
Expand Down Expand Up @@ -49,6 +50,12 @@ def _response_timestamp_to_datetime(timestamp_string: str) -> datetime:
return datetime.strptime(timestamp_string, response_timestamp_format_str)


@lru_cache(maxsize=None)
def _get_online_inference():
from cleanlab_studio.utils.data_enrichment.enrich import online_inference
return online_inference


class EnrichmentProject:
"""Represents an Enrichment Project instance, which is bound to a Cleanlab Studio account.
Expand Down Expand Up @@ -399,6 +406,27 @@ def resume(self) -> JSONDict:
latest_job = self._get_latest_job()
return api.resume_enrichment_job(api_key=self._api_key, job_id=latest_job["id"])

def online_inference(
self,
data: Union[pd.DataFrame, List[dict]],
options: EnrichmentOptions,
new_column_name: str,
) -> Dict[str, Any]:
"""
Enrich data in real-time using the same logic as the run() method, but client-side.
Args:
data (Union[pd.DataFrame, List[dict]]): The dataset to enrich.
options (EnrichmentOptions): Options for enriching the dataset.
new_column_name (str): The name of the new column to store the results.
Returns:
Dict[str, Any]: A dictionary containing information about the enrichment job and the enriched dataset.
"""
online_inference = _get_online_inference()
job_info = online_inference(data, options, new_column_name, None)
return job_info


class EnrichmentJob(TypedDict):
"""Represents an Enrichment Job instance.
Expand Down Expand Up @@ -611,3 +639,7 @@ def _handle_replacements_and_extraction_pattern(
else:
raise ValueError(REGEX_PARAMETER_ERROR_MESSAGE)
return extraction_pattern, replacements




196 changes: 103 additions & 93 deletions cleanlab_studio/utils/data_enrichment/enrich.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, List, Optional, Tuple, Union
import pandas as pd
from typing import Any, List, Optional, Tuple, Union, Dict
from functools import lru_cache
from cleanlab_studio.internal.enrichment_utils import (
extract_df_subset,
get_prompt_outputs,
Expand All @@ -8,84 +8,59 @@
get_optimized_prompt,
Replacement,
)

from cleanlab_studio.studio.studio import Studio


def enrich_data(
studio: Studio,
data: pd.DataFrame,
prompt: str,
*,
regex: Optional[Union[str, Replacement, List[Replacement]]] = None,
constrain_outputs: Optional[List[str]] = None,
optimize_prompt: bool = True,
subset_indices: Optional[Union[Tuple[int, int], List[int]]] = (0, 3),
new_column_name: str = "metadata",
disable_warnings: bool = False,
**kwargs: Any,
) -> pd.DataFrame:
from cleanlab_studio.studio.enrichment import EnrichmentOptions

@lru_cache(maxsize=None)
def _get_pandas():
import pandas as pd
return pd

@lru_cache(maxsize=None)
def _get_tqdm():
from tqdm import tqdm
return tqdm

def online_inference(
data: Union['pd.DataFrame', List[dict]],
options: EnrichmentOptions,
new_column_name: str,
studio: Any, # Add this parameter
) -> Dict[str, Any]:
"""
Generate a column of arbitrary metadata for your DataFrame, reliably at scale with Generative AI.
The metadata is separately generated for each row of your DataFrame, based on a prompt that specifies what information you need and what existing columns' data it should be derived from.
Each row of generated metadata is accompanied by a trustworthiness score, which helps you discover which metadata is most/least reliable.
You can optionally apply regular expressions to further reformat your metadata beyond raw LLM outputs, or specify that each row of the metadata must be constrained to a particular set of values.
Enrich data in real-time using the same logic as the run() method, but client-side.
Args:
studio (Studio): Cleanlab Studio client object, which you must instantiate before calling this method.
data (pd.DataFrame): A pandas DataFrame containing your data.
prompt (str): Formatted f-string, that contains both the prompt, and names of columns to embed.
**Example:** "Is this a numeric value, answer Yes or No only. Value: {column_name}"
regex (str | Replacement | List[Replacement], optional): A string, tuple, or list of tuples specifying regular expressions to apply for post-processing the raw LLM outputs.
If a string value is passed in, a regex match will be performed and the matched pattern will be returned (if the pattern cannot be matched, None will be returned).
Specifically the provided string will be passed into Python's `re.match()` method.
Pass in a tuple `(R1, R2)` instead if you wish to perform find and replace operations rather than matching/extraction.
`R1` should be a string containing the regex pattern to match, and `R2` should be a string to replace matches with.
Pass in a list of tuples instead if you wish to apply multiple replacements. Replacements will be applied in the order they appear in the list.
Note that you cannot pass in a list of strings (chaining of multiple regex processing steps is only allowed for replacement operations).
These tuples specify the desired patterns to match and replace from the raw LLM response,
This regex processing is useful in settings where you are unable to prompt the LLM to generate valid outputs 100% of the time,
but can easily transform the raw LLM outputs to be valid through regular expressions that extract and replace parts of the raw output string.
When this regex is applied, the processed results can be seen ithe ``{new_column_name}`` column, and the raw outpus (before any regex processing)
will be saved in the ``{new_column_name}_log`` column of the results dataframe.
**Example 1:** ``regex = '.*The answer is: (Bird|[Rr]abbit).*'`` will extract strings that are the words 'Bird', 'Rabbit' or 'rabbit' after the characters "The answer is: " from the raw response.
**Example 2:** ``regex = [('True', 'T'), ('False', 'F')]`` will replace the words True and False with T and F.
**Example 3:** ``regex = (' Explanation:.*', '') will remove everything after and including the words "Explanation:".
For instance, the response "True. Explanation: 3+4=7, and 7 is an odd number." would return "True." after the regex replacement.
constrain_outputs (List[str], optional): List of all possible output values for the `metadata` column.
If specified, every entry in the `metadata` column will exactly match one of these values (for less open-ended data enrichment tasks). If None, the `metadata` column can contain arbitrary values (for more open-ended data enrichment tasks).
There may be additional transformations applied to ensure the returned value is one of these. If regex is also specified, then these transformations occur after your regex is applied.
If `optimize_prompt` is True, the prompt will be automatically adjusted to include a statement that the response must match one of the `constrain_outputs`.
The last value of this list should be considered the baseline value (eg. “other”) that will be returned where there are no close matches between the raw LLM response and any of the classes mentioned,
that value will be returned if no close matches can be made.
optimize_prompt (bool, default = True): When False, your provided prompt will not be modified in any way. When True, your provided prompt may be automatically adjusted in an effort to produce better results.
For instance, if the constrain_outputs are constrained, we may automatically append the following statement to your prompt: "Your answer must exactly match one of the following values: `constrain_outputs`."
subset_indices (Tuple[int, int] | List[int], optional): What subset of the supplied data rows to generate metadata for. If None, we run on all of the data.
This can be either a list of unique indices or a range. These indices are passed into pandas ``.iloc`` method, so should be integers based on row order as opposed to row-index labels pointing to `df.index`.
We advise against collecting results for all of your data at first. First collect results for a smaller data subset, and use this subset to experiment with different values of the `prompt` or `regex` arguments. Only once the results look good for your subset should you run on the full dataset.
new_column_name (str): Optional name for the returned enriched column. Name acts as a prefix appended to all additional columns that are returned.
disable_warnings (bool, default = False): When True, warnings are disabled.
**kwargs: Optional keyword arguments to pass to the underlying TLM object, such as ``quality_preset`` and ``options`` to specify the TLM quality present and TLMOptions respectively.
For more information on valid TLM arguments, view the TLM documentation here: https://help.cleanlab.ai/reference/python/studio/#method-tlm
data (Union[pd.DataFrame, List[dict]]): The dataset to enrich.
options (EnrichmentOptions): Options for enriching the dataset.
new_column_name (str): The name of the new column to store the results.
studio (Any): A required parameter for the Studio object.
Returns:
A DataFrame that contains `metadata` and `trustworthiness` columns related to the prompt in order of original data. Some columns names will have `new_column_name` prepended to them.
`metadata` column = responses to the prompt and other data mutations if `regex` or `constrain_outputs` is not specified.
`trustworthiness` column = trustworthiness of the prompt responses (which ignore the data mutations).
**Note**: If you specified the `regex` or `constrain_outputs` arguments, some additional transformations may be applied to raw LLM outputs to produce the returned values. In these cases, an additional `log` column will be added to the returned DataFrame that records the raw LLM outputs (feel free to disregard these).
Dict[str, Any]: A dictionary containing information about the enrichment job and the enriched dataset.
"""
if subset_indices:
df = extract_df_subset(data, subset_indices)
else:
df = data.copy()
pd = _get_pandas()
tqdm = _get_tqdm()

# Validate options
_validate_enrichment_options(options)

# Ensure data is a DataFrame
if isinstance(data, list):
data = pd.DataFrame(data)

df = data.copy()

# Extract options
prompt = options['prompt']
regex = options.get('regex')
constrain_outputs = options.get('constrain_outputs')
optimize_prompt = options.get('optimize_prompt', True)
quality_preset = options.get('quality_preset', 'medium')

if optimize_prompt:
prompt = get_optimized_prompt(prompt, constrain_outputs)

outputs = get_prompt_outputs(studio, prompt, df, **kwargs)
outputs = get_prompt_outputs(studio, prompt, df, quality_preset=quality_preset, **options.get('tlm_options', {}))
column_name_prefix = new_column_name + "_"

df[f"{column_name_prefix}trustworthiness"] = [
Expand All @@ -94,41 +69,74 @@ def enrich_data(
df[f"{new_column_name}"] = [
output["response"] if output is not None else None for output in outputs
]

if (
regex is None and constrain_outputs is None
): # we do not need to have a "log" column as original output is not augmented by regex replacements or contrained outputs
return df[[f"{new_column_name}", f"{column_name_prefix}trustworthiness"]]

df[f"{column_name_prefix}log"] = [
output["response"] if output is not None else None for output in outputs
]

if regex:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_regex_match_or_replacement(x, regex)
)
if regex is None and constrain_outputs is None:
enriched_df = df[[f"{new_column_name}", f"{column_name_prefix}trustworthiness"]]
else:
if regex:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_regex_match_or_replacement(x, regex)
)

if constrain_outputs:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_constrain_outputs_match(
x, constrain_outputs, disable_warnings=disable_warnings
if constrain_outputs:
df[f"{new_column_name}"] = df[f"{new_column_name}"].apply(
lambda x: get_constrain_outputs_match(x, constrain_outputs)
)
)

return df[
[
enriched_df = df[[
f"{new_column_name}",
f"{column_name_prefix}trustworthiness",
f"{column_name_prefix}log",
]
]

]]

# Simulate the response structure of the run() method
job_info = {
"job_id": "online_inference",
"status": "SUCCEEDED",
"num_rows": len(enriched_df),
"processed_rows": len(enriched_df),
"average_trustworthiness_score": enriched_df[f"{column_name_prefix}trustworthiness"].mean(),
"results": enriched_df
}

return job_info

def _validate_enrichment_options(options: EnrichmentOptions) -> None:
required_keys = ['prompt']
for key in required_keys:
if key not in options or options[key] is None:
raise ValueError(f"'{key}' is required in the options.")

# Validate types and values
if not isinstance(options['prompt'], str):
raise TypeError("'prompt' must be a string.")

if 'constrain_outputs' in options and options['constrain_outputs'] is not None:
if not isinstance(options['constrain_outputs'], list):
raise TypeError("'constrain_outputs' must be a list if provided.")

if 'optimize_prompt' in options and options['optimize_prompt'] is not None:
if not isinstance(options['optimize_prompt'], bool):
raise TypeError("'optimize_prompt' must be a boolean if provided.")

if 'quality_preset' in options and options['quality_preset'] is not None:
if not isinstance(options['quality_preset'], str):
raise TypeError("'quality_preset' must be a string if provided.")

if 'regex' in options and options['regex'] is not None:
regex = options['regex']
if not isinstance(regex, (str, tuple, list)):
raise TypeError("'regex' must be a string, tuple, or list of tuples.")
if isinstance(regex, list) and not all(isinstance(item, tuple) for item in regex):
raise TypeError("All items in 'regex' list must be tuples.")

def process_regex(
column_data: Union[pd.Series, List[str]],
regex: Union[str, Replacement, List[Replacement]],
) -> Union[pd.Series, List[str]]:
column_data: Union['pd.Series', List[str]],
regex: Union[str, Tuple[str, str], List[Tuple[str, str]]],
) -> Union['pd.Series', List[str]]:
"""
Performs regex matches or replacements to the given string according to the given matching patterns and replacement strings.
Expand All @@ -153,6 +161,8 @@ def process_regex(
Returns:
Extracted matches to the provided regular expression from each element of the data column (specifically, the first match is returned).
"""
pd = _get_pandas()

if isinstance(column_data, list):
return [get_regex_match_or_replacement(x, regex) for x in column_data]
elif isinstance(column_data, pd.Series):
Expand Down

0 comments on commit 4e0db6a

Please sign in to comment.