Skip to content

Commit dc06767

Browse files
llm: Fix chat_llm for OpenAI model (#902)
This PR fixes the OpenAI model's chat_llm logic to preserve prompt history for each LLM chat. Signed-off-by: Arthur Chan <[email protected]> Co-authored-by: DavidKorczynski <[email protected]>
1 parent 77498e9 commit dc06767

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

Diff for: llm_toolkit/models.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(
7777
self.temperature = temperature
7878
self.temperature_list = temperature_list
7979

80+
# Preserve chat history for OpenAI
81+
self.messages = []
82+
8083
def cloud_setup(self):
8184
"""Runs Cloud specific-setup."""
8285
# Only a subset of models need a cloud specific set up, so
@@ -275,14 +278,19 @@ def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str:
275278
logger.info('OpenAI does not allow temperature list: %s',
276279
self.temperature_list)
277280

281+
self.messages.extend(prompt.get())
282+
278283
completion = self.with_retry_on_error(
279-
lambda: client.chat.completions.create(messages=prompt.get(),
284+
lambda: client.chat.completions.create(messages=self.messages,
280285
model=self.name,
281286
n=self.num_samples,
282287
temperature=self.temperature),
283288
[openai.OpenAIError])
284289

285-
return completion.choices[0].message.content
290+
llm_response = completion.choices[0].message.content
291+
self.messages.append({'role': 'assistant', 'content': llm_response})
292+
293+
return llm_response
286294

287295
def ask_llm(self, prompt: prompts.Prompt) -> str:
288296
"""Queries LLM a single prompt and returns its response."""

Diff for: llm_toolkit/prompts.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def gettext(self) -> str:
125125
"""Gets the final formatted prompt in plain text."""
126126
result = ''
127127
for item in self.get():
128-
result = f'{result}\n{item.get("content", "")}'
128+
result = (f'{result}\n{item.get("role", "Unknown")}:'
129+
f'\n{item.get("content", "")}')
129130

130131
return result
131132

0 commit comments

Comments
 (0)