Skip to content

Commit

Permalink
Add better context support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Liraim committed May 21, 2024
1 parent f1b1a5c commit a44f1c4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 27 deletions.
23 changes: 6 additions & 17 deletions src/evidently/descriptors/openai_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class OpenAIPrompting(FeatureDescriptor):
prompt: str
prompt_replace_string: str
context: str
context: Optional[str]
context_column: Optional[str]
context_replace_string: str
openai_params: Optional[dict]
model: str
Expand All @@ -22,7 +23,8 @@ def __init__(
prompt: str,
model: str,
feature_type: str,
context: str = "",
context: Optional[str] = None,
context_column: Optional[str] = None,
prompt_replace_string: str = "REPLACE",
context_replace_string: str = "CONTEXT",
display_name: Optional[str] = None,
Expand All @@ -37,26 +39,12 @@ def __init__(
self.display_name = display_name
self.possible_values = possible_values
self.context = context
self.context_column = context_column
self.context_replace_string = context_replace_string
self.openai_params = openai_params
self.check_mode = check_mode
super().__init__()

def for_column(self, column_name: str):
return OpenAIFeature(
column_name,
model=self.model,
prompt=self.prompt,
prompt_replace_string=self.prompt_replace_string,
feature_type=self.feature_type,
display_name=self.display_name,
possible_values=self.possible_values,
context=self.context,
context_replace_string=self.context_replace_string,
openai_params=self.openai_params,
check_mode=self.check_mode,
).feature_name()

def feature(self, column_name: str) -> GeneratedFeature:
return OpenAIFeature(
column_name,
Expand All @@ -67,6 +55,7 @@ def feature(self, column_name: str) -> GeneratedFeature:
display_name=self.display_name,
possible_values=self.possible_values,
context=self.context,
context_column=self.context_column,
context_replace_string=self.context_replace_string,
openai_params=self.openai_params,
check_mode=self.check_mode,
Expand Down
63 changes: 53 additions & 10 deletions src/evidently/features/openai_feature.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from itertools import repeat
from typing import List
from typing import Optional
from typing import Union
Expand All @@ -19,7 +20,8 @@ class OpenAIFeature(GeneratedFeature):
feature_id: str
prompt: str
prompt_replace_string: str
context: str
context: Optional[str]
context_column: Optional[str]
context_replace_string: str
openai_params: dict
model: str
Expand All @@ -32,7 +34,8 @@ def __init__(
model: str,
prompt: str,
feature_type: str,
context: str = "",
context: Optional[str] = None,
context_column: Optional[str] = None,
prompt_replace_string: str = "REPLACE",
context_replace_string: str = "CONTEXT",
check_mode: str = "any_line",
Expand All @@ -43,7 +46,10 @@ def __init__(
self.feature_id = str(uuid.uuid4())
self.prompt = prompt
self.prompt_replace_string = prompt_replace_string
if context is not None and context_column is not None:
raise ValueError("Context and context_column are mutually exclusive")
self.context = context
self.context_column = context_column
self.context_replace_string = context_replace_string
self.openai_params = openai_params or {}
self.model = model
Expand All @@ -60,16 +66,28 @@ def generate_feature(self, data: pd.DataFrame, data_definition: DataDefinition)
column_data = data[self.column_name].values.tolist()
client = OpenAI()
result: List[Union[str, float, None]] = []
prompt = self.prompt.replace(self.context_replace_string, self.context)

if self.model in _legacy_models:
func = _completions_openai_call
else:
func = _chat_completions_openai_call

for value in column_data:
prompt = prompt.replace(self.prompt_replace_string, value)
prompt_answer = func(client, model=self.model, prompt=prompt, **self.openai_params)
if self.context_column is not None:
context_column = data[self.context_column].values.tolist()
else:
context_column = repeat(self.context)

for message, context in zip(column_data, context_column):
prompt_answer = func(
client,
model=self.model,
prompt=self.prompt,
prompt_replace_string=self.prompt_replace_string,
context_replace_string=self.context_replace_string,
prompt_message=message,
context=context,
**self.openai_params,
)
processed_response = _postprocess_response(
prompt_answer,
self.check_mode,
Expand Down Expand Up @@ -98,14 +116,39 @@ def _feature_column_name(self) -> str:
return self.column_name + "_" + self.feature_id


def _completions_openai_call(client, model: str, prompt: str, params: dict):
prompt_answer = client.completions.create(model=model, prompt=prompt, **params)
def _completions_openai_call(
client,
model: str,
prompt: str,
prompt_replace_string: str,
context_replace_string: str,
prompt_message: str,
context: str,
params: dict,
):
final_prompt = prompt.replace(context_replace_string, context).replace(prompt_replace_string, prompt_message)
prompt_answer = client.completions.create(model=model, prompt=final_prompt, **params)
return prompt_answer.choices[0].text


def _chat_completions_openai_call(client, model: str, prompt: str, params: dict):
def _chat_completions_openai_call(
client,
model: str,
prompt: str,
prompt_replace_string: str,
context_replace_string: str,
prompt_message: str,
context: str,
params: dict,
):
final_prompt = prompt.replace(prompt_replace_string, prompt_message)
prompt_answer = client.chat.completions.create(
model=model, messages=[{"role": "user", "content": prompt}], **params
model=model,
messages=[
{"role": "system", "content": context},
{"role": "user", "content": final_prompt},
],
**params,
)
return prompt_answer.choices[0].message.content

Expand Down

0 comments on commit a44f1c4

Please sign in to comment.