diff --git a/src/databricks/labs/lakebridge/cli.py b/src/databricks/labs/lakebridge/cli.py index a2d198f47..53c312ac5 100644 --- a/src/databricks/labs/lakebridge/cli.py +++ b/src/databricks/labs/lakebridge/cli.py @@ -73,7 +73,6 @@ def raise_validation_exception(msg: str) -> NoReturn: def _create_warehouse(ws: WorkspaceClient) -> str: - dbsql = ws.warehouses.create_and_wait( name=f"lakebridge-warehouse-{time.time_ns()}", warehouse_type=CreateWarehouseRequestWarehouseType.PRO, @@ -838,7 +837,6 @@ def _validate_llm_transpile_args( source_dialect: str | None, prompts: Prompts, ) -> tuple[str, str, str]: - _switch_dialects = get_switch_dialects() # Validate presence after attempting to source from config @@ -879,6 +877,7 @@ def llm_transpile( schema_name: str | None = None, volume: str | None = None, foundation_model: str | None = None, + output_sdp: bool = False, ctx: ApplicationContext | None = None, ) -> None: """Transpile source code to Databricks using LLM Transpiler (Switch)""" @@ -960,6 +959,7 @@ def llm_transpile( schema=schema_name, foundation_model=foundation_model, job_id=job_id, + output_sdp=output_sdp, ) diff --git a/src/databricks/labs/lakebridge/transpiler/switch_runner.py b/src/databricks/labs/lakebridge/transpiler/switch_runner.py index 4c7ba86b3..d9f095d85 100644 --- a/src/databricks/labs/lakebridge/transpiler/switch_runner.py +++ b/src/databricks/labs/lakebridge/transpiler/switch_runner.py @@ -30,9 +30,14 @@ def run( schema: str, foundation_model: str, job_id: int, + output_sdp: bool = False, ) -> RootJsonValue: """Trigger Switch job.""" + switch_options = {} + if output_sdp: + switch_options["output_sdp"] = "true" + job_params = self._build_job_parameters( input_dir=volume_input_path, output_dir=output_ws_folder, @@ -40,6 +45,7 @@ def run( catalog=catalog, schema=schema, foundation_model=foundation_model, + switch_options=switch_options, ) logger.info(f"Triggering Switch job with job_id: {job_id}") @@ -55,7 +61,7 @@ def upload_to_volume( """Upload local files to UC Volume with unique timestamped path.""" now = datetime.now(timezone.utc) time_part = now.strftime("%Y%m%d%H%M%S") - random_part = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) + random_part = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) volume_base_path = f"/Volumes/{catalog}/{schema}/{volume}" volume_input_path = f"{volume_base_path}/input-{time_part}-{random_part}" @@ -63,11 +69,11 @@ def upload_to_volume( # File upload if local_path.is_file(): - if local_path.name.startswith('.'): + if local_path.name.startswith("."): logger.debug(f"Skipping hidden file: {local_path}") return volume_input_path volume_file_path = f"{volume_input_path}/{local_path.name}" - with open(local_path, 'rb') as f: + with open(local_path, "rb") as f: content = f.read() self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True) logger.debug(f"Uploaded: {local_path} -> {volume_file_path}") @@ -76,15 +82,15 @@ def upload_to_volume( else: for root, dirs, files in os.walk(local_path): # remove hidden directories - dirs[:] = [d for d in dirs if not d.startswith('.')] + dirs[:] = [d for d in dirs if not d.startswith(".")] # skip hidden files - files = [f for f in files if not f.startswith('.')] + files = [f for f in files if not f.startswith(".")] for file in files: local_file = Path(root) / file relative_path = local_file.relative_to(local_path) volume_file_path = f"{volume_input_path}/{relative_path}" - with open(local_file, 'rb') as f: + with open(local_file, "rb") as f: content = f.read() self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True) diff --git a/tests/unit/test_cli_llm_transpile.py b/tests/unit/test_cli_llm_transpile.py index bce9dd8de..22c766df0 100644 --- a/tests/unit/test_cli_llm_transpile.py +++ b/tests/unit/test_cli_llm_transpile.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from unittest.mock import create_autospec +from unittest.mock import create_autospec, MagicMock from typing import cast import pytest @@ -39,7 +39,7 @@ def make_mock_prompts(input_path: str, output_folder: str, source_dialect: str = def create_switch_workspace_client_mock() -> WorkspaceClient: ws = create_autospec(spec=WorkspaceClient, instance=True) - ws.config.host = 'https://workspace.databricks.com' + ws.config.host = "https://workspace.databricks.com" ws.files.upload.return_value = None ws.jobs.run_now.return_value.run_id = _RUN_ID ws.jobs.run_now_and_wait_result.return_value.run_id = _RUN_ID @@ -194,7 +194,6 @@ def test_llm_transpile_with_incorrect_output_parms( mock_installation_with_switch: MockInstallation, tmp_path: Path, ) -> None: - input_source = tmp_path / "input.sql" input_source.write_text("SELECT * FROM table1;") output_folder = "/Users/test/output" @@ -223,7 +222,6 @@ def test_llm_transpile_with_incorrect_dialect( mock_installation_with_switch: MockInstallation, tmp_path: Path, ) -> None: - input_source = tmp_path / "input.sql" input_source.write_text("SELECT * FROM table1;") output_folder = "/Workspace/Users/test/output" @@ -247,3 +245,113 @@ def test_llm_transpile_with_incorrect_dialect( error_msg = "Invalid value for '--source-dialect': 'agent_sql' must be one of: airflow, mssql, mysql, netezza, oracle, postgresql, redshift, snowflake, synapse, teradata" with pytest.raises(ValueError, match=rf"{error_msg}"): cli.llm_transpile(w=mock_ws, accept_terms=True, source_dialect="agent_sql", ctx=ctx) + + +def test_llm_transpile_with_output_sdp_flag( + mock_installation_with_switch: MockInstallation, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that output_sdp flag is properly passed to job parameters.""" + input_source = tmp_path / "input.sql" + input_source.write_text("SELECT * FROM table1;") + output_folder = "/Workspace/Users/test/output" + + # Use a dedicated WorkspaceClient mock tailored for SwitchRunner + mock_ws = create_switch_workspace_client_mock() + mock_configurator = mock_resource_configurator(mock_ws, make_mock_prompts(str(input_source), output_folder)) + + ctx = ApplicationContext(mock_ws) + ctx.replace( + installation=mock_installation_with_switch, + add_user_agent_extra=lambda w, *args, **kwargs: w, + resource_configurator=mock_configurator, + ) + + with caplog.at_level(logging.INFO): + cli.llm_transpile( + w=mock_ws, + accept_terms=True, + input_source=str(input_source), + output_ws_folder=output_folder, + source_dialect="mssql", + catalog_name="lakebridge", + schema_name="switch", + volume="switch_volume", + foundation_model="databricks-claude-sonnet-4-5", + output_sdp=True, + ctx=ctx, + ) + + # Verify that the job was called with the correct parameters including output_sdp + run_now_mock = cast(MagicMock, mock_ws.jobs.run_now) + run_now_mock.assert_called_once() + call_args = run_now_mock.call_args + job_params = call_args.kwargs["job_parameters"] + + # Verify output_sdp is in the job parameters + assert "output_sdp" in job_params + assert job_params["output_sdp"] == "true" + + # Verify other expected parameters are still present + assert job_params["source_tech"] == "mssql" + assert job_params["catalog"] == "lakebridge" + assert job_params["schema"] == "switch" + assert job_params["foundation_model"] == "databricks-claude-sonnet-4-5" + + expected_msg = ( + f"Switch LLM transpilation job started: https://workspace.databricks.com/jobs/{_JOB_ID}/runs/{_RUN_ID}" + ) + info_messages = [record.message for record in caplog.records if record.levelno == logging.INFO] + assert expected_msg in info_messages + + +def test_llm_transpile_without_output_sdp_flag( + mock_installation_with_switch: MockInstallation, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that output_sdp is not in job parameters when flag is not set.""" + input_source = tmp_path / "input.sql" + input_source.write_text("SELECT * FROM table1;") + output_folder = "/Workspace/Users/test/output" + + # Use a dedicated WorkspaceClient mock tailored for SwitchRunner + mock_ws = create_switch_workspace_client_mock() + mock_configurator = mock_resource_configurator(mock_ws, make_mock_prompts(str(input_source), output_folder)) + + ctx = ApplicationContext(mock_ws) + ctx.replace( + installation=mock_installation_with_switch, + add_user_agent_extra=lambda w, *args, **kwargs: w, + resource_configurator=mock_configurator, + ) + + with caplog.at_level(logging.INFO): + cli.llm_transpile( + w=mock_ws, + accept_terms=True, + input_source=str(input_source), + output_ws_folder=output_folder, + source_dialect="mssql", + catalog_name="lakebridge", + schema_name="switch", + volume="switch_volume", + foundation_model="databricks-claude-sonnet-4-5", + output_sdp=False, + ctx=ctx, + ) + + # Verify that the job was called + run_now_mock = cast(MagicMock, mock_ws.jobs.run_now) + run_now_mock.assert_called_once() + call_args = run_now_mock.call_args + job_params = call_args.kwargs["job_parameters"] + + # Verify output_sdp is NOT in the job parameters when flag is False + assert "output_sdp" not in job_params + + # Verify other expected parameters are still present + assert job_params["source_tech"] == "mssql" + assert job_params["catalog"] == "lakebridge" + assert job_params["schema"] == "switch"