Skip to content

Commit

Permalink
Merge pull request #12 from NREL/gb/oai_v1_debug
Browse files Browse the repository at this point in the history
fixed chat completion response object attributes for openai v1+
  • Loading branch information
grantbuster authored Jan 4, 2024
2 parents 323066e + 0bb795f commit aadab57
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
4 changes: 2 additions & 2 deletions elm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def chat(self, query, temperature=0):
stream=False)

response = self._client.chat.completions.create(**kwargs)
response = response["choices"][0]["message"]["content"]
response = response.choices[0].message.content
self.messages.append({'role': 'assistant', 'content': response})

return response
Expand Down Expand Up @@ -242,7 +242,7 @@ def generic_query(self, query, model_role=None, temperature=0):
stream=False)

response = self._client.chat.completions.create(**kwargs)
response = response["choices"][0]["message"]["content"]
response = response.choices[0].message.content
return response

async def generic_async_query(self, queries, model_role=None,
Expand Down
4 changes: 2 additions & 2 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, corpus, model=None, token_budget=3500, ref_col=None):
"""

super().__init__(model)

self.corpus = self.preflight_corpus(corpus)
self.token_budget = token_budget
self.embedding_arr = np.vstack(self.corpus['embedding'].values)
Expand Down Expand Up @@ -294,7 +294,7 @@ def chat(self, query,
print(chunk_msg, end='')

else:
response_message = response["choices"][0]["message"]["content"]
response_message = response.choices[0].message.content

self.messages.append({'role': 'assistant',
'content': response_message})
Expand Down
10 changes: 9 additions & 1 deletion tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
TEXT = f.read()


class MockObject:
"""Dummy class for mocking api response objects"""


class MockClass:
"""Dummy class to mock various api calls"""

Expand All @@ -38,7 +42,11 @@ def call(*args, **kwargs): # pylint: disable=unused-argument
@staticmethod
def create(*args, **kwargs): # pylint: disable=unused-argument
"""Mock for openai.ChatCompletion.create()"""
response = {'choices': [{'message': {'content': 'hello!'}}]}
# pylint: disable=attribute-defined-outside-init
response = MockObject()
response.choices = [MockObject()]
response.choices[0].message = MockObject()
response.choices[0].message.content = 'hello!'
return response


Expand Down

0 comments on commit aadab57

Please sign in to comment.