Skip to content

Commit be0e24b

Browse files
committed
Remove patching load credentials and read test credentials directly for integration tests
1 parent 41d71da commit be0e24b

File tree

9 files changed

+78
-121
lines changed

9 files changed

+78
-121
lines changed

src/databricks/labs/lakebridge/connections/credential_manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ def get_secret(self, key: str) -> str:
4040
class CredentialManager:
4141
def __init__(self, credential_loader: dict, secret_providers: dict):
4242
self._credentials = credential_loader
43-
self._secret_providers = secret_providers
4443
self._default_vault = self._credentials.get('secret_vault_type', 'local').lower()
44+
self._provider = secret_providers.get(self._default_vault)
45+
if not self._provider:
46+
raise ValueError(f"Unsupported secret vault type: {self._default_vault}")
4547

4648
def get_credentials(self, source: str) -> dict:
4749
if source not in self._credentials:
@@ -54,10 +56,8 @@ def get_credentials(self, source: str) -> dict:
5456
return {k: self._get_secret_value(v) for k, v in value.items()}
5557

5658
def _get_secret_value(self, key: str) -> str:
57-
provider = self._secret_providers.get(self._default_vault)
58-
if not provider:
59-
raise ValueError(f"Unsupported secret vault type: {self._default_vault}")
60-
return provider.get_secret(key)
59+
assert self._provider is not None
60+
return self._provider.get_secret(key)
6161

6262

6363
def _get_home() -> Path:
@@ -77,13 +77,13 @@ def _load_credentials(path: Path) -> dict:
7777

7878

7979
def create_credential_manager(product_name: str, env_getter: EnvGetter):
80-
file_path = Path(f"{_get_home()}/.databricks/labs/{product_name}/.credentials.yml")
80+
creds_path = cred_file(product_name)
8181

8282
secret_providers = {
8383
'local': LocalSecretProvider(),
8484
'env': EnvSecretProvider(env_getter),
8585
'databricks': DatabricksSecretProvider(),
8686
}
8787

88-
loader = _load_credentials(file_path)
88+
loader = _load_credentials(creds_path)
8989
return CredentialManager(loader, secret_providers)

src/databricks/labs/lakebridge/connections/database_manager.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _connect(self) -> Engine:
1818
pass
1919

2020
@abstractmethod
21-
def execute_query(self, query: str) -> Result[Any]:
21+
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
2222
pass
2323

2424

@@ -30,12 +30,15 @@ def __init__(self, config: dict[str, Any]):
3030
def _connect(self) -> Engine:
3131
raise NotImplementedError("Subclasses should implement this method")
3232

33-
def execute_query(self, query: str) -> Result[Any]:
33+
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
3434
if not self.engine:
3535
raise ConnectionError("Not connected to the database.")
3636
session = sessionmaker(bind=self.engine)
3737
connection = session()
38-
return connection.execute(text(query))
38+
result = connection.execute(text(query))
39+
if commit:
40+
connection.commit()
41+
return result
3942

4043

4144
def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector:
@@ -82,9 +85,9 @@ class DatabaseManager:
8285
def __init__(self, db_type: str, config: dict[str, Any]):
8386
self.connector = _create_connector(db_type, config)
8487

85-
def execute_query(self, query: str) -> Result[Any]:
88+
def execute_query(self, query: str, commit: bool = False) -> Result[Any]:
8689
try:
87-
return self.connector.execute_query(query)
90+
return self.connector.execute_query(query, commit)
8891
except OperationalError:
8992
logger.error("Error connecting to the database check credentials")
9093
raise ConnectionError("Error connecting to the database check credentials") from None

tests/integration/assessments/test_pipeline.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@
55
from databricks.labs.lakebridge.assessments.pipeline import PipelineClass, DB_NAME, StepExecutionStatus
66

77
from databricks.labs.lakebridge.assessments.profiler_config import Step, PipelineConfig
8-
from ..connections.helpers import get_db_manager
9-
10-
11-
@pytest.fixture()
12-
def extractor(mock_credentials):
13-
return get_db_manager("remorph", "mssql")
148

159

1610
@pytest.fixture(scope="module")
@@ -55,8 +49,8 @@ def python_failure_config():
5549
return config
5650

5751

58-
def test_run_pipeline(extractor, pipeline_config, get_logger):
59-
pipeline = PipelineClass(config=pipeline_config, executor=extractor)
52+
def test_run_pipeline(test_sqlserver, pipeline_config, get_logger):
53+
pipeline = PipelineClass(config=pipeline_config, executor=test_sqlserver)
6054
results = pipeline.execute()
6155
print("*******************\n")
6256
print(results)
@@ -72,8 +66,8 @@ def test_run_pipeline(extractor, pipeline_config, get_logger):
7266
assert verify_output(get_logger, pipeline_config.extract_folder)
7367

7468

75-
def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger):
76-
pipeline = PipelineClass(config=sql_failure_config, executor=extractor)
69+
def test_run_sql_failure_pipeline(test_sqlserver, sql_failure_config, get_logger):
70+
pipeline = PipelineClass(config=sql_failure_config, executor=test_sqlserver)
7771
results = pipeline.execute()
7872

7973
# Find the failed SQL step
@@ -82,8 +76,8 @@ def test_run_sql_failure_pipeline(extractor, sql_failure_config, get_logger):
8276
assert "SQL execution failed" in failed_steps[0].error_message
8377

8478

85-
def test_run_python_failure_pipeline(extractor, python_failure_config, get_logger):
86-
pipeline = PipelineClass(config=python_failure_config, executor=extractor)
79+
def test_run_python_failure_pipeline(test_sqlserver, python_failure_config, get_logger):
80+
pipeline = PipelineClass(config=python_failure_config, executor=test_sqlserver)
8781
results = pipeline.execute()
8882

8983
# Find the failed Python step
@@ -92,8 +86,8 @@ def test_run_python_failure_pipeline(extractor, python_failure_config, get_logge
9286
assert "Script execution failed" in failed_steps[0].error_message
9387

9488

95-
def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config, get_logger):
96-
pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=extractor)
89+
def test_run_python_dep_failure_pipeline(test_sqlserver, pipeline_dep_failure_config, get_logger):
90+
pipeline = PipelineClass(config=pipeline_dep_failure_config, executor=test_sqlserver)
9791
results = pipeline.execute()
9892

9993
# Find the failed Python step
@@ -102,12 +96,12 @@ def test_run_python_dep_failure_pipeline(extractor, pipeline_dep_failure_config,
10296
assert "Script execution failed" in failed_steps[0].error_message
10397

10498

105-
def test_skipped_steps(extractor, pipeline_config, get_logger):
99+
def test_skipped_steps(test_sqlserver, pipeline_config, get_logger):
106100
# Modify config to have some inactive steps
107101
for step in pipeline_config.steps:
108102
step.flag = "inactive"
109103

110-
pipeline = PipelineClass(config=pipeline_config, executor=extractor)
104+
pipeline = PipelineClass(config=pipeline_config, executor=test_sqlserver)
111105
results = pipeline.execute()
112106

113107
# Verify all steps are marked as skipped

tests/integration/conftest.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22
import logging
3-
from unittest.mock import patch
3+
from urllib.parse import urlparse
44

55
import pytest
66
from pyspark.sql import SparkSession
7-
from databricks.labs.lakebridge.__about__ import __version__
87

8+
from databricks.labs.lakebridge.__about__ import __version__
9+
from databricks.labs.lakebridge.connections.database_manager import DatabaseManager
10+
from tests.integration.debug_envgetter import TestEnvGetter
911

1012
logging.getLogger("tests").setLevel("DEBUG")
1113
logging.getLogger("databricks.labs.lakebridge").setLevel("DEBUG")
@@ -53,48 +55,26 @@ def mock_spark() -> SparkSession:
5355
return SparkSession.builder.appName("Remorph Reconcile Test").remote("sc://localhost").getOrCreate()
5456

5557

56-
@pytest.fixture(scope="session")
57-
def mock_credentials():
58-
with patch(
59-
'databricks.labs.lakebridge.connections.credential_manager._load_credentials',
60-
return_value={
61-
'secret_vault_type': 'env',
62-
'secret_vault_name': '',
63-
'mssql': {
64-
'user': 'TEST_TSQL_USER',
65-
'password': 'TEST_TSQL_PASS',
66-
'server': 'TEST_TSQL_JDBC',
67-
'database': 'TEST_TSQL_JDBC',
68-
'driver': 'ODBC Driver 18 for SQL Server',
69-
},
70-
'synapse': {
71-
'workspace': {
72-
'name': 'test-workspace',
73-
'dedicated_sql_endpoint': 'test-dedicated-endpoint',
74-
'serverless_sql_endpoint': 'test-serverless-endpoint',
75-
'sql_user': 'test-user',
76-
'sql_password': 'test-password',
77-
'tz_info': 'UTC',
78-
},
79-
'azure_api_access': {
80-
'development_endpoint': 'test-dev-endpoint',
81-
'azure_client_id': 'test-client-id',
82-
'azure_tenant_id': 'test-tenant-id',
83-
'azure_client_secret': 'test-client-secret',
84-
},
85-
'jdbc': {
86-
'auth_type': 'sql_authentication',
87-
'fetch_size': '1000',
88-
'login_timeout': '30',
89-
},
90-
'profiler': {
91-
'exclude_serverless_sql_pool': False,
92-
'exclude_dedicated_sql_pools': False,
93-
'exclude_spark_pools': False,
94-
'exclude_monitoring_metrics': False,
95-
'redact_sql_pools_sql_text': False,
96-
},
97-
},
98-
},
99-
):
100-
yield
58+
@pytest.fixture()
59+
def test_sqlserver_db_config():
60+
env = TestEnvGetter(True)
61+
db_url = env.get("TEST_TSQL_JDBC").removeprefix("jdbc:")
62+
base_url, params = db_url.replace("jdbc:", "", 1).split(";", 1)
63+
url_parts = urlparse(base_url)
64+
server = url_parts.hostname
65+
query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param)
66+
database = query_params.get("database", "")
67+
68+
config = {
69+
"user": env.get("TEST_TSQL_USER"),
70+
"password": env.get("TEST_TSQL_PASS"),
71+
"server": server,
72+
"database": database,
73+
"driver": "ODBC Driver 18 for SQL Server",
74+
}
75+
return config
76+
77+
78+
@pytest.fixture()
79+
def test_sqlserver(test_sqlserver_db_config):
80+
return DatabaseManager("mssql", test_sqlserver_db_config)

tests/integration/connections/helpers.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

tests/integration/connections/test_mssql_connector.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
import pytest
2-
31
from databricks.labs.lakebridge.connections.database_manager import MSSQLConnector
4-
from .helpers import get_db_manager
5-
6-
7-
@pytest.fixture()
8-
def db_manager(mock_credentials):
9-
return get_db_manager("remorph", "mssql")
102

113

12-
def test_mssql_connector_connection(db_manager):
13-
assert isinstance(db_manager.connector, MSSQLConnector)
4+
def test_mssql_connector_connection(test_sqlserver):
5+
assert isinstance(test_sqlserver.connector, MSSQLConnector)
146

157

16-
def test_mssql_connector_execute_query(db_manager):
8+
def test_mssql_connector_execute_query(test_sqlserver):
179
# Test executing a query
1810
query = "SELECT 101 AS test_column"
19-
result = db_manager.execute_query(query)
11+
result = test_sqlserver.execute_query(query)
2012
row = result.fetchone()
2113
assert row[0] == 101
2214

File renamed without changes.
Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
import pytest
2-
31
from databricks.labs.lakebridge.discovery.tsql_table_definition import TsqlTableDefinitionService
4-
from ..connections.helpers import get_db_manager
5-
6-
7-
@pytest.fixture()
8-
def extractor(mock_credentials):
9-
return get_db_manager("remorph", "mssql")
102

113

12-
def test_tsql_get_catalog(extractor):
13-
tss = TsqlTableDefinitionService(extractor)
4+
def test_tsql_get_catalog(test_sqlserver):
5+
tss = TsqlTableDefinitionService(test_sqlserver)
146
catalogs = list(tss.get_all_catalog())
157
assert catalogs is not None
168
assert len(catalogs) > 0
179

1810

19-
def test_tsql_get_table_definition(extractor):
20-
tss = TsqlTableDefinitionService(extractor)
11+
def test_tsql_get_table_definition(test_sqlserver):
12+
tss = TsqlTableDefinitionService(test_sqlserver)
2113
table_def = tss.get_table_definition("labs_azure_sandbox_remorph")
2214
assert table_def is not None

tests/unit/connections/test_database_manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,24 @@ def test_execute_query(mock_mssql_connector):
4343
result = db_manager.execute_query(query)
4444

4545
assert result == mock_result
46-
mock_connector_instance.execute_query.assert_called_once_with(query)
46+
mock_connector_instance.execute_query.assert_called_once_with(query, False)
47+
48+
49+
@patch('databricks.labs.lakebridge.connections.database_manager.MSSQLConnector')
50+
def test_execute_query_commit(mock_mssql_connector):
51+
mock_connector_instance = MagicMock()
52+
mock_mssql_connector.return_value = mock_connector_instance
53+
54+
db_manager = DatabaseManager("mssql", sample_config)
55+
56+
mutate_query = "TRUNCATE users"
57+
mock_result = MagicMock()
58+
mock_connector_instance.execute_query.return_value = mock_result
59+
60+
mutate_result = db_manager.execute_query(mutate_query, commit=True)
61+
62+
assert mutate_result == mock_result
63+
mock_connector_instance.execute_query.assert_called_once_with(mutate_query, True)
4764

4865

4966
def running_on_ci() -> bool:

0 commit comments

Comments
 (0)