diff --git a/src/databricks/labs/lakebridge/connections/credential_manager.py b/src/databricks/labs/lakebridge/connections/credential_manager.py index bd1e42a492..e1bb22f9a5 100644 --- a/src/databricks/labs/lakebridge/connections/credential_manager.py +++ b/src/databricks/labs/lakebridge/connections/credential_manager.py @@ -40,8 +40,10 @@ def get_secret(self, key: str) -> str: class CredentialManager: def __init__(self, credential_loader: dict, secret_providers: dict): self._credentials = credential_loader - self._secret_providers = secret_providers self._default_vault = self._credentials.get('secret_vault_type', 'local').lower() + self._provider = secret_providers.get(self._default_vault) + if not self._provider: + raise ValueError(f"Unsupported secret vault type: {self._default_vault}") def get_credentials(self, source: str) -> dict: if source not in self._credentials: @@ -54,10 +56,8 @@ def get_credentials(self, source: str) -> dict: return {k: self._get_secret_value(v) for k, v in value.items()} def _get_secret_value(self, key: str) -> str: - provider = self._secret_providers.get(self._default_vault) - if not provider: - raise ValueError(f"Unsupported secret vault type: {self._default_vault}") - return provider.get_secret(key) + assert self._provider is not None + return self._provider.get_secret(key) def _get_home() -> Path: @@ -77,7 +77,7 @@ def _load_credentials(path: Path) -> dict: def create_credential_manager(product_name: str, env_getter: EnvGetter): - file_path = Path(f"{_get_home()}/.databricks/labs/{product_name}/.credentials.yml") + creds_path = cred_file(product_name) secret_providers = { 'local': LocalSecretProvider(), @@ -85,5 +85,5 @@ def create_credential_manager(product_name: str, env_getter: EnvGetter): 'databricks': DatabricksSecretProvider(), } - loader = _load_credentials(file_path) + loader = _load_credentials(creds_path) return CredentialManager(loader, secret_providers) diff --git a/src/databricks/labs/lakebridge/connections/database_manager.py b/src/databricks/labs/lakebridge/connections/database_manager.py index df9d678bb2..45adb1c9fc 100644 --- a/src/databricks/labs/lakebridge/connections/database_manager.py +++ b/src/databricks/labs/lakebridge/connections/database_manager.py @@ -18,7 +18,7 @@ def _connect(self) -> Engine: pass @abstractmethod - def execute_query(self, query: str) -> Result[Any]: + def execute_query(self, query: str, commit: bool = False) -> Result[Any]: pass @@ -30,12 +30,15 @@ def __init__(self, config: dict[str, Any]): def _connect(self) -> Engine: raise NotImplementedError("Subclasses should implement this method") - def execute_query(self, query: str) -> Result[Any]: + def execute_query(self, query: str, commit: bool = False) -> Result[Any]: if not self.engine: raise ConnectionError("Not connected to the database.") session = sessionmaker(bind=self.engine) connection = session() - return connection.execute(text(query)) + result = connection.execute(text(query)) + if commit: + connection.commit() + return result def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector: @@ -82,9 +85,9 @@ class DatabaseManager: def __init__(self, db_type: str, config: dict[str, Any]): self.connector = _create_connector(db_type, config) - def execute_query(self, query: str) -> Result[Any]: + def execute_query(self, query: str, commit: bool = False) -> Result[Any]: try: - return self.connector.execute_query(query) + return self.connector.execute_query(query, commit) except OperationalError: logger.error("Error connecting to the database check credentials") raise ConnectionError("Error connecting to the database check credentials") from None diff --git a/tests/integration/assessments/test_pipeline.py b/tests/integration/assessments/test_pipeline.py index 541a846e35..1494943419 100644 --- a/tests/integration/assessments/test_pipeline.py +++ b/tests/integration/assessments/test_pipeline.py @@ -5,12 +5,6 @@ from databricks.labs.lakebridge.assessments.pipeline import PipelineClass, DB_NAME, StepExecutionStatus from databricks.labs.lakebridge.assessments.profiler_config import Step, PipelineConfig -from ..connections.helpers import get_db_manager - - -@pytest.fixture() -def extractor(mock_credentials): - return get_db_manager("remorph", "mssql") @pytest.fixture(scope="module") @@ -55,12 +49,9 @@ def python_failure_config(): return config -def test_run_pipeline(extractor, pipeline_config, get_logger): - pipeline = PipelineClass(config=pipeline_config, executor=extractor) +def test_run_pipeline(sandbox_sqlserver, pipeline_config, get_logger): + pipeline = PipelineClass(config=pipeline_config, executor=sandbox_sqlserver) results = pipeline.execute() - print("*******************\n") - print(results) - print("\n*******************") # Verify all steps completed successfully for result in results: @@ -72,8 +63,8 @@ def test_run_pipeline(extractor, pipeline_config, get_logger): assert verify_output(get_logger, pipeline_config.extract_folder) -def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger): - pipeline = PipelineClass(config=sql_failure_config, executor=extractor) +def test_run_sql_failure_pipeline(sandbox_sqlserver, sql_failure_config, get_logger): + pipeline = PipelineClass(config=sql_failure_config, executor=sandbox_sqlserver) results = pipeline.execute() # Find the failed SQL step @@ -82,8 +73,8 @@ def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger): assert "SQL execution failed" in failed_steps[0].error_message -def test_run_python_failure_pipeline(extractor, python_failure_config, get_logger): - pipeline = PipelineClass(config=python_failure_config, executor=extractor) +def test_run_python_failure_pipeline(sandbox_sqlserver, python_failure_config, get_logger): + pipeline = PipelineClass(config=python_failure_config, executor=sandbox_sqlserver) results = pipeline.execute() # Find the failed Python step @@ -92,8 +83,8 @@ def test_run_python_failure_pipeline(extractor, python_failure_config, get_logge assert "Script execution failed" in failed_steps[0].error_message -def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config, get_logger): - pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=extractor) +def test_run_python_dep_failure_pipeline(sandbox_sqlserver, pipeline_dep_failure_config, get_logger): + pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=sandbox_sqlserver) results = pipeline.execute() # Find the failed Python step @@ -102,12 +93,12 @@ def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config, assert "Script execution failed" in failed_steps[0].error_message -def test_skipped_steps(extractor, pipeline_config, get_logger): +def test_skipped_steps(sandbox_sqlserver, pipeline_config, get_logger): # Modify config to have some inactive steps for step in pipeline_config.steps: step.flag = "inactive" - pipeline = PipelineClass(config=pipeline_config, executor=extractor) + pipeline = PipelineClass(config=pipeline_config, executor=sandbox_sqlserver) results = pipeline.execute() # Verify all steps are marked as skipped diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a6a54effb6..b9e496edbd 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,13 @@ import os import logging -from unittest.mock import patch +from urllib.parse import urlparse import pytest from pyspark.sql import SparkSession -from databricks.labs.lakebridge.__about__ import __version__ +from databricks.labs.lakebridge.__about__ import __version__ +from databricks.labs.lakebridge.connections.database_manager import DatabaseManager +from tests.integration.debug_envgetter import TestEnvGetter logging.getLogger("tests").setLevel("DEBUG") logging.getLogger("databricks.labs.lakebridge").setLevel("DEBUG") @@ -53,48 +55,26 @@ def mock_spark() -> SparkSession: return SparkSession.builder.appName("Remorph Reconcile Test").remote("sc://localhost").getOrCreate() -@pytest.fixture(scope="session") -def mock_credentials(): - with patch( - 'databricks.labs.lakebridge.connections.credential_manager._load_credentials', - return_value={ - 'secret_vault_type': 'env', - 'secret_vault_name': '', - 'mssql': { - 'user': 'TEST_TSQL_USER', - 'password': 'TEST_TSQL_PASS', - 'server': 'TEST_TSQL_JDBC', - 'database': 'TEST_TSQL_JDBC', - 'driver': 'ODBC Driver 18 for SQL Server', - }, - 'synapse': { - 'workspace': { - 'name': 'test-workspace', - 'dedicated_sql_endpoint': 'test-dedicated-endpoint', - 'serverless_sql_endpoint': 'test-serverless-endpoint', - 'sql_user': 'test-user', - 'sql_password': 'test-password', - 'tz_info': 'UTC', - }, - 'azure_api_access': { - 'development_endpoint': 'test-dev-endpoint', - 'azure_client_id': 'test-client-id', - 'azure_tenant_id': 'test-tenant-id', - 'azure_client_secret': 'test-client-secret', - }, - 'jdbc': { - 'auth_type': 'sql_authentication', - 'fetch_size': '1000', - 'login_timeout': '30', - }, - 'profiler': { - 'exclude_serverless_sql_pool': False, - 'exclude_dedicated_sql_pools': False, - 'exclude_spark_pools': False, - 'exclude_monitoring_metrics': False, - 'redact_sql_pools_sql_text': False, - }, - }, - }, - ): - yield +@pytest.fixture() +def sandbox_sqlserver_config(): + env = TestEnvGetter(True) + db_url = env.get("TEST_TSQL_JDBC").removeprefix("jdbc:") + base_url, params = db_url.replace("jdbc:", "", 1).split(";", 1) + url_parts = urlparse(base_url) + server = url_parts.hostname + query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) + database = query_params.get("database", "") + + config = { + "user": env.get("TEST_TSQL_USER"), + "password": env.get("TEST_TSQL_PASS"), + "server": server, + "database": database, + "driver": "ODBC Driver 18 for SQL Server", + } + return config + + +@pytest.fixture() +def sandbox_sqlserver(sandbox_sqlserver_config): + return DatabaseManager("mssql", sandbox_sqlserver_config) diff --git a/tests/integration/connections/helpers.py b/tests/integration/connections/helpers.py deleted file mode 100644 index a1156f6ad2..0000000000 --- a/tests/integration/connections/helpers.py +++ /dev/null @@ -1,21 +0,0 @@ -from urllib.parse import urlparse -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.connections.database_manager import DatabaseManager -from .debug_envgetter import TestEnvGetter - - -def get_db_manager(product_name: str, source: str) -> DatabaseManager: - env = TestEnvGetter(True) - config = create_credential_manager(product_name, env).get_credentials(source) - - # since the kv has only URL so added explicit parse rules - base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) - - url_parts = urlparse(base_url) - server = url_parts.hostname - query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) - database = query_params.get("database", "") - config['server'] = server - config['database'] = database - - return DatabaseManager(source, config) diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py index 563dba8479..873c75e908 100644 --- a/tests/integration/connections/test_mssql_connector.py +++ b/tests/integration/connections/test_mssql_connector.py @@ -1,25 +1,17 @@ -import pytest - from databricks.labs.lakebridge.connections.database_manager import MSSQLConnector -from .helpers import get_db_manager - - -@pytest.fixture() -def db_manager(mock_credentials): - return get_db_manager("remorph", "mssql") -def test_mssql_connector_connection(db_manager): - assert isinstance(db_manager.connector, MSSQLConnector) +def test_mssql_connector_connection(sandbox_sqlserver): + assert isinstance(sandbox_sqlserver.connector, MSSQLConnector) -def test_mssql_connector_execute_query(db_manager): +def test_mssql_connector_execute_query(sandbox_sqlserver): # Test executing a query query = "SELECT 101 AS test_column" - result = db_manager.execute_query(query) + result = sandbox_sqlserver.execute_query(query) row = result.fetchone() assert row[0] == 101 -def test_connection_test(db_manager): - assert db_manager.check_connection() +def test_connection_test(sandbox_sqlserver): + assert sandbox_sqlserver.check_connection() diff --git a/tests/integration/connections/debug_envgetter.py b/tests/integration/debug_envgetter.py similarity index 100% rename from tests/integration/connections/debug_envgetter.py rename to tests/integration/debug_envgetter.py diff --git a/tests/integration/discovery/test_tsql_table_definition.py b/tests/integration/discovery/test_tsql_table_definition.py index 00d81e18af..8fc447c88d 100644 --- a/tests/integration/discovery/test_tsql_table_definition.py +++ b/tests/integration/discovery/test_tsql_table_definition.py @@ -1,22 +1,14 @@ -import pytest - from databricks.labs.lakebridge.discovery.tsql_table_definition import TsqlTableDefinitionService -from ..connections.helpers import get_db_manager - - -@pytest.fixture() -def extractor(mock_credentials): - return get_db_manager("remorph", "mssql") -def test_tsql_get_catalog(extractor): - tss = TsqlTableDefinitionService(extractor) +def test_tsql_get_catalog(sandbox_sqlserver): + tss = TsqlTableDefinitionService(sandbox_sqlserver) catalogs = list(tss.get_all_catalog()) assert catalogs is not None assert len(catalogs) > 0 -def test_tsql_get_table_definition(extractor): - tss = TsqlTableDefinitionService(extractor) +def test_tsql_get_table_definition(sandbox_sqlserver): + tss = TsqlTableDefinitionService(sandbox_sqlserver) table_def = tss.get_table_definition("labs_azure_sandbox_remorph") assert table_def is not None diff --git a/tests/integration/reconcile/connectors/test_read_schema.py b/tests/integration/reconcile/connectors/test_read_schema.py index c25df12e25..12cf51deea 100644 --- a/tests/integration/reconcile/connectors/test_read_schema.py +++ b/tests/integration/reconcile/connectors/test_read_schema.py @@ -7,7 +7,7 @@ from databricks.sdk import WorkspaceClient -from tests.integration.connections.debug_envgetter import TestEnvGetter +from tests.integration.debug_envgetter import TestEnvGetter class TSQLServerDataSourceUnderTest(TSQLServerDataSource): diff --git a/tests/unit/connections/test_database_manager.py b/tests/unit/connections/test_database_manager.py index 1b64dd1728..71cd810684 100644 --- a/tests/unit/connections/test_database_manager.py +++ b/tests/unit/connections/test_database_manager.py @@ -43,7 +43,24 @@ def test_execute_query(mock_mssql_connector): result = db_manager.execute_query(query) assert result == mock_result - mock_connector_instance.execute_query.assert_called_once_with(query) + mock_connector_instance.execute_query.assert_called_once_with(query, False) + + +@patch('databricks.labs.lakebridge.connections.database_manager.MSSQLConnector') +def test_execute_query_commit(mock_mssql_connector): + mock_connector_instance = MagicMock() + mock_mssql_connector.return_value = mock_connector_instance + + db_manager = DatabaseManager("mssql", sample_config) + + mutate_query = "TRUNCATE users" + mock_result = MagicMock() + mock_connector_instance.execute_query.return_value = mock_result + + mutate_result = db_manager.execute_query(mutate_query, commit=True) + + assert mutate_result == mock_result + mock_connector_instance.execute_query.assert_called_once_with(mutate_query, True) def running_on_ci() -> bool: