Skip to content

Commit

Permalink
Use token provider instead of token wrapper (Azure-Samples#1228)
Browse files Browse the repository at this point in the history
* Use token provider instead of token wrapper

* Upgrade to class based syntax

* Use Union
  • Loading branch information
pamelafox authored Feb 2, 2024
1 parent 15af3a8 commit 0ee189d
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions scripts/prepdocslib/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import time
from abc import ABC
from typing import List, Optional, Union
from typing import Awaitable, Callable, List, Optional, Union
from urllib.parse import urljoin

import aiohttp
import tiktoken
from azure.core.credentials import AccessToken, AzureKeyCredential
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
from tenacity import (
AsyncRetrying,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from typing_extensions import TypedDict


class EmbeddingBatch:
Expand Down Expand Up @@ -139,28 +140,29 @@ def __init__(
self.open_ai_service = open_ai_service
self.open_ai_deployment = open_ai_deployment
self.credential = credential
self.cached_token: Optional[AccessToken] = None

async def create_client(self) -> AsyncOpenAI:
class AuthArgs(TypedDict, total=False):
api_key: str
azure_ad_token_provider: Callable[[], Union[str, Awaitable[str]]]

auth_args = AuthArgs()
if isinstance(self.credential, AzureKeyCredential):
auth_args["api_key"] = self.credential.key
elif isinstance(self.credential, AsyncTokenCredential):
auth_args["azure_ad_token_provider"] = get_bearer_token_provider(
self.credential, "https://cognitiveservices.azure.com/.default"
)
else:
raise TypeError("Invalid credential type")

return AsyncAzureOpenAI(
azure_endpoint=f"https://{self.open_ai_service}.openai.azure.com",
azure_deployment=self.open_ai_deployment,
api_key=await self.wrap_credential(),
api_version="2023-05-15",
**auth_args,
)

async def wrap_credential(self) -> str:
if isinstance(self.credential, AzureKeyCredential):
return self.credential.key

if isinstance(self.credential, AsyncTokenCredential):
if not self.cached_token or self.cached_token.expires_on <= time.time():
self.cached_token = await self.credential.get_token("https://cognitiveservices.azure.com/.default")

return self.cached_token.token

raise TypeError("Invalid credential type")


class OpenAIEmbeddingService(OpenAIEmbeddings):
"""
Expand Down

0 comments on commit 0ee189d

Please sign in to comment.