Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def get_secret(self, key: str) -> str:
class CredentialManager:
def __init__(self, credential_loader: dict, secret_providers: dict):
self._credentials = credential_loader
self._secret_providers = secret_providers
self._default_vault = self._credentials.get('secret_vault_type', 'local').lower()
self._provider = secret_providers.get(self._default_vault)
if not self._provider:
raise ValueError(f"Unsupported secret vault type: {self._default_vault}")

def get_credentials(self, source: str) -> dict:
if source not in self._credentials:
Expand All @@ -54,10 +56,8 @@ def get_credentials(self, source: str) -> dict:
return {k: self._get_secret_value(v) for k, v in value.items()}

def _get_secret_value(self, key: str) -> str:
provider = self._secret_providers.get(self._default_vault)
if not provider:
raise ValueError(f"Unsupported secret vault type: {self._default_vault}")
return provider.get_secret(key)
assert self._provider is not None
return self._provider.get_secret(key)


def _get_home() -> Path:
Expand All @@ -77,13 +77,13 @@ def _load_credentials(path: Path) -> dict:


def create_credential_manager(product_name: str, env_getter: EnvGetter):
file_path = Path(f"{_get_home()}/.databricks/labs/{product_name}/.credentials.yml")
creds_path = cred_file(product_name)

secret_providers = {
'local': LocalSecretProvider(),
'env': EnvSecretProvider(env_getter),
'databricks': DatabricksSecretProvider(),
}

loader = _load_credentials(file_path)
loader = _load_credentials(creds_path)
return CredentialManager(loader, secret_providers)
13 changes: 8 additions & 5 deletions src/databricks/labs/lakebridge/connections/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _connect(self) -> Engine:
pass

@abstractmethod
def execute_query(self, query: str) -> Result[Any]:
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
pass


Expand All @@ -30,12 +30,15 @@ def __init__(self, config: dict[str, Any]):
def _connect(self) -> Engine:
raise NotImplementedError("Subclasses should implement this method")

def execute_query(self, query: str) -> Result[Any]:
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
if not self.engine:
raise ConnectionError("Not connected to the database.")
session = sessionmaker(bind=self.engine)
connection = session()
return connection.execute(text(query))
result = connection.execute(text(query))
if commit:
connection.commit()
return result


def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector:
Expand Down Expand Up @@ -82,9 +85,9 @@ class DatabaseManager:
def __init__(self, db_type: str, config: dict[str, Any]):
self.connector = _create_connector(db_type, config)

def execute_query(self, query: str) -> Result[Any]:
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
try:
return self.connector.execute_query(query)
return self.connector.execute_query(query, commit)
except OperationalError:
logger.error("Error connecting to the database check credentials")
raise ConnectionError("Error connecting to the database check credentials") from None
Expand Down
29 changes: 10 additions & 19 deletions tests/integration/assessments/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
from databricks.labs.lakebridge.assessments.pipeline import PipelineClass, DB_NAME, StepExecutionStatus

from databricks.labs.lakebridge.assessments.profiler_config import Step, PipelineConfig
from ..connections.helpers import get_db_manager


@pytest.fixture()
def extractor(mock_credentials):
return get_db_manager("remorph", "mssql")


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -55,12 +49,9 @@ def python_failure_config():
return config


def test_run_pipeline(extractor, pipeline_config, get_logger):
pipeline = PipelineClass(config=pipeline_config, executor=extractor)
def test_run_pipeline(test_sqlserver, pipeline_config, get_logger):
pipeline = PipelineClass(config=pipeline_config, executor=test_sqlserver)
results = pipeline.execute()
print("*******************\n")
print(results)
print("\n*******************")

# Verify all steps completed successfully
for result in results:
Expand All @@ -72,8 +63,8 @@ def test_run_pipeline(extractor, pipeline_config, get_logger):
assert verify_output(get_logger, pipeline_config.extract_folder)


def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger):
pipeline = PipelineClass(config=sql_failure_config, executor=extractor)
def test_run_sql_failure_pipeline(test_sqlserver, sql_failure_config, get_logger):
pipeline = PipelineClass(config=sql_failure_config, executor=test_sqlserver)
results = pipeline.execute()

# Find the failed SQL step
Expand All @@ -82,8 +73,8 @@ def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger):
assert "SQL execution failed" in failed_steps[0].error_message


def test_run_python_failure_pipeline(extractor, python_failure_config, get_logger):
pipeline = PipelineClass(config=python_failure_config, executor=extractor)
def test_run_python_failure_pipeline(test_sqlserver, python_failure_config, get_logger):
pipeline = PipelineClass(config=python_failure_config, executor=test_sqlserver)
results = pipeline.execute()

# Find the failed Python step
Expand All @@ -92,8 +83,8 @@ def test_run_python_failure_pipeline(extractor, python_failure_config, get_logge
assert "Script execution failed" in failed_steps[0].error_message


def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config, get_logger):
pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=extractor)
def test_run_python_dep_failure_pipeline(test_sqlserver, pipeline_dep_failure_config, get_logger):
pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=test_sqlserver)
results = pipeline.execute()

# Find the failed Python step
Expand All @@ -102,12 +93,12 @@ def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config,
assert "Script execution failed" in failed_steps[0].error_message


def test_skipped_steps(extractor, pipeline_config, get_logger):
def test_skipped_steps(test_sqlserver, pipeline_config, get_logger):
# Modify config to have some inactive steps
for step in pipeline_config.steps:
step.flag = "inactive"

pipeline = PipelineClass(config=pipeline_config, executor=extractor)
pipeline = PipelineClass(config=pipeline_config, executor=test_sqlserver)
results = pipeline.execute()

# Verify all steps are marked as skipped
Expand Down
74 changes: 27 additions & 47 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import logging
from unittest.mock import patch
from urllib.parse import urlparse

import pytest
from pyspark.sql import SparkSession
from databricks.labs.lakebridge.__about__ import __version__

from databricks.labs.lakebridge.__about__ import __version__
from databricks.labs.lakebridge.connections.database_manager import DatabaseManager
from tests.integration.debug_envgetter import TestEnvGetter

logging.getLogger("tests").setLevel("DEBUG")
logging.getLogger("databricks.labs.lakebridge").setLevel("DEBUG")
Expand Down Expand Up @@ -53,48 +55,26 @@ def mock_spark() -> SparkSession:
return SparkSession.builder.appName("Remorph Reconcile Test").remote("sc://localhost").getOrCreate()


@pytest.fixture(scope="session")
def mock_credentials():
with patch(
'databricks.labs.lakebridge.connections.credential_manager._load_credentials',
return_value={
'secret_vault_type': 'env',
'secret_vault_name': '',
'mssql': {
'user': 'TEST_TSQL_USER',
'password': 'TEST_TSQL_PASS',
'server': 'TEST_TSQL_JDBC',
'database': 'TEST_TSQL_JDBC',
'driver': 'ODBC Driver 18 for SQL Server',
},
'synapse': {
'workspace': {
'name': 'test-workspace',
'dedicated_sql_endpoint': 'test-dedicated-endpoint',
'serverless_sql_endpoint': 'test-serverless-endpoint',
'sql_user': 'test-user',
'sql_password': 'test-password',
'tz_info': 'UTC',
},
'azure_api_access': {
'development_endpoint': 'test-dev-endpoint',
'azure_client_id': 'test-client-id',
'azure_tenant_id': 'test-tenant-id',
'azure_client_secret': 'test-client-secret',
},
'jdbc': {
'auth_type': 'sql_authentication',
'fetch_size': '1000',
'login_timeout': '30',
},
'profiler': {
'exclude_serverless_sql_pool': False,
'exclude_dedicated_sql_pools': False,
'exclude_spark_pools': False,
'exclude_monitoring_metrics': False,
'redact_sql_pools_sql_text': False,
},
},
},
):
yield
@pytest.fixture()
def test_sqlserver_db_config():
env = TestEnvGetter(True)
db_url = env.get("TEST_TSQL_JDBC").removeprefix("jdbc:")
base_url, params = db_url.replace("jdbc:", "", 1).split(";", 1)
url_parts = urlparse(base_url)
server = url_parts.hostname
query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param)
database = query_params.get("database", "")

config = {
"user": env.get("TEST_TSQL_USER"),
"password": env.get("TEST_TSQL_PASS"),
"server": server,
"database": database,
"driver": "ODBC Driver 18 for SQL Server",
}
return config


@pytest.fixture()
def test_sqlserver(test_sqlserver_db_config):
return DatabaseManager("mssql", test_sqlserver_db_config)
21 changes: 0 additions & 21 deletions tests/integration/connections/helpers.py

This file was deleted.

20 changes: 6 additions & 14 deletions tests/integration/connections/test_mssql_connector.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
import pytest

from databricks.labs.lakebridge.connections.database_manager import MSSQLConnector
from .helpers import get_db_manager


@pytest.fixture()
def db_manager(mock_credentials):
return get_db_manager("remorph", "mssql")


def test_mssql_connector_connection(db_manager):
assert isinstance(db_manager.connector, MSSQLConnector)
def test_mssql_connector_connection(test_sqlserver):
assert isinstance(test_sqlserver.connector, MSSQLConnector)


def test_mssql_connector_execute_query(db_manager):
def test_mssql_connector_execute_query(test_sqlserver):
# Test executing a query
query = "SELECT 101 AS test_column"
result = db_manager.execute_query(query)
result = test_sqlserver.execute_query(query)
row = result.fetchone()
assert row[0] == 101


def test_connection_test(db_manager):
assert db_manager.check_connection()
def test_connection_test(test_sqlserver):
assert test_sqlserver.check_connection()
16 changes: 4 additions & 12 deletions tests/integration/discovery/test_tsql_table_definition.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import pytest

from databricks.labs.lakebridge.discovery.tsql_table_definition import TsqlTableDefinitionService
from ..connections.helpers import get_db_manager


@pytest.fixture()
def extractor(mock_credentials):
return get_db_manager("remorph", "mssql")


def test_tsql_get_catalog(extractor):
tss = TsqlTableDefinitionService(extractor)
def test_tsql_get_catalog(test_sqlserver):
tss = TsqlTableDefinitionService(test_sqlserver)
catalogs = list(tss.get_all_catalog())
assert catalogs is not None
assert len(catalogs) > 0


def test_tsql_get_table_definition(extractor):
tss = TsqlTableDefinitionService(extractor)
def test_tsql_get_table_definition(test_sqlserver):
tss = TsqlTableDefinitionService(test_sqlserver)
table_def = tss.get_table_definition("labs_azure_sandbox_remorph")
assert table_def is not None
19 changes: 18 additions & 1 deletion tests/unit/connections/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,24 @@ def test_execute_query(mock_mssql_connector):
result = db_manager.execute_query(query)

assert result == mock_result
mock_connector_instance.execute_query.assert_called_once_with(query)
mock_connector_instance.execute_query.assert_called_once_with(query, False)


@patch('databricks.labs.lakebridge.connections.database_manager.MSSQLConnector')
def test_execute_query_commit(mock_mssql_connector):
mock_connector_instance = MagicMock()
mock_mssql_connector.return_value = mock_connector_instance

db_manager = DatabaseManager("mssql", sample_config)

mutate_query = "TRUNCATE users"
mock_result = MagicMock()
mock_connector_instance.execute_query.return_value = mock_result

mutate_result = db_manager.execute_query(mutate_query, commit=True)

assert mutate_result == mock_result
mock_connector_instance.execute_query.assert_called_once_with(mutate_query, True)


def running_on_ci() -> bool:
Expand Down
Loading