Skip to content

Commit

Permalink
fixed azure test
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Nov 13, 2024
1 parent 06db190 commit 65adc9a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
22 changes: 14 additions & 8 deletions python/semantic_kernel/connectors/memory/azure_ai_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,21 @@ def data_model_definition_to_azure_ai_search_index(
algorithm_configuration_name=algo_name,
)
)
algo_class, algo_params = INDEX_ALGORITHM_MAP[field.index_kind or "default"]
search_algos.append(
algo_class(
name=algo_name,
parameters=algo_params(
metric=DISTANCE_FUNCTION_MAP[field.distance_function or "default"],
),
try:
algo_class, algo_params = INDEX_ALGORITHM_MAP[field.index_kind or "default"]
except KeyError as e:
raise ServiceInitializationError(f"Error: {e} not found in INDEX_ALGORITHM_MAP.") from e
try:
search_algos.append(
algo_class(
name=algo_name,
parameters=algo_params(
metric=DISTANCE_FUNCTION_MAP[field.distance_function or "default"],
),
)
)
)
except KeyError as e:
raise ServiceInitializationError(f"Error: {e} not found in DISTANCE_FUNCTION_MAP.") from e
return SearchIndex(
name=collection_name,
fields=fields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,12 @@
MemoryConnectorInitializationError,
)
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.list_handler import desync_list

BASE_PATH_SEARCH_CLIENT = "azure.search.documents.aio.SearchClient"
BASE_PATH_INDEX_CLIENT = "azure.search.documents.indexes.aio.SearchIndexClient"


class AsyncIter:
def __init__(self, items):
self.items = items

async def __aiter__(self):
for item in self.items:
yield item


@fixture
def vector_store(azure_ai_search_unit_test_env):
"""Fixture to instantiate AzureCognitiveSearchMemoryStore with basic configuration."""
Expand All @@ -58,7 +50,7 @@ def mock_list_collection_names():
"""Fixture to patch 'SearchIndexClient' and its 'create_index' method."""
with patch(f"{BASE_PATH_INDEX_CLIENT}.list_index_names") as mock_list_index_names:
# Setup the mock to return a specific SearchIndex instance when called
mock_list_index_names.return_value = AsyncIter(["test"])
mock_list_index_names.return_value = desync_list(["test"])
yield mock_list_index_names


Expand Down Expand Up @@ -253,6 +245,7 @@ async def test_create_index_from_index_fail(collection, mock_create_collection):
await collection.create_collection(index=index)


@mark.parametrize("distance_function", [("cosine_distance")])
def test_data_model_definition_to_azure_ai_search_index(data_model_definition):
index = data_model_definition_to_azure_ai_search_index("test", data_model_definition)
assert index is not None
Expand Down

0 comments on commit 65adc9a

Please sign in to comment.