Skip to content

Commit

Permalink
Externalize OpenAI base URL
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed Jun 11, 2024
1 parent 10aed6c commit 89b4d97
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"USE_FLASK_DEBUG_MODE": "true",
"FEEDS": "https://feed.infoq.com/development/,https://blog.jetbrains.com/feed/,https://feed.infoq.com/Devops/,https://feed.infoq.com/architecture-design/",
"ROOT_LOG_LEVEL": "INFO",
"STARTER_LOG_LEVEL": "DEBUG"
"STARTER_LOG_LEVEL": "DEBUG",
"OPEN_AI_BASE_URL": "https://api.openai.com/v1/"
}
}
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export OPEN_AI_KEY=fill_me_in
export OPEN_AI_BASE_URL=fill_me_in
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ and uses pgvector to store the embeddings in PostgreSQL.

The web application collects the user's query and creates an embedding with the OpenAI Embeddings API.
It then searches the PostgreSQL for similar embeddings (using pgvector) and provides the corresponding chunk of text as
context for a query to the [Azure AI Chat Completion API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions).
context for a query to the [Azure AI Chat Completion API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions).

## Local development

Expand All @@ -113,7 +113,7 @@ context for a query to the [Azure AI Chat Completion API](https://learn.microsof
[localhost:5001](http://localhost:5001).

```shell
python collector.py
python analyzer.py
python collect.py
python analyze.py
python -m starter
```
2 changes: 1 addition & 1 deletion analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
chunks_gateway = ChunksGateway(db_template)
embeddings_gateway = EmbeddingsGateway(db_template)
ai_client = OpenAIClient(
base_url="https://api.openai.com/v1/",
base_url=env.open_ai_base_url,
api_key=env.open_ai_key,
embeddings_model="text-embedding-3-small",
chat_model="gpt-4o"
Expand Down
2 changes: 2 additions & 0 deletions starter/ai/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def fetch_embedding(self, text) -> Result[List[float]]:
},
)
if not response.ok:
logger.error(f"Received {response.status_code} response from {self.base_url}: {response.text}")
return Failure("Failed to fetch embedding")

return Success(response.json()["data"][0]["embedding"])
Expand All @@ -55,6 +56,7 @@ def fetch_chat_completion(self, messages: List[ChatMessage]) -> Result[str]:
]},
)
if not response.ok:
logger.error(f"Received {response.status_code} response from {self.base_url}: {response.text}")
return Failure("Failed to fetch completion")

return Success(response.json()["choices"][0]["message"]["content"])
2 changes: 1 addition & 1 deletion starter/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def create_app(env: Environment = Environment.from_env()) -> Flask:
chunks_gateway = ChunksGateway(db_template)
embeddings_gateway = EmbeddingsGateway(db_template)
ai_client = OpenAIClient(
base_url="https://api.openai.com/v1/",
base_url=env.open_ai_base_url,
api_key=env.open_ai_key,
embeddings_model="text-embedding-3-small",
chat_model="gpt-4o"
Expand Down
2 changes: 2 additions & 0 deletions starter/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Environment:
use_flask_debug_mode: bool
feeds: List[str]
open_ai_key: str
open_ai_base_url: str
root_log_level: str
starter_log_level: str

Expand All @@ -21,6 +22,7 @@ def from_env(cls) -> 'Environment':
use_flask_debug_mode=os.environ.get('USE_FLASK_DEBUG_MODE', 'false') == 'true',
feeds=cls.__require_env('FEEDS').strip().split(','),
open_ai_key=cls.__require_env('OPEN_AI_KEY'),
open_ai_base_url=cls.__require_env('OPEN_AI_BASE_URL'),
root_log_level=os.environ.get('ROOT_LOG_LEVEL', 'INFO'),
starter_log_level=os.environ.get('STARTER_LOG_LEVEL', 'INFO'),
)
Expand Down
3 changes: 3 additions & 0 deletions tests/ai/test_open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from starter.ai.open_ai_client import OpenAIClient, ChatMessage
from tests.chat_support import chat_response
from tests.embeddings_support import embedding_response, embedding_vector
from tests.logging_support import disable_logging


class TestOpenAIClient(unittest.TestCase):
Expand All @@ -22,6 +23,7 @@ def test_fetch_embedding(self):

self.assertEqual(embedding_vector(2), self.client.fetch_embedding("some query").value)

@disable_logging
@responses.activate
def test_fetch_embedding_failure(self):
responses.add(responses.POST, "https://openai.example.com/embeddings", "bad news", status=400)
Expand All @@ -39,6 +41,7 @@ def test_fetch_chat_completion(self):
]).value,
)

@disable_logging
@responses.activate
def test_fetch_chat_completion_failure(self):
responses.add(responses.POST, "https://openai.example.com/chat/completions", "bad news", status=400)
Expand Down

0 comments on commit 89b4d97

Please sign in to comment.