diff --git a/python/samples/concepts/memory/new_memory.py b/python/samples/concepts/memory/new_memory.py index a76716659ad2..0819fd2aeadf 100644 --- a/python/samples/concepts/memory/new_memory.py +++ b/python/samples/concepts/memory/new_memory.py @@ -12,6 +12,9 @@ from semantic_kernel.connectors.ai.open_ai import OpenAIEmbeddingPromptExecutionSettings, OpenAITextEmbedding from semantic_kernel.connectors.ai.open_ai.services.azure_text_embedding import AzureTextEmbedding from semantic_kernel.connectors.memory.azure_ai_search import AzureAISearchCollection +from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_collection import ( + AzureDBForPostgresCollection, +) from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection from semantic_kernel.connectors.memory.qdrant import QdrantCollection from semantic_kernel.connectors.memory.redis import RedisHashsetCollection, RedisJsonCollection @@ -88,6 +91,10 @@ class MyDataModelList: "ai_search": lambda: AzureAISearchCollection[MyDataModel]( data_model_type=MyDataModel, ), + "azure_db_for_postgres": lambda: AzureDBForPostgresCollection[str, MyDataModel]( + data_model_type=MyDataModel, + collection_name=collection_name, + ), "postgres": lambda: PostgresCollection[str, MyDataModel]( data_model_type=MyDataModel, collection_name=collection_name, diff --git a/python/samples/getting_started/third_party/postgres-memory.ipynb b/python/samples/getting_started/third_party/postgres-memory.ipynb index b0069a59a1c7..1aeca721b2d7 100644 --- a/python/samples/getting_started/third_party/postgres-memory.ipynb +++ b/python/samples/getting_started/third_party/postgres-memory.ipynb @@ -33,6 +33,9 @@ ")\n", "from semantic_kernel.connectors.ai.open_ai.services.azure_text_embedding import AzureTextEmbedding\n", "from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_embedding import OpenAITextEmbedding\n", + "from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_collection import (\n", + " AzureDBForPostgresCollection,\n", + ")\n", "from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection\n", "from semantic_kernel.data.const import DistanceFunction, IndexKind\n", "from semantic_kernel.data.vector_store_model_decorator import vectorstoremodel\n", @@ -55,10 +58,23 @@ "\n", "To do this, copy the `.env.example` file to `.env` and fill in the necessary information.\n", "\n", + "Note that if you are using VS Code to execute this notebook, ensure you don't have alternate values in the .env file at the root of the workspace, as that will take precedence over the .env file in the notebook.\n", + "\n", "### Postgres configuration\n", "\n", "You'll need to provide a connection string to a Postgres database. You can use a local Postgres instance, or a cloud-hosted one.\n", - "You can provide a connection string, or provide environment variables with the connection information. See the .env.example file for `POSTGRES_` settings.\n", + "You can provide a connection string, or provide environment variables with the connection information. See the .env.example file for `POSTGRES_CONNECTION_STRING` and `PG*` settings.\n", + "\n", + "#### Using Azure DB for Postgres\n", + "\n", + "You can use Azure DB for Postgres by following the steps below:\n", + "\n", + "1. Create an Azure DB for Postgres instance. You can set the database to only allow Entra authentication to avoid\n", + " storing the password in the `.env` file.\n", + "2. Set the `PG*` settings, except for the password if using Entra authentication. If using entra, ensure you\n", + " are logged in via the Azure CLI. You can get the configuration values from the Azure portal Settings -> Connect\n", + " page.\n", + "3. Set \"USE_AZURE_DB_FOR_POSTGRES\" to True in the cell below.\n", "\n", "#### Using Docker\n", "\n", @@ -130,7 +146,13 @@ "USE_AZURE_OPENAI = True\n", "\n", "# The name of the OpenAI model or Azure OpenAI deployment to use\n", - "EMBEDDING_MODEL = \"text-embedding-3-small\"" + "EMBEDDING_MODEL = \"text-embedding-3-small\"\n", + "\n", + "# -- Postgres settings --\n", + "\n", + "# Use Azure DB For Postgres. This enables Entra authentication against the database instead of\n", + "# setting a password in the environment.\n", + "USE_AZURE_DB_FOR_POSTGRES = True" ] }, { @@ -265,9 +287,14 @@ "metadata": {}, "outputs": [], "source": [ - "collection = PostgresCollection[str, ArxivPaper](\n", - " collection_name=\"arxiv_papers\", data_model_type=ArxivPaper, env_file_path=env_file_path\n", - ")" + "if USE_AZURE_DB_FOR_POSTGRES:\n", + " collection = AzureDBForPostgresCollection[str, ArxivPaper](\n", + " collection_name=\"arxiv_papers\", data_model_type=ArxivPaper, env_file_path=env_file_path\n", + " )\n", + "else:\n", + " collection = PostgresCollection[str, ArxivPaper](\n", + " collection_name=\"arxiv_papers\", data_model_type=ArxivPaper, env_file_path=env_file_path\n", + " )" ] }, { @@ -334,9 +361,44 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "Note that we use the collection as a context manager. This opens and closes the connection pool that is created by the collection. If we want to maintain a persistent connection pool, which is more typical for a long-running application, we can create the connection pool outside of the context manager and pass it in. This is also useful if we want to use the same connection pool for multiple collections. \n", + "\n", + "The settings object PostgresSettings and AzureDBForPostgresSettings enable easy creation of collection pools. We use this technique in the next cell.\n", + "\n", "Here we retrieve the first few models from the database and print out their information." ] }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_settings import (\n", + " AzureDBForPostgresSettings,\n", + ")\n", + "from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings\n", + "\n", + "if USE_AZURE_DB_FOR_POSTGRES:\n", + " settings = AzureDBForPostgresSettings(env_file_path=env_file_path)\n", + " connection_pool = await settings.create_connection_pool()\n", + " collection = AzureDBForPostgresCollection[str, ArxivPaper](\n", + " collection_name=\"arxiv_papers\",\n", + " data_model_type=ArxivPaper,\n", + " connection_pool=connection_pool,\n", + " settings=AzureDBForPostgresSettings(env_file_path=env_file_path),\n", + " )\n", + "else:\n", + " settings = PostgresSettings(env_file_path=env_file_path)\n", + " connection_pool = await settings.create_connection_pool()\n", + " collection = PostgresCollection[str, ArxivPaper](\n", + " collection_name=\"arxiv_papers\", data_model_type=ArxivPaper, env_file_path=env_file_path\n", + " )\n", + "\n", + "# Open the connection pool\n", + "await connection_pool.open()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -360,6 +422,22 @@ " print()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we need to close the connection pool explicitly." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "await connection_pool.close()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/__init__.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/__init__.py new file mode 100644 index 000000000000..2a50eae89411 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_collection.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_collection.py new file mode 100644 index 000000000000..9427a390831a --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_collection.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft. All rights reserved. +from typing import TypeVar + +from psycopg_pool import AsyncConnectionPool + +from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_settings import ( + AzureDBForPostgresSettings, +) +from semantic_kernel.connectors.memory.postgres.constants import DEFAULT_SCHEMA +from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection +from semantic_kernel.data.vector_store_model_definition import VectorStoreRecordDefinition + +TKey = TypeVar("TKey", str, int) +TModel = TypeVar("TModel") + + +class AzureDBForPostgresCollection(PostgresCollection[TKey, TModel]): + """AzureDBForPostgresCollection class.""" + + def __init__( + self, + collection_name: str, + data_model_type: type[TModel], + data_model_definition: VectorStoreRecordDefinition | None = None, + connection_pool: AsyncConnectionPool | None = None, + db_schema: str = DEFAULT_SCHEMA, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + settings: AzureDBForPostgresSettings | None = None, + ): + """Initialize the collection. + + Args: + collection_name: The name of the collection, which corresponds to the table name. + data_model_type: The type of the data model. + data_model_definition: The data model definition. + connection_pool: The connection pool. + db_schema: The database schema. + env_file_path: Use the environment settings file as a fallback to environment variables. + env_file_encoding: The encoding of the environment settings file. + settings: The settings for the Azure DB for Postgres connection. If not provided, the settings will be + created from the environment. + """ + # If the connection pool or settings were not provided, create the settings from the environment. + # Passing this to the super class will enforce using Azure DB settings. + if not connection_pool and not settings: + settings = AzureDBForPostgresSettings.create( + env_file_path=env_file_path, env_file_encoding=env_file_encoding + ) + super().__init__( + collection_name=collection_name, + data_model_type=data_model_type, + data_model_definition=data_model_definition, + connection_pool=connection_pool, + db_schema=db_schema, + settings=settings, + ) diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_settings.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_settings.py new file mode 100644 index 000000000000..ea5e61e646b6 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_settings.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft. All rights reserved. +import sys +from typing import Any + +from psycopg.conninfo import conninfo_to_dict +from psycopg_pool import AsyncConnectionPool + +from semantic_kernel.connectors.memory.azure_db_for_postgres.entra_connection import AsyncEntraConnection +from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError + +if sys.version_info >= (3, 12): + pass # pragma: no cover +else: + pass # pragma: no cover + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from semantic_kernel import __version__ +from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings + + +class AzureDBForPostgresSettings(PostgresSettings): + """Azure DB for Postgres model settings. + + This is the same as PostgresSettings, but does not a require a password. + If a password is not supplied, then Entra will use the Azure AD token. + You can also supply an Azure credential directly. + """ + + credential: AsyncTokenCredential | TokenCredential | None = None + + def get_connection_args(self, **kwargs) -> dict[str, Any]: + """Get connection arguments. + + Args: + kwargs: dict[str, Any] - Additional arguments + Use this to override any connection arguments. + + Returns: + dict[str, Any]: Connection arguments that can be passed to psycopg.connect + """ + result = conninfo_to_dict(self.connection_string.get_secret_value()) if self.connection_string else {} + + if self.host: + result["host"] = self.host + if self.port: + result["port"] = self.port + if self.dbname: + result["dbname"] = self.dbname + if self.user: + result["user"] = self.user + if self.password: + result["password"] = self.password.get_secret_value() + + result = {**result, **kwargs} + + # Ensure required values + if "host" not in result: + raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.") + if "dbname" not in result: + raise MemoryConnectorInitializationError( + "database is required. Please set PGDATABASE or connection_string." + ) + + return result + + async def create_connection_pool(self) -> AsyncConnectionPool: + """Creates a connection pool based off of settings. + + Uses AsyncEntraConnection as the connection class, which + can set the user and password based on a Entra token. + """ + pool: AsyncConnectionPool = AsyncConnectionPool( + min_size=self.min_pool, + max_size=self.max_pool, + open=False, + kwargs={ + **self.get_connection_args(), + **{ + "credential": self.credential, + "application_name": f"semantic_kernel (python) v{__version__}", + }, + }, + connection_class=AsyncEntraConnection, + ) + await pool.open() + return pool diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_store.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_store.py new file mode 100644 index 000000000000..eba3138b4cfe --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_store.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.connectors.memory.postgres.postgres_store import PostgresStore + + +class AzureDBForPostgresStore(PostgresStore): + """AzureDBForPostgresStore class.""" + + pass diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/constants.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/constants.py new file mode 100644 index 000000000000..612b938173e1 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/constants.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft. All rights reserved. + +AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/entra_connection.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/entra_connection.py new file mode 100644 index 000000000000..2a2d044e3802 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/entra_connection.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft. All rights reserved. +import base64 +import json +import logging +from functools import lru_cache + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.identity import DefaultAzureCredential +from psycopg import AsyncConnection + +from semantic_kernel.connectors.memory.azure_db_for_postgres.constants import AZURE_DB_FOR_POSTGRES_SCOPE + +logger = logging.getLogger(__name__) + + +async def get_entra_token_aysnc(credential: AsyncTokenCredential) -> str: + """Get the password from Entra using the provided credential.""" + logger.info("Acquiring Entra token for postgres password") + + async with credential: + cred = await credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE) + return cred.token + + +def get_entra_token(credential: TokenCredential | None) -> str: + """Get the password from Entra using the provided credential.""" + logger.info("Acquiring Entra token for postgres password") + credential = credential or get_default_azure_credentials() + + return credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE).token + + +@lru_cache(maxsize=1) +def get_default_azure_credentials() -> DefaultAzureCredential: + """Get the default Azure credentials. + + This method caches the credentials to avoid creating new instances. + """ + return DefaultAzureCredential() + + +def decode_jwt(token): + """Decode the JWT payload to extract claims.""" + payload = token.split(".")[1] + padding = "=" * (4 - len(payload) % 4) + decoded_payload = base64.urlsafe_b64decode(payload + padding) + return json.loads(decoded_payload) + + +async def get_entra_conninfo(credential: TokenCredential | AsyncTokenCredential | None) -> dict[str, str]: + """Fetches a token returns the username and token.""" + # Fetch a new token and extract the username + if isinstance(credential, AsyncTokenCredential): + token = await get_entra_token_aysnc(credential) + else: + token = get_entra_token(credential) + claims = decode_jwt(token) + username = claims.get("upn") or claims.get("preferred_username") or claims.get("unique_name") + if not username: + raise ValueError("Could not extract username from token. Have you logged in?") + + return {"user": username, "password": token} + + +class AsyncEntraConnection(AsyncConnection): + """Asynchronous connection class for using Entra auth with Azure DB for PostgreSQL.""" + + @classmethod + async def connect(cls, *args, **kwargs): + """Establish an asynchronous connection using Entra auth with Azure DB for PostgreSQL.""" + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (TokenCredential, AsyncTokenCredential)): + raise ValueError("credential must be a TokenCredential or AsyncTokenCredential") + if credential or not kwargs.get("user") or not kwargs.get("password"): + entra_conninfo = await get_entra_conninfo(credential) + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return await super().connect(*args, **kwargs | entra_conninfo) diff --git a/python/semantic_kernel/connectors/memory/azure_db_for_postgres/utils.py b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/utils.py new file mode 100644 index 000000000000..ea693f908d3c --- /dev/null +++ b/python/semantic_kernel/connectors/memory/azure_db_for_postgres/utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. +import logging +from functools import lru_cache + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.identity import DefaultAzureCredential + +from semantic_kernel.connectors.memory.azure_db_for_postgres.constants import AZURE_DB_FOR_POSTGRES_SCOPE + +logger = logging.getLogger(__name__) + + +async def get_entra_token_async(credential: AsyncTokenCredential) -> str: + """Get the password from Entra using the provided credential.""" + logger.info("Acquiring Entra token for postgres password") + + async with credential: + cred = await credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE) + return cred.token + + +def get_entra_token(credential: TokenCredential | None) -> str: + """Get the password from Entra using the provided credential.""" + logger.info("Acquiring Entra token for postgres password") + credential = credential or get_default_azure_credentials() + print("HERE") + + return credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE).token + + +@lru_cache(maxsize=1) +def get_default_azure_credentials() -> DefaultAzureCredential: + """Get the default Azure credentials. + + This method caches the credentials to avoid creating new instances. + """ + return DefaultAzureCredential() diff --git a/python/semantic_kernel/connectors/memory/postgres/postgres_settings.py b/python/semantic_kernel/connectors/memory/postgres/postgres_settings.py index 32cd56f7b079..c90275237e58 100644 --- a/python/semantic_kernel/connectors/memory/postgres/postgres_settings.py +++ b/python/semantic_kernel/connectors/memory/postgres/postgres_settings.py @@ -71,8 +71,16 @@ class PostgresSettings(KernelBaseSettings): default_dimensionality: int = 100 max_rows_per_transaction: int = 1000 - def get_connection_args(self) -> dict[str, Any]: - """Get connection arguments.""" + def get_connection_args(self, **kwargs) -> dict[str, Any]: + """Get connection arguments. + + Args: + kwargs: dict[str, Any] - Additional arguments + Use this to override any connection arguments. + + Returns: + dict[str, Any]: Connection arguments that can be passed to psycopg.connect + """ result = conninfo_to_dict(self.connection_string.get_secret_value()) if self.connection_string else {} if self.host: @@ -86,6 +94,8 @@ def get_connection_args(self) -> dict[str, Any]: if self.password: result["password"] = self.password.get_secret_value() + result = {**result, **kwargs} + # Ensure required values if "host" not in result: raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.")