Skip to content

Commit

Permalink
Add option to get secrets from keyvault
Browse files Browse the repository at this point in the history
  • Loading branch information
abhahn committed Jun 28, 2024
1 parent 2e9317b commit 7c81fd0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--use-keyvault-secrets",
help='Get secrets from a keyvault instead of the environment.',
action='store_true', default=False
)


@pytest.fixture(scope="session")
def use_keyvault_secrets(request) -> str:
return request.config.getoption("use_keyvault_secrets")
34 changes: 34 additions & 0 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
import json
import os
import pytest
from azure.identity import AzureCliCredential
from azure.keyvault.secrets import SecretClient
from pydantic.alias_generators import to_snake


VAULT_NAME = os.environ.get("VAULT_NAME")


@pytest.fixture(scope="module")
def secret_client() -> SecretClient:
kv_uri = f"https://{VAULT_NAME}.vault.azure.net"
print(f"init secret_client from kv_uri={kv_uri}")
credential = AzureCliCredential(additionally_allowed_tenants="*")
return SecretClient(vault_url=kv_uri, credential=credential)


@pytest.fixture(scope="module")
def dotenv_template_params_from_kv(secret_client: SecretClient) -> dict[str, str]:
secrets_properties_list = secret_client.list_properties_of_secrets()
secrets = {}
for secret in secrets_properties_list:
secret_name = to_snake(secret.name).upper()
secrets[secret_name] = secret_client.get_secret(secret.name).value

return secrets


@pytest.fixture(scope="module")
def dotenv_template_params_from_env() -> dict[str, str]:
Expand Down Expand Up @@ -33,3 +59,11 @@ def get_and_unset_variable(var_name):

return {s: get_and_unset_variable(s) for s in env_secrets}


@pytest.fixture(scope="module")
def dotenv_template_params(request, use_keyvault_secrets):
if use_keyvault_secrets:
return request.getfixturevalue("dotenv_template_params_from_kv")

return request.getfixturevalue("dotenv_template_params_from_env")

32 changes: 16 additions & 16 deletions tests/integration_tests/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def use_elasticsearch_embeddings(request):
@pytest.fixture(scope="function")
def dotenv_rendered_template_path(
request,
dotenv_template_params_from_env,
dotenv_template_params,
datasource,
enable_chat_history,
stream,
Expand All @@ -83,30 +83,30 @@ def dotenv_rendered_template_path(
)

if datasource != "none":
dotenv_template_params_from_env["DATASOURCE_TYPE"] = datasource
dotenv_template_params["DATASOURCE_TYPE"] = datasource

if datasource != "Elasticsearch" and use_elasticsearch_embeddings:
pytest.skip("Elasticsearch embeddings not supported for test.")

if datasource == "Elasticsearch":
dotenv_template_params_from_env["USE_ELASTICSEARCH_EMBEDDINGS"] = use_elasticsearch_embeddings
dotenv_template_params["USE_ELASTICSEARCH_EMBEDDINGS"] = use_elasticsearch_embeddings

dotenv_template_params_from_env["USE_AOAI_EMBEDDINGS"] = use_aoai_embeddings
dotenv_template_params["USE_AOAI_EMBEDDINGS"] = use_aoai_embeddings

if use_aoai_embeddings or use_elasticsearch_embeddings:
dotenv_template_params_from_env["AZURE_SEARCH_QUERY_TYPE"] = "vector"
dotenv_template_params_from_env["ELASTICSEARCH_QUERY_TYPE"] = "vector"
dotenv_template_params["AZURE_SEARCH_QUERY_TYPE"] = "vector"
dotenv_template_params["ELASTICSEARCH_QUERY_TYPE"] = "vector"
else:
dotenv_template_params_from_env["AZURE_SEARCH_QUERY_TYPE"] = "simple"
dotenv_template_params_from_env["ELASTICSEARCH_QUERY_TYPE"] = "simple"
dotenv_template_params["AZURE_SEARCH_QUERY_TYPE"] = "simple"
dotenv_template_params["ELASTICSEARCH_QUERY_TYPE"] = "simple"

dotenv_template_params_from_env["ENABLE_CHAT_HISTORY"] = enable_chat_history
dotenv_template_params_from_env["AZURE_OPENAI_STREAM"] = stream
dotenv_template_params["ENABLE_CHAT_HISTORY"] = enable_chat_history
dotenv_template_params["AZURE_OPENAI_STREAM"] = stream

return render_template_to_tempfile(
rendered_template_name,
template_path,
**dotenv_template_params_from_env
**dotenv_template_params
)


Expand All @@ -121,12 +121,12 @@ def test_app(dotenv_rendered_template_path) -> Quart:


@pytest.mark.asyncio
async def test_dotenv(test_app: Quart, dotenv_template_params_from_env: dict[str, str]):
if dotenv_template_params_from_env["DATASOURCE_TYPE"] == "AzureCognitiveSearch":
message_content = dotenv_template_params_from_env["AZURE_SEARCH_QUERY"]
async def test_dotenv(test_app: Quart, dotenv_template_params: dict[str, str]):
if dotenv_template_params["DATASOURCE_TYPE"] == "AzureCognitiveSearch":
message_content = dotenv_template_params["AZURE_SEARCH_QUERY"]

elif dotenv_template_params_from_env["DATASOURCE_TYPE"] == "Elasticsearch":
message_content = dotenv_template_params_from_env["ELASTICSEARCH_QUERY"]
elif dotenv_template_params["DATASOURCE_TYPE"] == "Elasticsearch":
message_content = dotenv_template_params["ELASTICSEARCH_QUERY"]

else:
message_content = "What is Contoso?"
Expand Down

0 comments on commit 7c81fd0

Please sign in to comment.