diff --git a/langtest/embeddings/openai.py b/langtest/embeddings/openai.py index 8fe9c686c..490e414f4 100644 --- a/langtest/embeddings/openai.py +++ b/langtest/embeddings/openai.py @@ -10,7 +10,7 @@ class OpenaiEmbeddings: LIB_NAME = "openai" - def __init__(self, model="text-embedding-ada-002"): + def __init__(self, model="text-embedding-3-small"): self.model = model self.api_key = os.environ.get("OPENAI_API_KEY") self.openai = None @@ -18,7 +18,7 @@ def __init__(self, model="text-embedding-ada-002"): if not self.api_key: raise ValueError(Errors.E032()) - self.openai.api_key = self.api_key + # self.openai.api_key = self.api_key def _check_openai_package(self): """Check if the 'openai' package is installed and import the required functions. @@ -44,13 +44,17 @@ def get_embedding( list[float]: A list of floating-point values representing the text's embedding. """ if isinstance(text, list): - response = self.openai.Embedding.create(input=text, model=self.model) + response = self.openai.Client(api_key=self.api_key).embedding.create( + input=text, model=self.model + ) embedding = [ - np.array(response["data"][i]["embedding"]).reshape(1, -1) + np.array(response.data[i].embedding).reshape(1, -1) for i in range(len(text)) ] return embedding else: - response = self.openai.Embedding.create(input=[text], model=self.model) - embedding = np.array(response["data"][0]["embedding"]).reshape(1, -1) + response = self.openai.Client(api_key=self.api_key).embedding.create( + input=[text], model=self.model + ) + embedding = np.array(response.data[0].embedding).reshape(1, -1) return embedding