Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/databricks/labs/lakebridge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def raise_validation_exception(msg: str) -> NoReturn:


def _create_warehouse(ws: WorkspaceClient) -> str:

dbsql = ws.warehouses.create_and_wait(
name=f"lakebridge-warehouse-{time.time_ns()}",
warehouse_type=CreateWarehouseRequestWarehouseType.PRO,
Expand Down Expand Up @@ -838,7 +837,6 @@ def _validate_llm_transpile_args(
source_dialect: str | None,
prompts: Prompts,
) -> tuple[str, str, str]:

_switch_dialects = get_switch_dialects()

# Validate presence after attempting to source from config
Expand Down Expand Up @@ -879,6 +877,7 @@ def llm_transpile(
schema_name: str | None = None,
volume: str | None = None,
foundation_model: str | None = None,
output_sdp: bool = False,
ctx: ApplicationContext | None = None,
) -> None:
"""Transpile source code to Databricks using LLM Transpiler (Switch)"""
Expand Down Expand Up @@ -960,6 +959,7 @@ def llm_transpile(
schema=schema_name,
foundation_model=foundation_model,
job_id=job_id,
output_sdp=output_sdp,
)


Expand Down
18 changes: 12 additions & 6 deletions src/databricks/labs/lakebridge/transpiler/switch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,22 @@ def run(
schema: str,
foundation_model: str,
job_id: int,
output_sdp: bool = False,
) -> RootJsonValue:
"""Trigger Switch job."""

switch_options = {}
if output_sdp:
switch_options["output_sdp"] = "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes need to go into SwitchDeployment _generate_switch_job_parameters method as well, which deploys the switch job during
databricks labs lakebridge install-transpile --include-llm-transpile true which deploys artifacts for switch to execute.


job_params = self._build_job_parameters(
input_dir=volume_input_path,
output_dir=output_ws_folder,
source_tech=source_tech,
catalog=catalog,
schema=schema,
foundation_model=foundation_model,
switch_options=switch_options,
)
logger.info(f"Triggering Switch job with job_id: {job_id}")

Expand All @@ -55,19 +61,19 @@ def upload_to_volume(
"""Upload local files to UC Volume with unique timestamped path."""
now = datetime.now(timezone.utc)
time_part = now.strftime("%Y%m%d%H%M%S")
random_part = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
random_part = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
volume_base_path = f"/Volumes/{catalog}/{schema}/{volume}"
volume_input_path = f"{volume_base_path}/input-{time_part}-{random_part}"

logger.info(f"Uploading {local_path} to {volume_input_path}...")

# File upload
if local_path.is_file():
if local_path.name.startswith('.'):
if local_path.name.startswith("."):
logger.debug(f"Skipping hidden file: {local_path}")
return volume_input_path
volume_file_path = f"{volume_input_path}/{local_path.name}"
with open(local_path, 'rb') as f:
with open(local_path, "rb") as f:
content = f.read()
self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True)
logger.debug(f"Uploaded: {local_path} -> {volume_file_path}")
Expand All @@ -76,15 +82,15 @@ def upload_to_volume(
else:
for root, dirs, files in os.walk(local_path):
# remove hidden directories
dirs[:] = [d for d in dirs if not d.startswith('.')]
dirs[:] = [d for d in dirs if not d.startswith(".")]
# skip hidden files
files = [f for f in files if not f.startswith('.')]
files = [f for f in files if not f.startswith(".")]
for file in files:
local_file = Path(root) / file
relative_path = local_file.relative_to(local_path)
volume_file_path = f"{volume_input_path}/{relative_path}"

with open(local_file, 'rb') as f:
with open(local_file, "rb") as f:
content = f.read()

self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True)
Expand Down
116 changes: 112 additions & 4 deletions tests/unit/test_cli_llm_transpile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from unittest.mock import create_autospec
from unittest.mock import create_autospec, MagicMock
from typing import cast

import pytest
Expand Down Expand Up @@ -39,7 +39,7 @@ def make_mock_prompts(input_path: str, output_folder: str, source_dialect: str =
def create_switch_workspace_client_mock() -> WorkspaceClient:
ws = create_autospec(spec=WorkspaceClient, instance=True)

ws.config.host = 'https://workspace.databricks.com'
ws.config.host = "https://workspace.databricks.com"
ws.files.upload.return_value = None
ws.jobs.run_now.return_value.run_id = _RUN_ID
ws.jobs.run_now_and_wait_result.return_value.run_id = _RUN_ID
Expand Down Expand Up @@ -194,7 +194,6 @@ def test_llm_transpile_with_incorrect_output_parms(
mock_installation_with_switch: MockInstallation,
tmp_path: Path,
) -> None:

input_source = tmp_path / "input.sql"
input_source.write_text("SELECT * FROM table1;")
output_folder = "/Users/test/output"
Expand Down Expand Up @@ -223,7 +222,6 @@ def test_llm_transpile_with_incorrect_dialect(
mock_installation_with_switch: MockInstallation,
tmp_path: Path,
) -> None:

input_source = tmp_path / "input.sql"
input_source.write_text("SELECT * FROM table1;")
output_folder = "/Workspace/Users/test/output"
Expand All @@ -247,3 +245,113 @@ def test_llm_transpile_with_incorrect_dialect(
error_msg = "Invalid value for '--source-dialect': 'agent_sql' must be one of: airflow, mssql, mysql, netezza, oracle, postgresql, redshift, snowflake, synapse, teradata"
with pytest.raises(ValueError, match=rf"{error_msg}"):
cli.llm_transpile(w=mock_ws, accept_terms=True, source_dialect="agent_sql", ctx=ctx)


def test_llm_transpile_with_output_sdp_flag(
mock_installation_with_switch: MockInstallation,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that output_sdp flag is properly passed to job parameters."""
input_source = tmp_path / "input.sql"
input_source.write_text("SELECT * FROM table1;")
output_folder = "/Workspace/Users/test/output"

# Use a dedicated WorkspaceClient mock tailored for SwitchRunner
mock_ws = create_switch_workspace_client_mock()
mock_configurator = mock_resource_configurator(mock_ws, make_mock_prompts(str(input_source), output_folder))

ctx = ApplicationContext(mock_ws)
ctx.replace(
installation=mock_installation_with_switch,
add_user_agent_extra=lambda w, *args, **kwargs: w,
resource_configurator=mock_configurator,
)

with caplog.at_level(logging.INFO):
cli.llm_transpile(
w=mock_ws,
accept_terms=True,
input_source=str(input_source),
output_ws_folder=output_folder,
source_dialect="mssql",
catalog_name="lakebridge",
schema_name="switch",
volume="switch_volume",
foundation_model="databricks-claude-sonnet-4-5",
output_sdp=True,
ctx=ctx,
)

# Verify that the job was called with the correct parameters including output_sdp
run_now_mock = cast(MagicMock, mock_ws.jobs.run_now)
run_now_mock.assert_called_once()
call_args = run_now_mock.call_args
job_params = call_args.kwargs["job_parameters"]

# Verify output_sdp is in the job parameters
assert "output_sdp" in job_params
assert job_params["output_sdp"] == "true"

# Verify other expected parameters are still present
assert job_params["source_tech"] == "mssql"
assert job_params["catalog"] == "lakebridge"
assert job_params["schema"] == "switch"
assert job_params["foundation_model"] == "databricks-claude-sonnet-4-5"

expected_msg = (
f"Switch LLM transpilation job started: https://workspace.databricks.com/jobs/{_JOB_ID}/runs/{_RUN_ID}"
)
info_messages = [record.message for record in caplog.records if record.levelno == logging.INFO]
assert expected_msg in info_messages


def test_llm_transpile_without_output_sdp_flag(
mock_installation_with_switch: MockInstallation,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that output_sdp is not in job parameters when flag is not set."""
input_source = tmp_path / "input.sql"
input_source.write_text("SELECT * FROM table1;")
output_folder = "/Workspace/Users/test/output"

# Use a dedicated WorkspaceClient mock tailored for SwitchRunner
mock_ws = create_switch_workspace_client_mock()
mock_configurator = mock_resource_configurator(mock_ws, make_mock_prompts(str(input_source), output_folder))

ctx = ApplicationContext(mock_ws)
ctx.replace(
installation=mock_installation_with_switch,
add_user_agent_extra=lambda w, *args, **kwargs: w,
resource_configurator=mock_configurator,
)

with caplog.at_level(logging.INFO):
cli.llm_transpile(
w=mock_ws,
accept_terms=True,
input_source=str(input_source),
output_ws_folder=output_folder,
source_dialect="mssql",
catalog_name="lakebridge",
schema_name="switch",
volume="switch_volume",
foundation_model="databricks-claude-sonnet-4-5",
output_sdp=False,
ctx=ctx,
)

# Verify that the job was called
run_now_mock = cast(MagicMock, mock_ws.jobs.run_now)
run_now_mock.assert_called_once()
call_args = run_now_mock.call_args
job_params = call_args.kwargs["job_parameters"]

# Verify output_sdp is NOT in the job parameters when flag is False
assert "output_sdp" not in job_params

# Verify other expected parameters are still present
assert job_params["source_tech"] == "mssql"
assert job_params["catalog"] == "lakebridge"
assert job_params["schema"] == "switch"
Loading