diff --git a/.gitignore b/.gitignore index 466c618f0..12de8f342 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ remorph_transpile/ /linter/src/main/antlr4/library/gen/ .databricks-login.json .mypy_cache +.env diff --git a/labs.yml b/labs.yml index 5a0cd5805..3209d5cd7 100644 --- a/labs.yml +++ b/labs.yml @@ -42,6 +42,19 @@ commands: {{range .}}{{.total_files_processed}}\t{{.total_queries_processed}}\t{{.analysis_error_count}}\t{{.parsing_error_count}}\t{{.validation_error_count}}\t{{.generation_error_count}}\t{{.error_log_file}} {{end}} + - name: llm-transpile + description: Transpile source code to Databricks using LLM Transpiler (Switch) + flags: + - name: input-source + description: Input Script Folder or File (local path) + default: null + - name: output-ws-folder + description: Output folder path (Databricks Workspace path starting with /Workspace/) + default: null + - name: source-dialect + description: Source dialect name (e.g., 'snowflake', 'teradata') + default: null + - name: reconcile description: Reconcile source and target data residing on Databricks @@ -59,6 +72,9 @@ commands: - name: interactive description: (Optional) Whether installing in interactive mode (`true|false|auto`); configuration settings are prompted for when interactive default: auto + - name: include-llm-transpiler + description: (Optional) Whether to include LLM-based transpiler in installation ('true'|'false') + default: false - name: describe-transpile description: Describe installed transpilers diff --git a/pyproject.toml b/pyproject.toml index cdd928ad2..61cf4550f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -447,7 +447,7 @@ bad-functions = ["map", "input"] # ignored-parents = # Maximum number of arguments for function / method. -max-args = 12 +max-args = 13 # Maximum number of attributes for a class (see R0902). max-attributes = 13 diff --git a/src/databricks/labs/lakebridge/cli.py b/src/databricks/labs/lakebridge/cli.py index 36a5e7461..ce9008dbf 100644 --- a/src/databricks/labs/lakebridge/cli.py +++ b/src/databricks/labs/lakebridge/cli.py @@ -17,6 +17,7 @@ from databricks.labs.blueprint.cli import App from databricks.labs.blueprint.entrypoint import get_logger, is_in_debug from databricks.labs.blueprint.installation import RootJsonValue +from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.tui import Prompts @@ -33,9 +34,10 @@ from databricks.labs.lakebridge.reconcile.recon_config import RECONCILE_OPERATION_NAME, AGG_RECONCILE_OPERATION_NAME from databricks.labs.lakebridge.transpiler.describe import TranspilersDescription from databricks.labs.lakebridge.transpiler.execute import transpile as do_transpile -from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import LSPEngine +from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import LSPConfig, LSPEngine from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository from databricks.labs.lakebridge.transpiler.sqlglot.sqlglot_engine import SqlglotEngine +from databricks.labs.lakebridge.transpiler.switch_runner import SwitchConfig, SwitchRunner from databricks.labs.lakebridge.transpiler.transpile_engine import TranspileEngine from databricks.labs.lakebridge.transpiler.transpile_status import ErrorSeverity @@ -530,6 +532,234 @@ def _override_workspace_client_config(ctx: ApplicationContext, overrides: dict[s ctx.connect_config.cluster_id = cluster_id +@lakebridge.command +def llm_transpile( + *, + w: WorkspaceClient, + input_source: str | None = None, + output_ws_folder: str | None = None, + source_dialect: str | None = None, + transpiler_repository: TranspilerRepository = TranspilerRepository.user_home(), +) -> None: + """Transpile source code to Databricks using LLM Transpiler (Switch)""" + ctx = ApplicationContext(w) + ctx.add_user_agent_extra("cmd", "llm-transpile") + user = ctx.current_user + logger.debug(f"User: {user}") + + checker = _LLMTranspileConfigChecker(ctx.transpile_config, ctx.prompts, ctx.install_state, transpiler_repository) + checker.use_input_source(input_source) + checker.use_output_ws_folder(output_ws_folder) + checker.use_source_dialect(source_dialect) + params = checker.check() + + result = _llm_transpile(ctx, params) + print(json.dumps(result)) + + +class _LLMTranspileConfigChecker: + """Helper class for 'llm-transpile' command configuration validation""" + + _transpile_config: TranspileConfig | None + _prompts: Prompts + _install_state: InstallState + _transpiler_repository: TranspilerRepository + _input_source: str | None = None + _output_ws_folder: str | None = None + _source_dialect: str | None = None + + def __init__( + self, + transpile_config: TranspileConfig | None, + prompts: Prompts, + install_state: InstallState, + transpiler_repository: TranspilerRepository, + ): + self._transpile_config = transpile_config + self._prompts = prompts + self._install_state = install_state + self._transpiler_repository = transpiler_repository + + @staticmethod + def _validate_input_source_path(input_source: str, msg: str) -> None: + """Validate the input source: it must be a path that exists.""" + if not Path(input_source).exists(): + raise_validation_exception(msg) + + def use_input_source(self, input_source: str | None) -> None: + if input_source is not None: + logger.debug(f"Setting input_source to: {input_source!r}") + self._validate_input_source_path(input_source, f"Invalid path for '--input-source': {input_source}") + self._input_source = input_source + + def _prompt_input_source(self) -> None: + default_input = None + if self._transpile_config and self._transpile_config.input_source: + default_input = self._transpile_config.input_source + + if default_input: + prompt_text = f"Enter input source path (press for default: {default_input})" + prompted = self._prompts.question(prompt_text).strip() + self._input_source = prompted if prompted else default_input + else: + prompted = self._prompts.question("Enter input source path (directory or file)").strip() + self._input_source = prompted + + logger.debug(f"Setting input_source to: {self._input_source!r}") + self._validate_input_source_path(self._input_source, f"Invalid input source: {self._input_source}") + + def _check_input_source(self) -> None: + if self._input_source is None: + self._prompt_input_source() + + def use_output_ws_folder(self, output_ws_folder: str | None) -> None: + if output_ws_folder is not None: + logger.debug(f"Setting output_ws_folder to: {output_ws_folder!r}") + self._validate_output_ws_folder_path( + output_ws_folder, f"Invalid path for '--output-ws-folder': {output_ws_folder}" + ) + self._output_ws_folder = output_ws_folder + + @staticmethod + def _validate_output_ws_folder_path(output_ws_folder: str, msg: str) -> None: + """Validate output folder is a Workspace path.""" + if not output_ws_folder.startswith("/Workspace/"): + raise_validation_exception(f"{msg}. Must start with /Workspace/") + + def _prompt_output_ws_folder(self) -> None: + prompted_output_ws_folder = self._prompts.question( + "Enter output folder path (Databricks Workspace path starting with /Workspace/)" + ).strip() + logger.debug(f"Setting output_ws_folder to: {prompted_output_ws_folder!r}") + self._validate_output_ws_folder_path( + prompted_output_ws_folder, f"Invalid output folder: {prompted_output_ws_folder}" + ) + self._output_ws_folder = prompted_output_ws_folder + + def _check_output_ws_folder(self) -> None: + if self._output_ws_folder is None: + self._prompt_output_ws_folder() + + def use_source_dialect(self, source_dialect: str | None) -> None: + if source_dialect is not None: + logger.debug(f"Setting source_dialect to: {source_dialect!r}") + self._source_dialect = source_dialect + + def _prompt_source_dialect(self) -> None: + """Prompt for source dialect from Switch dialects.""" + available_dialects = self._get_switch_dialects() + + if not available_dialects: + raise_validation_exception( + "No Switch dialects available. " + "Install with: databricks labs lakebridge install-transpile --include-llm-transpiler" + ) + + logger.debug(f"Available dialects: {available_dialects!r}") + source_dialect = self._prompts.choice("Select the source dialect:", list(sorted(available_dialects))) + + self._source_dialect = source_dialect + + def _check_source_dialect(self) -> None: + """Validate and prompt for source dialect if not provided.""" + available_dialects = self._get_switch_dialects() + + if self._source_dialect is None: + self._prompt_source_dialect() + elif self._source_dialect not in available_dialects: + supported = ", ".join(sorted(available_dialects)) + raise_validation_exception(f"Invalid source-dialect: '{self._source_dialect}'. " f"Available: {supported}") + + def _get_switch_dialects(self) -> set[str]: + """Get Switch dialects from config.yml using LSPConfig.""" + config_path = self._transpiler_repository.transpiler_config_path("Switch") + if not config_path.exists(): + return set() + + try: + lsp_config = LSPConfig.load(config_path) + return set(lsp_config.remorph.dialects) + except (OSError, ValueError) as e: + logger.warning(f"Failed to load Switch dialects: {e}") + return set() + + def _get_switch_options_with_defaults(self) -> dict[str, str]: + """Get default values for Switch options from config.yml.""" + config_path = self._transpiler_repository.transpiler_config_path("Switch") + if not config_path.exists(): + return {} + + try: + lsp_config = LSPConfig.load(config_path) + except (OSError, ValueError) as e: + logger.warning(f"Failed to load Switch options: {e}") + return {} + + options_all = lsp_config.options_for_dialect("all") + result = {} + for option in options_all: + if option.default and option.default != "": + result[option.flag] = option.default + + logger.debug(f"Loaded {len(result)} Switch options with defaults from config.yml") + return result + + def _validate_switch_options(self, options: dict[str, str]) -> None: + """Validate options against config.yml choices.""" + config_path = self._transpiler_repository.transpiler_config_path("Switch") + if not config_path.exists(): + return + + try: + lsp_config = LSPConfig.load(config_path) + except (OSError, ValueError) as e: + logger.warning(f"Failed to validate Switch options: {e}") + return + + options_all = lsp_config.options_for_dialect("all") + for option in options_all: + if option.flag in options and option.choices: + value = options[option.flag] + if value not in option.choices: + raise_validation_exception( + f"Invalid value for '{option.flag}': {value!r}. " f"Must be one of: {', '.join(option.choices)}" + ) + + def check(self) -> dict: + """Validate all parameters and return configuration dict.""" + logger.debug("Checking llm-transpile configuration") + + self._check_input_source() + self._check_output_ws_folder() + self._check_source_dialect() + + switch_options = self._get_switch_options_with_defaults() + self._validate_switch_options(switch_options) + + wait_for_completion = str(switch_options.pop("wait_for_completion", "false")).lower() == "true" + + return { + "input_source": self._input_source, + "output_ws_folder": self._output_ws_folder, + "source_dialect": self._source_dialect, + "switch_options": switch_options, + "wait_for_completion": wait_for_completion, + } + + +def _llm_transpile(ctx: ApplicationContext, params: dict) -> RootJsonValue: + """Execute LLM transpilation via Switch job.""" + config = SwitchConfig(ctx.install_state) + resources = config.get_resources() + job_id = config.get_job_id() + + runner = SwitchRunner(ctx.workspace_client, ctx.installation) + + return runner.run( + catalog=resources["catalog"], schema=resources["schema"], volume=resources["volume"], job_id=job_id, **params + ) + + @lakebridge.command def reconcile(*, w: WorkspaceClient) -> None: """[EXPERIMENTAL] Reconciles source to Databricks datasets""" @@ -623,6 +853,7 @@ def install_transpile( w: WorkspaceClient, artifact: str | None = None, interactive: str | None = None, + include_llm_transpiler: bool = False, transpiler_repository: TranspilerRepository = TranspilerRepository.user_home(), ) -> None: """Install or upgrade the Lakebridge transpilers.""" @@ -631,9 +862,13 @@ def install_transpile( ctx.add_user_agent_extra("cmd", "install-transpile") if artifact: ctx.add_user_agent_extra("artifact-overload", Path(artifact).name) + if include_llm_transpiler: + ctx.add_user_agent_extra("include-llm-transpiler", "true") user = w.current_user logger.debug(f"User: {user}") - transpile_installer = installer(w, transpiler_repository, is_interactive=is_interactive) + transpile_installer = installer( + w, transpiler_repository, is_interactive=is_interactive, include_llm=include_llm_transpiler + ) transpile_installer.run(module="transpile", artifact=artifact) diff --git a/src/databricks/labs/lakebridge/config.py b/src/databricks/labs/lakebridge/config.py index 9b5d0d418..85e9f840d 100644 --- a/src/databricks/labs/lakebridge/config.py +++ b/src/databricks/labs/lakebridge/config.py @@ -140,6 +140,13 @@ def prompt_for_value(self, prompts: Prompts) -> JsonValue: raise ValueError(f"Unsupported prompt method: {self.method}") +@dataclass +class SwitchResourcesConfig: + catalog: str + schema: str + volume: str + + @dataclass class TranspileConfig: __file__ = "config.yml" @@ -152,9 +159,11 @@ class TranspileConfig: error_file_path: str | None = None sdk_config: dict[str, str] | None = None skip_validation: bool = False + include_llm: bool = False catalog_name: str = "remorph" schema_name: str = "transpiler" transpiler_options: JsonValue = None + switch_resources: SwitchResourcesConfig | None = None @property def transpiler_path(self) -> Path | None: diff --git a/src/databricks/labs/lakebridge/contexts/application.py b/src/databricks/labs/lakebridge/contexts/application.py index f9e0875d8..4aecfd327 100644 --- a/src/databricks/labs/lakebridge/contexts/application.py +++ b/src/databricks/labs/lakebridge/contexts/application.py @@ -18,7 +18,9 @@ from databricks.labs.lakebridge.deployment.dashboard import DashboardDeployment from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.deployment.recon import TableDeployment, JobDeployment, ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment from databricks.labs.lakebridge.helpers.metastore import CatalogOperations +from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository logger = logging.getLogger(__name__) @@ -119,6 +121,17 @@ def recon_deployment(self) -> ReconDeployment: self.dashboard_deployment, ) + @cached_property + def switch_deployment(self) -> SwitchDeployment: + return SwitchDeployment( + self.workspace_client, + self.installation, + self.install_state, + self.product_info, + self.job_deployment, + TranspilerRepository.user_home(), + ) + @cached_property def workspace_installation(self) -> WorkspaceInstallation: return WorkspaceInstallation( @@ -126,6 +139,7 @@ def workspace_installation(self) -> WorkspaceInstallation: self.prompts, self.installation, self.recon_deployment, + self.switch_deployment, self.product_info, self.upgrades, ) diff --git a/src/databricks/labs/lakebridge/deployment/installation.py b/src/databricks/labs/lakebridge/deployment/installation.py index 7ff283f0e..709e8bc7c 100644 --- a/src/databricks/labs/lakebridge/deployment/installation.py +++ b/src/databricks/labs/lakebridge/deployment/installation.py @@ -13,6 +13,7 @@ from databricks.labs.lakebridge.config import LakebridgeConfiguration from databricks.labs.lakebridge.deployment.recon import ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment logger = logging.getLogger("databricks.labs.lakebridge.install") @@ -24,6 +25,7 @@ def __init__( prompts: Prompts, installation: Installation, recon_deployment: ReconDeployment, + switch_deployment: SwitchDeployment, product_info: ProductInfo, upgrades: Upgrades, ): @@ -31,6 +33,7 @@ def __init__( self._prompts = prompts self._installation = installation self._recon_deployment = recon_deployment + self._switch_deployment = switch_deployment self._product_info = product_info self._upgrades = upgrades @@ -96,6 +99,16 @@ def install(self, config: LakebridgeConfiguration): if config.reconcile: logger.info("Installing Lakebridge reconcile Metadata components.") self._recon_deployment.install(config.reconcile, wheel_path) + if config.transpile and config.transpile.include_llm: + resources = config.transpile.switch_resources + if resources is None: + logger.error( + "Switch resources are missing. Run `lakebridge install-transpile --include-llm-transpiler true` " + "with interactive prompts to capture the Switch catalog, schema, and volume before retrying." + ) + else: + logger.info("Installing Switch transpiler to workspace.") + self._switch_deployment.install(resources) def uninstall(self, config: LakebridgeConfiguration): # This will remove all the Lakebridge modules @@ -116,9 +129,22 @@ def uninstall(self, config: LakebridgeConfiguration): f"Won't remove transpile validation schema `{config.transpile.schema_name}` " f"from catalog `{config.transpile.catalog_name}`. Please remove it manually." ) + self._uninstall_switch_job() if config.reconcile: self._recon_deployment.uninstall(config.reconcile) self._installation.remove() logger.info("Uninstallation completed successfully.") + + def _uninstall_switch_job(self) -> None: + """Remove Switch transpiler job if exists.""" + resources = self._switch_deployment.get_configured_resources() + self._switch_deployment.uninstall() + + if resources: + logger.info( + f"Won't remove Switch resources: catalog=`{resources['catalog']}`, " + f"schema=`{resources['schema']}`, volume=`{resources['volume']}`. " + "Please remove them manually if needed." + ) diff --git a/src/databricks/labs/lakebridge/deployment/switch.py b/src/databricks/labs/lakebridge/deployment/switch.py new file mode 100644 index 000000000..63aae1ff0 --- /dev/null +++ b/src/databricks/labs/lakebridge/deployment/switch.py @@ -0,0 +1,237 @@ +import logging +import os +import sys +from pathlib import Path + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.lakebridge.deployment.job import JobDeployment +from databricks.labs.lakebridge.config import SwitchResourcesConfig +from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue, NotFound +from databricks.sdk.service.jobs import JobParameterDefinition, JobSettings, NotebookTask, Source, Task + +logger = logging.getLogger(__name__) + + +class SwitchDeployment: + _INSTALL_STATE_KEY = "Switch" + _TRANSPILER_ID = "switch" + + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + product_info: ProductInfo, + job_deployer: JobDeployment, + transpiler_repository: TranspilerRepository, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + self._product_info = product_info + self._job_deployer = job_deployer + self._transpiler_repository = transpiler_repository + + def install(self, resources: SwitchResourcesConfig) -> None: + """Deploy Switch to workspace and configure resources.""" + logger.info("Deploying Switch to workspace...") + self._deploy_workspace(self._get_switch_package_path()) + self._setup_job() + self._record_resources(resources) + logger.info("Switch deployment completed") + + def uninstall(self) -> None: + """Remove Switch job from workspace.""" + if self._INSTALL_STATE_KEY not in self._install_state.jobs: + logger.info("No Switch job found in InstallState") + return + + try: + job_id = int(self._install_state.jobs[self._INSTALL_STATE_KEY]) + logger.info(f"Removing Switch job with job_id={job_id}") + del self._install_state.jobs[self._INSTALL_STATE_KEY] + self._ws.jobs.delete(job_id) + self._install_state.save() + except (InvalidParameterValue, NotFound): + logger.warning(f"Switch job {job_id} doesn't exist anymore") + self._install_state.save() + + def get_configured_resources(self) -> dict[str, str] | None: + """Get configured Switch resources (catalog, schema, volume).""" + if self._install_state.switch_resources: + return { + "catalog": self._install_state.switch_resources.get("catalog"), + "schema": self._install_state.switch_resources.get("schema"), + "volume": self._install_state.switch_resources.get("volume"), + } + return None + + def _deploy_workspace(self, switch_package_dir: Path) -> None: + """Deploy Switch package to workspace from site-packages.""" + try: + logger.info("Deploying Switch package to workspace...") + remote_path = f"{self._TRANSPILER_ID}/databricks" + self._upload_directory(switch_package_dir, remote_path) + logger.info("Switch workspace deployment completed") + except (OSError, ValueError, AttributeError) as e: + logger.error(f"Failed to deploy to workspace: {e}") + + def _upload_directory(self, local_path: Path, remote_prefix: str) -> None: + """Recursively upload directory to workspace, excluding cache files.""" + for root, dirs, files in os.walk(local_path): + # Skip cache directories and hidden directories + dirs[:] = [d for d in dirs if d != "__pycache__" and not d.startswith(".")] + + for file in files: + # Skip compiled Python files and hidden files + if file.endswith((".pyc", ".pyo")) or file.startswith("."): + continue + + local_file = Path(root) / file + rel_path = local_file.relative_to(local_path) + remote_path = f"{remote_prefix}/{rel_path}" + + with open(local_file, "rb") as f: + content = f.read() + + self._installation.upload(remote_path, content) + + def _setup_job(self) -> None: + """Create or update Switch job.""" + existing_job_id = self._get_existing_job_id() + logger.info("Setting up Switch job in workspace...") + try: + job_id = self._create_or_update_switch_job(existing_job_id) + self._install_state.jobs[self._INSTALL_STATE_KEY] = job_id + self._install_state.save() + job_url = f"{self._ws.config.host}/jobs/{job_id}" + logger.info(f"Switch job created/updated: {job_url}") + except (RuntimeError, ValueError, InvalidParameterValue) as e: + logger.error(f"Failed to create/update Switch job: {e}") + + def _get_existing_job_id(self) -> str | None: + """Check if Switch job already exists in workspace.""" + if self._INSTALL_STATE_KEY not in self._install_state.jobs: + return None + try: + job_id = self._install_state.jobs[self._INSTALL_STATE_KEY] + self._ws.jobs.get(int(job_id)) + return job_id + except (InvalidParameterValue, NotFound, ValueError): + return None + + def _create_or_update_switch_job(self, job_id: str | None) -> str: + """Create or update Switch job, returning job ID.""" + job_settings = self._get_switch_job_settings() + + # Try to update existing job + if job_id: + try: + logger.info(f"Updating Switch job: {job_id}") + self._ws.jobs.reset(int(job_id), JobSettings(**job_settings)) + return job_id + except (ValueError, InvalidParameterValue): + logger.warning("Previous Switch job not found, creating new one") + + # Create new job + logger.info("Creating new Switch job") + new_job = self._ws.jobs.create(**job_settings) + new_job_id = str(new_job.job_id) + assert new_job_id is not None + return new_job_id + + def _get_switch_job_settings(self) -> dict: + """Build job settings for Switch transpiler.""" + product = self._installation.product() + job_name = f"{product.upper()}_Switch" + version = ProductInfo.from_class(self.__class__).version() + user_name = self._installation.username() + notebook_path = ( + f"/Workspace/Users/{user_name}/.{product}/{self._TRANSPILER_ID}/" + f"databricks/labs/switch/notebooks/00_main" + ) + + task = Task( + task_key="run_transpilation", + notebook_task=NotebookTask( + notebook_path=notebook_path, + source=Source.WORKSPACE, + ), + disable_auto_optimization=True, # To disable retries on failure + ) + + return { + "name": job_name, + "tags": {"created_by": user_name, "switch_version": f"v{version}"}, + "tasks": [task], + "parameters": self._get_switch_job_parameters(), + "max_concurrent_runs": 100, # Allow simultaneous transpilations + } + + def _get_switch_job_parameters(self) -> list[JobParameterDefinition]: + """Build job-level parameter definitions from installed config.yml.""" + configs = self._transpiler_repository.all_transpiler_configs() + config = configs.get(self._INSTALL_STATE_KEY) or configs.get(self._TRANSPILER_ID) + + if not config: + raise ValueError( + "Switch config.yml not found. This indicates an incomplete installation. " + "Please reinstall Switch transpiler." + ) + + # Add required runtime parameters not in config at the beginning + parameters = { + "input_dir": "", + "output_dir": "", + "result_catalog": "", + "result_schema": "", + "builtin_prompt": "", + } + + # Options to exclude from job parameters (local execution only) + excluded_options = {"wait_for_completion"} + + # Then add parameters from config.yml + for option in config.options.get("all", []): + flag = option.flag + + # Skip local execution-only options + if flag in excluded_options: + continue + + default = option.default or "" + + # Convert special values + if default == "": + default = "" + elif isinstance(default, (int, float)): + default = str(default) + + parameters[flag] = default + + return [JobParameterDefinition(name=key, default=value) for key, value in parameters.items()] + + def _record_resources(self, resources: SwitchResourcesConfig) -> None: + """Persist configured Switch resources for later reuse.""" + self._install_state.switch_resources["catalog"] = resources.catalog + self._install_state.switch_resources["schema"] = resources.schema + self._install_state.switch_resources["volume"] = resources.volume + self._install_state.save() + logger.info( + f"Switch resources stored: catalog=`{resources.catalog}`, " + f"schema=`{resources.schema}`, volume=`{resources.volume}`" + ) + + def _get_switch_package_path(self) -> Path: + """Get Switch package path (databricks directory) from site-packages.""" + product_path = self._transpiler_repository.transpilers_path() / self._TRANSPILER_ID + venv_path = product_path / "lib" / ".venv" + + if sys.platform != "win32": + major, minor = sys.version_info[:2] + return venv_path / "lib" / f"python{major}.{minor}" / "site-packages" / "databricks" + return venv_path / "Lib" / "site-packages" / "databricks" diff --git a/src/databricks/labs/lakebridge/install.py b/src/databricks/labs/lakebridge/install.py index 7c9edf082..693ba5214 100644 --- a/src/databricks/labs/lakebridge/install.py +++ b/src/databricks/labs/lakebridge/install.py @@ -20,6 +20,7 @@ LakebridgeConfiguration, ReconcileMetadataConfig, TranspileConfig, + SwitchResourcesConfig, ) from databricks.labs.lakebridge.contexts.application import ApplicationContext from databricks.labs.lakebridge.deployment.configurator import ResourceConfigurator @@ -28,6 +29,7 @@ from databricks.labs.lakebridge.transpiler.installers import ( BladebridgeInstaller, MorpheusInstaller, + SwitchInstaller, TranspilerInstaller, ) from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository @@ -50,10 +52,12 @@ def __init__( environ: dict[str, str] | None = None, *, is_interactive: bool = True, + include_llm: bool = False, transpiler_repository: TranspilerRepository = TranspilerRepository.user_home(), transpiler_installers: Sequence[Callable[[TranspilerRepository], TranspilerInstaller]] = ( BladebridgeInstaller, MorpheusInstaller, + SwitchInstaller, ), ): self._ws = ws @@ -65,6 +69,7 @@ def __init__( self._ws_installation = workspace_installation # TODO: Refactor the 'prompts' property in preference to using this flag, which should be redundant. self._is_interactive = is_interactive + self._include_llm = include_llm self._transpiler_repository = transpiler_repository self._transpiler_installer_factories = transpiler_installers @@ -77,7 +82,12 @@ def __init__( @property def _transpiler_installers(self) -> Set[TranspilerInstaller]: - return frozenset(factory(self._transpiler_repository) for factory in self._transpiler_installer_factories) + factories = self._transpiler_installer_factories + if not self._include_llm: + if SwitchInstaller in factories: + logger.info("Skipping Switch installation (LLM transpiler not requested)") + factories = tuple(f for f in factories if f != SwitchInstaller) + return frozenset(factory(self._transpiler_repository) for factory in factories) def run( self, @@ -92,7 +102,7 @@ def run( for transpiler_installer in self._transpiler_installers: transpiler_installer.install() if not config: - config = self.configure(module) + config = self.configure(module, self._include_llm) if self._is_testing(): return config self._ws_installation.install(config) @@ -129,18 +139,18 @@ def _install_artifact(self, artifact: str) -> None: else: logger.fatal(f"Cannot install unsupported artifact: {artifact}") - def configure(self, module: str) -> LakebridgeConfiguration: + def configure(self, module: str, include_llm: bool = False) -> LakebridgeConfiguration: match module: case "transpile": logger.info("Configuring lakebridge `transpile`.") - return LakebridgeConfiguration(self._configure_transpile(), None) + return LakebridgeConfiguration(self._configure_transpile(include_llm), None) case "reconcile": logger.info("Configuring lakebridge `reconcile`.") return LakebridgeConfiguration(None, self._configure_reconcile()) case "all": logger.info("Configuring lakebridge `transpile` and `reconcile`.") return LakebridgeConfiguration( - self._configure_transpile(), + self._configure_transpile(include_llm), self._configure_reconcile(), ) case _: @@ -149,7 +159,7 @@ def configure(self, module: str) -> LakebridgeConfiguration: def _is_testing(self): return self._product_info.product_name() != "lakebridge" - def _configure_transpile(self) -> TranspileConfig | None: + def _configure_transpile(self, include_llm: bool = False) -> TranspileConfig | None: try: config = self._installation.load(TranspileConfig) logger.info("Lakebridge `transpile` is already installed on this workspace.") @@ -170,12 +180,12 @@ def _configure_transpile(self) -> TranspileConfig | None: logger.warning("Installation is not interactive, skipping configuration of transpilers.") return None - config = self._configure_new_transpile_installation() + config = self._configure_new_transpile_installation(include_llm) logger.info("Finished configuring lakebridge `transpile`.") return config - def _configure_new_transpile_installation(self) -> TranspileConfig: - default_config = self._prompt_for_new_transpile_installation() + def _configure_new_transpile_installation(self, include_llm: bool = False) -> TranspileConfig: + default_config = self._prompt_for_new_transpile_installation(include_llm) runtime_config = None catalog_name = "remorph" schema_name = "transpiler" @@ -204,7 +214,7 @@ def _transpilers_with_dialect(self, dialect: str) -> list[str]: def _transpiler_config_path(self, transpiler: str) -> Path: return self._transpiler_repository.transpiler_config_path(transpiler) - def _prompt_for_new_transpile_installation(self) -> TranspileConfig: + def _prompt_for_new_transpile_installation(self, include_llm: bool = False) -> TranspileConfig: install_later = "Set it later" # TODO tidy this up, logger might not display the below in console... logger.info("Please answer a few questions to configure lakebridge `transpile`") @@ -212,20 +222,7 @@ def _prompt_for_new_transpile_installation(self) -> TranspileConfig: source_dialect: str | None = self._prompts.choice("Select the source dialect:", all_dialects, sort=False) if source_dialect == install_later: source_dialect = None - transpiler_name: str | None = None - transpiler_config_path: Path | None = None - if source_dialect: - transpilers = self._transpilers_with_dialect(source_dialect) - if len(transpilers) > 1: - transpilers = [install_later] + transpilers - transpiler_name = self._prompts.choice("Select the transpiler:", transpilers, sort=False) - if transpiler_name == install_later: - transpiler_name = None - else: - transpiler_name = next(t for t in transpilers) - logger.info(f"Lakebridge will use the {transpiler_name} transpiler") - if transpiler_name: - transpiler_config_path = self._transpiler_config_path(transpiler_name) + transpiler_name, transpiler_config_path = self._get_transpiler_config(install_later, source_dialect) transpiler_options: dict[str, JsonValue] | None = None if transpiler_config_path: transpiler_options = self._prompt_for_transpiler_options( @@ -248,22 +245,53 @@ def _prompt_for_new_transpile_installation(self) -> TranspileConfig: "Would you like to validate the syntax and semantics of the transpiled queries?" ) + switch_resources = None + if include_llm: + switch_resources = self._prompt_for_switch_resources() + return TranspileConfig( transpiler_config_path=str(transpiler_config_path) if transpiler_config_path is not None else None, transpiler_options=transpiler_options, source_dialect=source_dialect, skip_validation=(not run_validation), + include_llm=include_llm, input_source=input_source, output_folder=output_folder, error_file_path=error_file_path, + switch_resources=switch_resources, ) + def _get_transpiler_config(self, install_later, source_dialect): + transpiler_name: str | None = None + transpiler_config_path: Path | None = None + if source_dialect: + transpilers = self._transpilers_with_dialect(source_dialect) + if len(transpilers) > 1: + transpilers = [install_later] + transpilers + transpiler_name = self._prompts.choice("Select the transpiler:", transpilers, sort=False) + if transpiler_name == install_later: + transpiler_name = None + else: + transpiler_name = next(t for t in transpilers) + logger.info(f"Lakebridge will use the {transpiler_name} transpiler") + if transpiler_name: + transpiler_config_path = self._transpiler_config_path(transpiler_name) + return transpiler_name, transpiler_config_path + def _prompt_for_transpiler_options(self, transpiler_name: str, source_dialect: str) -> dict[str, Any] | None: config_options = self._transpiler_repository.transpiler_config_options(transpiler_name, source_dialect) if len(config_options) == 0: return None return {option.flag: option.prompt_for_value(self._prompts) for option in config_options} + def _prompt_for_switch_resources(self) -> SwitchResourcesConfig: + logger.info("Configuring Switch resources (catalog, schema, volume)...") + catalog = self._resource_configurator.prompt_for_catalog_setup() + schema = self._resource_configurator.prompt_for_schema_setup(catalog, "switch") + volume = self._resource_configurator.prompt_for_volume_setup(catalog, schema, "switch_volume") + self._has_necessary_access(catalog, schema, volume) + return SwitchResourcesConfig(catalog=catalog, schema=schema, volume=volume) + def _configure_catalog(self) -> str: return self._resource_configurator.prompt_for_catalog_setup() @@ -386,6 +414,7 @@ def installer( transpiler_repository: TranspilerRepository, *, is_interactive: bool, + include_llm: bool = False, ) -> WorkspaceInstaller: app_context = ApplicationContext(_verify_workspace_client(ws)) return WorkspaceInstaller( @@ -398,6 +427,7 @@ def installer( app_context.workspace_installation, transpiler_repository=transpiler_repository, is_interactive=is_interactive, + include_llm=include_llm, ) diff --git a/src/databricks/labs/lakebridge/transpiler/installers.py b/src/databricks/labs/lakebridge/transpiler/installers.py index 8bdecb6f5..718efddcb 100644 --- a/src/databricks/labs/lakebridge/transpiler/installers.py +++ b/src/databricks/labs/lakebridge/transpiler/installers.py @@ -567,3 +567,18 @@ def _parse_java_version(cls, version: str) -> tuple[int, int, int, int] | None: update = int(match["update"] or 0) patch = int(match["patch"] or 0) return feature, interim, update, patch + + +class SwitchInstaller(TranspilerInstaller): + @property + def name(self) -> str: + return "Switch" + + def can_install(self, artifact: Path) -> bool: + return "databricks_switch_plugin" in artifact.name and artifact.suffix == ".whl" + + def install(self, artifact: Path | None = None) -> bool: + local_name = "switch" + pypi_name = "databricks-switch-plugin" + wheel_installer = WheelInstaller(self._transpiler_repository, local_name, pypi_name, artifact) + return wheel_installer.install() is not None diff --git a/src/databricks/labs/lakebridge/transpiler/switch_runner.py b/src/databricks/labs/lakebridge/transpiler/switch_runner.py new file mode 100644 index 000000000..b6afdbb67 --- /dev/null +++ b/src/databricks/labs/lakebridge/transpiler/switch_runner.py @@ -0,0 +1,191 @@ +import io +import logging +import os +import random +import string +from datetime import datetime, timezone +from pathlib import Path + +from databricks.labs.blueprint.installation import Installation, RootJsonValue +from databricks.labs.blueprint.installer import InstallState +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +class SwitchConfig: + """Helper to load Switch configuration from InstallState.""" + + def __init__(self, install_state: InstallState): + self._install_state = install_state + + def get_resources(self) -> dict[str, str]: + """Get catalog, schema, volume from switch_resources.""" + resources = self._install_state.switch_resources + + if not resources or not all(k in resources for k in ("catalog", "schema", "volume")): + raise SystemExit( + "Switch resources not configured. " + "Please run 'databricks labs lakebridge install-transpile --include-llm-transpiler true' first." + ) + + return { + "catalog": resources["catalog"], + "schema": resources["schema"], + "volume": resources["volume"], + } + + def get_job_id(self) -> int: + """Get Switch job ID from InstallState.""" + if "Switch" in self._install_state.jobs: + logger.debug("Switch job ID found in InstallState") + return int(self._install_state.jobs["Switch"]) + + raise SystemExit( + "Switch Job ID not found. " + "Please run 'databricks labs lakebridge install-transpile --include-llm-transpiler true' first." + ) + + +class SwitchRunner: + """Runner for Switch LLM transpilation jobs.""" + + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + ): + self._ws = ws + self._installation = installation + + def run( + self, + input_source: str, + output_ws_folder: str, + source_dialect: str, + catalog: str, + schema: str, + volume: str, + job_id: int, + switch_options: dict[str, str], + wait_for_completion: bool = False, + ) -> RootJsonValue: + """Upload local files to Volume and trigger Switch job.""" + volume_input_path = self._upload_to_volume(Path(input_source), catalog, schema, volume) + + job_params = self._build_job_parameters( + volume_input_path=volume_input_path, + output_ws_folder=output_ws_folder, + catalog=catalog, + schema=schema, + source_dialect=source_dialect, + switch_options=switch_options, + ) + logger.info(f"Triggering Switch job with job_id: {job_id}") + + return self._run_job(job_id, job_params, wait_for_completion) + + def _upload_to_volume( + self, + local_path: Path, + catalog: str, + schema: str, + volume: str, + ) -> str: + """Upload local files to UC Volume with unique timestamped path.""" + now = datetime.now(timezone.utc) + time_part = now.strftime("%Y%m%d%H%M%S") + random_part = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) + timestamp_suffix = f"{time_part}_{random_part}" + volume_base_path = f"/Volumes/{catalog}/{schema}/{volume}" + volume_input_path = f"{volume_base_path}/input_{timestamp_suffix}" + + logger.info(f"Uploading {local_path} to {volume_input_path}...") + + # File upload + if local_path.is_file(): + volume_file_path = f"{volume_input_path}/{local_path.name}" + with open(local_path, 'rb') as f: + content = f.read() + self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True) + logger.debug(f"Uploaded: {local_path} -> {volume_file_path}") + + # Directory upload + else: + for root, _, files in os.walk(local_path): + for file in files: + local_file = Path(root) / file + relative_path = local_file.relative_to(local_path) + volume_file_path = f"{volume_input_path}/{relative_path}" + + with open(local_file, 'rb') as f: + content = f.read() + + self._ws.files.upload(file_path=volume_file_path, contents=io.BytesIO(content), overwrite=True) + logger.debug(f"Uploaded: {local_file} -> {volume_file_path}") + + logger.info(f"Upload complete: {volume_input_path}") + return volume_input_path + + def _build_job_parameters( + self, + volume_input_path: str, + output_ws_folder: str, + catalog: str, + schema: str, + source_dialect: str, + switch_options: dict[str, str], + ) -> dict[str, str]: + """Build Switch job parameters.""" + return { + "input_dir": volume_input_path, + "output_dir": output_ws_folder, + "result_catalog": catalog, + "result_schema": schema, + "builtin_prompt": source_dialect, + **switch_options, + } + + def _run_job( + self, + job_id: int, + job_params: dict[str, str], + wait_for_completion: bool, + ) -> RootJsonValue: + """Execute Switch job and return run information.""" + if wait_for_completion: + run = self._ws.jobs.run_now_and_wait(job_id, job_parameters=job_params) + + if not run.run_id: + raise SystemExit(f"Job {job_id} execution failed.") + + job_run_url = f"{self._ws.config.host}/jobs/{job_id}/runs/{run.run_id}" + logger.info(f"Switch LLM transpilation job completed: {job_run_url}") + + return [ + { + "job_id": job_id, + "run_id": run.run_id, + "run_url": job_run_url, + "state": ( + run.state.life_cycle_state.value if run.state and run.state.life_cycle_state else "UNKNOWN" + ), + "result_state": run.state.result_state.value if run.state and run.state.result_state else None, + } + ] + + wait = self._ws.jobs.run_now(job_id, job_parameters=job_params) + + if not wait.run_id: + raise SystemExit(f"Job {job_id} execution failed.") + + job_run_url = f"{self._ws.config.host}/jobs/{job_id}/runs/{wait.run_id}" + logger.info(f"Switch LLM transpilation job started: {job_run_url}") + + return [ + { + "job_id": job_id, + "run_id": wait.run_id, + "run_url": job_run_url, + } + ] diff --git a/tests/conftest.py b/tests/conftest.py index 3ab95ca75..000f07d6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -312,6 +312,21 @@ def morpheus_artifact() -> Path: return artifact +@pytest.fixture +def switch_artifact() -> Path: + """Get Switch wheel for testing.""" + artifact = ( + Path(__file__).parent + / "resources" + / "transpiler_configs" + / "switch" + / "wheel" + / "databricks_switch_plugin-0.1.2-py3-none-any.whl" + ) + assert artifact.exists(), f"Switch artifact not found: {artifact}" + return artifact + + class FakeDataSource(DataSource): def __init__(self, start_delimiter: str, end_delimiter: str): diff --git a/tests/integration/transpile/test_switch.py b/tests/integration/transpile/test_switch.py new file mode 100644 index 000000000..49531f460 --- /dev/null +++ b/tests/integration/transpile/test_switch.py @@ -0,0 +1,94 @@ +from pathlib import Path + +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound + +from databricks.labs.lakebridge.config import SwitchResourcesConfig +from databricks.labs.lakebridge.contexts.application import ApplicationContext +from databricks.labs.lakebridge.deployment.job import JobDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment +from databricks.labs.lakebridge.transpiler.installers import SwitchInstaller +from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository +from databricks.labs.lakebridge.transpiler.switch_runner import SwitchConfig + + +def test_switch_installation(ws: WorkspaceClient, switch_artifact: Path): + """Test Switch installation, job creation, resource persistence, and cleanup.""" + context = ApplicationContext(ws) + installation = context.installation + install_state = InstallState.from_installation(installation) + transpiler_repository = TranspilerRepository.user_home() + + # Phase 1: Local installation + installer = SwitchInstaller(transpiler_repository) + result = installer.install(switch_artifact) + assert result, "Switch local installation failed" + + # Phase 2: Workspace deployment + product_info = ProductInfo.from_class(SwitchDeployment) + job_deployer = JobDeployment(ws, installation, install_state, product_info) + switch_deployment = SwitchDeployment( + ws, installation, install_state, product_info, job_deployer, transpiler_repository + ) + + resources = SwitchResourcesConfig(catalog="test_catalog", schema="test_schema", volume="test_volume") + switch_deployment.install(resources) + + try: + install_state = InstallState.from_installation(installation) + job_id = _verify_job_creation(ws, install_state) + _verify_resource_persistence(install_state) + _verify_job_id_retrieval(install_state, job_id) + finally: + try: + switch_deployment.uninstall() + installation.remove() + except NotFound: + pass + + +def _verify_job_creation(ws: WorkspaceClient, install_state: InstallState): + """Verify job creation and registration.""" + assert "Switch" in install_state.jobs + job_id = int(install_state.jobs["Switch"]) + + job = ws.jobs.get(job_id) + assert job is not None + assert job.settings is not None + assert job.settings.name is not None + assert "switch" in job.settings.name.lower() + + assert job.settings.tasks is not None + assert len(job.settings.tasks) > 0 + task = job.settings.tasks[0] + assert task.notebook_task is not None + assert "switch" in task.notebook_task.notebook_path.lower() + return job_id + + +def _verify_resource_persistence(install_state: InstallState): + """Verify resource persistence.""" + assert install_state.switch_resources is not None + resources = install_state.switch_resources + assert "catalog" in resources + assert "schema" in resources + assert "volume" in resources + + switch_config = SwitchConfig(install_state) + retrieved_resources = switch_config.get_resources() + + assert retrieved_resources["catalog"] == "test_catalog" + assert retrieved_resources["schema"] == "test_schema" + assert retrieved_resources["volume"] == "test_volume" + + +def _verify_job_id_retrieval(install_state: InstallState, expected_job_id: int): + """Verify job ID retrieval.""" + switch_config = SwitchConfig(install_state) + job_id_from_config = switch_config.get_job_id() + + assert isinstance(job_id_from_config, int) + assert job_id_from_config > 0 + assert job_id_from_config == expected_job_id diff --git a/tests/resources/transpiler_configs/switch/wheel/databricks_switch_plugin-0.1.2-py3-none-any.whl b/tests/resources/transpiler_configs/switch/wheel/databricks_switch_plugin-0.1.2-py3-none-any.whl new file mode 100644 index 000000000..90de13f37 Binary files /dev/null and b/tests/resources/transpiler_configs/switch/wheel/databricks_switch_plugin-0.1.2-py3-none-any.whl differ diff --git a/tests/unit/deployment/test_installation.py b/tests/unit/deployment/test_installation.py index b7579bf7c..157d55d97 100644 --- a/tests/unit/deployment/test_installation.py +++ b/tests/unit/deployment/test_installation.py @@ -1,3 +1,4 @@ +import logging from unittest.mock import create_autospec import pytest @@ -16,9 +17,11 @@ ReconcileConfig, DatabaseConfig, ReconcileMetadataConfig, + SwitchResourcesConfig, ) from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.deployment.recon import ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment @pytest.fixture @@ -37,6 +40,7 @@ def test_install_all(ws): } ) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) product_info = create_autospec(ProductInfo) upgrades = create_autospec(Upgrades) @@ -66,13 +70,16 @@ def test_install_all(ws): ), ) config = LakebridgeConfiguration(transpile=transpile_config, reconcile=reconcile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) def test_no_recon_component_installation(ws): prompts = MockPrompts({}) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) product_info = create_autospec(ProductInfo) upgrades = create_autospec(Upgrades) @@ -87,13 +94,16 @@ def test_no_recon_component_installation(ws): schema_name="transpiler7", ) config = LakebridgeConfiguration(transpile=transpile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) recon_deployment.install.assert_not_called() def test_recon_component_installation(ws): recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) prompts = MockPrompts({}) product_info = create_autospec(ProductInfo) @@ -115,7 +125,9 @@ def test_recon_component_installation(ws): ), ) config = LakebridgeConfiguration(reconcile=reconcile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) recon_deployment.install.assert_called() @@ -128,10 +140,13 @@ def test_negative_uninstall_confirmation(ws): ) installation = create_autospec(Installation) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) config = LakebridgeConfiguration() ws_installation.uninstall(config) installation.remove.assert_not_called() @@ -147,10 +162,13 @@ def test_missing_installation(ws): installation.files.side_effect = NotFound("Installation not found") installation.install_folder.return_value = "~/mock" recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) config = LakebridgeConfiguration() ws_installation.uninstall(config) installation.remove.assert_not_called() @@ -193,10 +211,13 @@ def test_uninstall_configs_exist(ws): config = LakebridgeConfiguration(transpile=transpile_config, reconcile=reconcile_config) installation = MockInstallation({}) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) ws_installation.uninstall(config) recon_deployment.uninstall.assert_called() installation.assert_removed() @@ -210,11 +231,68 @@ def test_uninstall_configs_missing(ws): ) installation = MockInstallation() recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) config = LakebridgeConfiguration() ws_installation.uninstall(config) recon_deployment.uninstall.assert_not_called() installation.assert_removed() + + +class TestSwitchInstallation: + """Tests for Switch transpiler installation.""" + + def test_switch_install_uses_configured_resources(self, ws): + prompts = MockPrompts({}) + recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + switch_resources = SwitchResourcesConfig(catalog="cat", schema="sch", volume="vol") + transpile_config = TranspileConfig( + include_llm=True, + switch_resources=switch_resources, + ) + config = LakebridgeConfiguration(transpile=transpile_config) + + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) + + ws_installation.install(config) + + switch_deployment.install.assert_called_once() + args, _ = switch_deployment.install.call_args + assert args[0] is switch_resources + + def test_switch_install_missing_resources_logs_error(self, ws, caplog): + prompts = MockPrompts({}) + recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + transpile_config = TranspileConfig(include_llm=True) + config = LakebridgeConfiguration(transpile=transpile_config) + + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) + + with caplog.at_level(logging.ERROR): + ws_installation.install(config) + + switch_deployment.install.assert_not_called() + assert any( + "Switch resources are missing" in record.message + for record in caplog.records + if record.levelno == logging.ERROR + ) diff --git a/tests/unit/deployment/test_switch.py b/tests/unit/deployment/test_switch.py new file mode 100644 index 000000000..c0fac0656 --- /dev/null +++ b/tests/unit/deployment/test_switch.py @@ -0,0 +1,164 @@ +from unittest.mock import Mock, create_autospec + +import pytest + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.lakebridge.config import LSPConfigOptionV1, LSPPromptMethod, SwitchResourcesConfig +from databricks.labs.lakebridge.deployment.job import JobDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment +from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound +from databricks.sdk.service.jobs import JobParameterDefinition + + +class FriendOfSwitchDeployment(SwitchDeployment): + """A friend class to access protected members for testing purposes.""" + + def get_switch_job_parameters(self) -> list[JobParameterDefinition]: + return self._get_switch_job_parameters() + + +@pytest.fixture() +def workspace_client(): + ws = create_autospec(WorkspaceClient) + ws.jobs = Mock() + ws.jobs.delete = Mock() + ws.jobs.get = Mock() + ws.jobs.reset = Mock() + ws.jobs.create = Mock() + return ws + + +@pytest.fixture() +def install_state(): + state = create_autospec(InstallState) + state.jobs = {} + state.switch_resources = {} + return state + + +@pytest.fixture() +def switch_deployment(workspace_client, install_state): + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + job_deployer = create_autospec(JobDeployment) + repository = create_autospec(TranspilerRepository) + + return SwitchDeployment( # type: ignore[call-arg] + workspace_client, installation, install_state, product_info, job_deployer, repository + ) + + +def test_record_resources_persists_install_state(switch_deployment, install_state, monkeypatch, tmp_path): + resources = SwitchResourcesConfig(catalog="cat", schema="sch", volume="vol") + + install_state.switch_resources = {} + install_state.save.reset_mock() + monkeypatch.setattr(switch_deployment, "_get_switch_package_path", lambda: tmp_path) + monkeypatch.setattr(switch_deployment, "_deploy_workspace", lambda _: None) + monkeypatch.setattr(switch_deployment, "_setup_job", lambda: None) + + switch_deployment.install(resources) + + saved = install_state.switch_resources + assert saved["catalog"] == "cat" + assert saved["schema"] == "sch" + assert saved["volume"] == "vol" + install_state.save.assert_called_once() + + +def test_install_records_resources(switch_deployment, monkeypatch, tmp_path): + resources = SwitchResourcesConfig(catalog="cat", schema="sch", volume="vol") + call_order = [] + + monkeypatch.setattr(switch_deployment, "_get_switch_package_path", lambda: tmp_path) + monkeypatch.setattr(switch_deployment, "_deploy_workspace", lambda pkg: call_order.append(("deploy", pkg))) + monkeypatch.setattr(switch_deployment, "_setup_job", lambda: call_order.append(("setup", None))) + monkeypatch.setattr(switch_deployment, "_record_resources", lambda res: call_order.append(("record", res))) + + switch_deployment.install(resources) + + assert ("deploy", tmp_path) in call_order + assert ("setup", None) in call_order + assert ("record", resources) in call_order + + +def test_uninstall_removes_job_and_saves_state(switch_deployment, install_state, workspace_client): + install_state.jobs = {"Switch": "123"} + install_state.save.reset_mock() + + workspace_client.jobs.delete.reset_mock() + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + workspace_client.jobs.delete.assert_called_once_with(123) + install_state.save.assert_called_once() + + +def test_uninstall_handles_missing_job(switch_deployment, install_state, workspace_client): + install_state.jobs = {"Switch": "123"} + workspace_client.jobs.delete.side_effect = NotFound("missing") + + switch_deployment.uninstall() + + install_state.save.assert_called_once() + + +def test_get_configured_resources_returns_mapping(switch_deployment, install_state): + install_state.switch_resources = {"catalog": "c", "schema": "s", "volume": "v"} + resources = switch_deployment.get_configured_resources() + assert resources == {"catalog": "c", "schema": "s", "volume": "v"} + + +def test_get_configured_resources_none_when_absent(switch_deployment, install_state): + install_state.switch_resources = {} + assert switch_deployment.get_configured_resources() is None + + +def test_get_switch_job_parameters_excludes_wait_for_completion(): + config_options = { + "all": [ + LSPConfigOptionV1( + flag="wait_for_completion", + method=LSPPromptMethod.CONFIRM, + prompt="Wait for completion?", + default="true", + ), + LSPConfigOptionV1( + flag="some_other_option", + method=LSPPromptMethod.QUESTION, + prompt="Enter value", + default="default_value", + ), + ] + } + + mock_config = Mock() + mock_config.options = config_options + + mock_repository = create_autospec(TranspilerRepository) + mock_repository.all_transpiler_configs.return_value = {"switch": mock_config} + + ws = create_autospec(WorkspaceClient) + installation = create_autospec(Installation) + state = create_autospec(InstallState) + state.jobs = {} + product_info = create_autospec(ProductInfo) + job_deployer = create_autospec(JobDeployment) + + deployment = FriendOfSwitchDeployment(ws, installation, state, product_info, job_deployer, mock_repository) + + job_params = deployment.get_switch_job_parameters() + param_names = {param.name for param in job_params} + + assert "wait_for_completion" not in param_names + assert "some_other_option" in param_names + assert "input_dir" in param_names + assert "output_dir" in param_names + assert "result_catalog" in param_names + assert "result_schema" in param_names + assert "builtin_prompt" in param_names diff --git a/tests/unit/test_cli_transpile.py b/tests/unit/test_cli_transpile.py index 4164cc964..02d11ffd3 100644 --- a/tests/unit/test_cli_transpile.py +++ b/tests/unit/test_cli_transpile.py @@ -10,6 +10,7 @@ from databricks.labs.lakebridge import cli +from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.lakebridge.config import TranspileConfig from databricks.sdk import WorkspaceClient @@ -414,3 +415,113 @@ def test_describe_transpile(mock_cli_transpile_no_config, transpiler_repository: } ], } + + +class TestLLMTranspile: + """Test suite for llm-transpile command.""" + + @pytest.fixture + def switch_transpiler_repository(self, tmp_path: Path) -> TranspilerRepository: + """Create TranspilerRepository with Switch config.""" + switch_config = { + "remorph": { + "version": 1, + "name": "Switch", + "dialects": ["snowflake", "teradata", "oracle"], + "command_line": ["/usr/bin/true"], + }, + "options": {"all": []}, + } + config_path = tmp_path / "remorph-transpilers" / "Switch" / "lib" / "config.yml" + config_path.parent.mkdir(parents=True) + with config_path.open("w") as f: + yaml.dump(switch_config, f) + return TranspilerRepository(tmp_path) + + def test_with_valid_inputs( + self, empty_input_source: Path, switch_transpiler_repository: TranspilerRepository + ) -> None: + """Test llm-transpile with valid parameters.""" + ws = create_autospec(WorkspaceClient) + mock_runner_run = MagicMock(return_value=[{"job_id": 123, "run_id": 456, "run_url": "http://test"}]) + + with ( + patch("databricks.labs.lakebridge.cli.ApplicationContext") as mock_ctx, + patch( + "databricks.labs.lakebridge.cli._llm_transpile", return_value=mock_runner_run.return_value + ) as mock_llm, + ): + mock_ctx.return_value.transpile_config = None + mock_ctx.return_value.prompts = MockPrompts({}) + mock_ctx.return_value.install_state = create_autospec(InstallState) + + cli.llm_transpile( + w=ws, + input_source=str(empty_input_source), + output_ws_folder="/Workspace/Users/test/output", + source_dialect="snowflake", + transpiler_repository=switch_transpiler_repository, + ) + + # Verify _llm_transpile was called with correct parameters + mock_llm.assert_called_once() + call_args = mock_llm.call_args + ctx_arg = call_args[0][0] + params_arg = call_args[0][1] + + # Verify context + assert ctx_arg == mock_ctx.return_value + + # Verify parameters + assert params_arg["input_source"] == str(empty_input_source) + assert params_arg["output_ws_folder"] == "/Workspace/Users/test/output" + assert params_arg["source_dialect"] == "snowflake" + assert "switch_options" in params_arg + assert "wait_for_completion" in params_arg + + @pytest.mark.parametrize( + ("input_source_type", "output_ws_folder", "source_dialect", "expected_error"), + ( + pytest.param( + "invalid", "/Workspace/Users/test/output", "snowflake", "Invalid path", id="invalid_input_source" + ), + pytest.param( + "valid", "/invalid/path", "snowflake", "Must start with /Workspace/", id="invalid_output_folder" + ), + pytest.param( + "valid", + "/Workspace/Users/test/output", + "invalid_dialect", + "Invalid source-dialect", + id="invalid_dialect", + ), + ), + ) + def test_with_invalid_params( + self, + empty_input_source: Path, + switch_transpiler_repository: TranspilerRepository, + input_source_type: str, + output_ws_folder: str, + source_dialect: str, + expected_error: str, + ) -> None: + """Test llm-transpile with invalid parameters.""" + ws = create_autospec(WorkspaceClient) + input_source = "/invalid/path/does/not/exist" if input_source_type == "invalid" else str(empty_input_source) + + with ( + patch("databricks.labs.lakebridge.cli.ApplicationContext") as mock_ctx, + pytest.raises(ValueError, match=expected_error), + ): + mock_ctx.return_value.transpile_config = None + mock_ctx.return_value.prompts = MockPrompts({}) + mock_ctx.return_value.install_state = create_autospec(InstallState) + + cli.llm_transpile( + w=ws, + input_source=input_source, + output_ws_folder=output_ws_folder, + source_dialect=source_dialect, + transpiler_repository=switch_transpiler_repository, + ) diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index fd9aee169..17417b885 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable, Generator, Sequence +from collections.abc import Callable, Generator, Sequence, Set from pathlib import Path from unittest.mock import create_autospec, patch @@ -23,7 +23,12 @@ from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.install import WorkspaceInstaller from databricks.labs.lakebridge.reconcile.constants import ReconSourceType, ReconReportType -from databricks.labs.lakebridge.transpiler.installers import TranspilerInstaller +from databricks.labs.lakebridge.transpiler.installers import ( + BladebridgeInstaller, + MorpheusInstaller, + SwitchInstaller, + TranspilerInstaller, +) from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository from tests.unit.conftest import path_to_resource @@ -1598,3 +1603,59 @@ def test_no_configure_if_noninteractive( assert config.transpile is None expected_log_message = "Installation is not interactive, skipping configuration of transpilers." assert any(expected_log_message in log.message for log in caplog.records if log.levelno == logging.WARNING) + + +class FriendOfWorkspaceInstaller(WorkspaceInstaller): + """A friend class to access protected members for testing purposes.""" + + def get_transpiler_installers(self) -> Set[TranspilerInstaller]: + return self._transpiler_installers + + +@pytest.mark.parametrize( + "include_llm_transpiler,should_include_switch", + [ + (False, False), # Default: exclude Switch + (True, True), # Flag enabled: include Switch + (None, False), # Not specified: default behavior (exclude Switch) + ], +) +def test_transpiler_installers_llm_flag( + ws: WorkspaceClient, include_llm_transpiler: bool | None, should_include_switch: bool +) -> None: + """Test Switch installer filtering based on include_llm_transpiler flag.""" + ctx = ApplicationContext(ws) + + if include_llm_transpiler is not None: + installer = FriendOfWorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + include_llm=include_llm_transpiler, + ) + else: + installer = FriendOfWorkspaceInstaller( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + ) + installers = installer.get_transpiler_installers() + installer_types = {type(i) for i in installers} + + # Verify Switch inclusion/exclusion + if should_include_switch: + assert SwitchInstaller in installer_types + else: + assert SwitchInstaller not in installer_types + + # Other transpilers should always be included + assert BladebridgeInstaller in installer_types + assert MorpheusInstaller in installer_types diff --git a/tests/unit/transpiler/test_installers.py b/tests/unit/transpiler/test_installers.py index f729686b8..0e4ef96d6 100644 --- a/tests/unit/transpiler/test_installers.py +++ b/tests/unit/transpiler/test_installers.py @@ -6,7 +6,12 @@ import pytest -from databricks.labs.lakebridge.transpiler.installers import ArtifactInstaller, MorpheusInstaller +from databricks.labs.lakebridge.transpiler.installers import ( + ArtifactInstaller, + MorpheusInstaller, + SwitchInstaller, +) +from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository def test_store_product_state(tmp_path) -> None: @@ -110,3 +115,34 @@ def test_java_version_parse_missing() -> None: version_output = "Nothing in here that looks like a version." parsed = FriendOfMorpheusInstaller.parse_java_version(version_output) assert parsed is None + + +class TestSwitchInstaller: + """Test suite for SwitchInstaller.""" + + @pytest.fixture + def installer(self, tmp_path: Path) -> SwitchInstaller: + """Create a SwitchInstaller instance for testing.""" + repository = TranspilerRepository(tmp_path) + return SwitchInstaller(repository) + + def test_name(self, installer: SwitchInstaller) -> None: + """Verify the installer name is correct.""" + assert installer.name == "Switch" + + @pytest.mark.parametrize( + ("filename", "expected"), + ( + # Valid Switch wheel files + pytest.param("databricks_switch_plugin-0.1.0-py3-none-any.whl", True, id="valid_version"), + pytest.param("databricks_switch_plugin-1.2.3-py3-none-any.whl", True, id="valid_multi_digit"), + pytest.param("databricks_switch_plugin-0.1.0rc1-py3-none-any.whl", True, id="valid_rc_version"), + # Invalid files + pytest.param("databricks_bb_plugin-0.1.0-py3-none-any.whl", False, id="wrong_package"), + pytest.param("some_other_package-0.1.0-py3-none-any.whl", False, id="other_package"), + pytest.param("databricks_switch_plugin-0.1.0.jar", False, id="wrong_extension"), + ), + ) + def test_can_install(self, filename: str, expected: bool, installer: SwitchInstaller) -> None: + """Verify can_install works for valid and invalid files.""" + assert installer.can_install(Path(filename)) == expected diff --git a/tests/unit/transpiler/test_switch_runner.py b/tests/unit/transpiler/test_switch_runner.py new file mode 100644 index 000000000..40099cdc9 --- /dev/null +++ b/tests/unit/transpiler/test_switch_runner.py @@ -0,0 +1,346 @@ +from pathlib import Path +from unittest.mock import Mock, create_autospec + +import pytest + +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.lakebridge.transpiler.switch_runner import SwitchConfig, SwitchRunner + + +class FriendOfSwitchRunner(SwitchRunner): + """Friend class to access protected methods for testing.""" + + def upload_to_volume(self, local_path: Path, catalog: str, schema: str, volume: str) -> str: + return self._upload_to_volume(local_path, catalog, schema, volume) + + def build_job_parameters( + self, + volume_input_path: str, + output_ws_folder: str, + catalog: str, + schema: str, + source_dialect: str, + switch_options: dict[str, str], + ) -> dict[str, str]: + return self._build_job_parameters( + volume_input_path, output_ws_folder, catalog, schema, source_dialect, switch_options + ) + + +class TestSwitchConfig: + """Test suite for SwitchConfig.""" + + @pytest.fixture + def install_state_with_switch(self) -> InstallState: + """Create InstallState with Switch configured.""" + state = create_autospec(InstallState) + state.switch_resources = { + "catalog": "test_catalog", + "schema": "test_schema", + "volume": "test_volume", + } + state.jobs = {"Switch": "12345"} + return state + + @pytest.fixture + def install_state_without_switch(self) -> InstallState: + """Create InstallState without Switch configured.""" + state = create_autospec(InstallState) + state.switch_resources = None + state.jobs = {} + return state + + def test_get_resources_success(self, install_state_with_switch: InstallState) -> None: + """Verify successful resource retrieval from InstallState.""" + config = SwitchConfig(install_state_with_switch) + resources = config.get_resources() + + assert resources == { + "catalog": "test_catalog", + "schema": "test_schema", + "volume": "test_volume", + } + + @pytest.mark.parametrize( + ("switch_resources", "error_msg_fragment"), + ( + pytest.param(None, "Switch resources not configured", id="resources_none"), + pytest.param({"catalog": "cat"}, "Switch resources not configured", id="missing_schema_volume"), + pytest.param({"catalog": "cat", "schema": "sch"}, "Switch resources not configured", id="missing_volume"), + ), + ) + def test_get_resources_not_configured(self, switch_resources: dict | None, error_msg_fragment: str) -> None: + """Test error when switch_resources missing or incomplete.""" + state = create_autospec(InstallState) + state.switch_resources = switch_resources + + config = SwitchConfig(state) + + with pytest.raises(SystemExit, match=error_msg_fragment): + config.get_resources() + + @pytest.mark.parametrize( + ("jobs", "should_succeed"), + ( + pytest.param({"Switch": "12345"}, True, id="job_exists"), + pytest.param({"Switch": 67890}, True, id="job_exists_int"), + pytest.param({}, False, id="job_missing"), + pytest.param({"OtherJob": "99999"}, False, id="other_job_only"), + ), + ) + def test_get_job_id(self, jobs: dict, should_succeed: bool) -> None: + """Test job ID retrieval from InstallState.""" + state = create_autospec(InstallState) + state.jobs = jobs + + config = SwitchConfig(state) + + if should_succeed: + job_id = config.get_job_id() + assert isinstance(job_id, int) + assert job_id == int(jobs["Switch"]) + else: + with pytest.raises(SystemExit, match="Switch Job ID not found"): + config.get_job_id() + + +class TestSwitchRunner: + """Test suite for SwitchRunner.""" + + @pytest.fixture + def mock_ws(self) -> Mock: + """Create mock WorkspaceClient.""" + ws = Mock() + ws.config.host = "https://test.databricks.com" + ws.files.upload = Mock() + ws.jobs.run_now = Mock() + ws.jobs.run_now_and_wait = Mock() + return ws + + @pytest.fixture + def mock_installation(self) -> Mock: + """Create mock Installation.""" + return Mock() + + @pytest.fixture + def runner(self, mock_ws: Mock, mock_installation: Mock) -> FriendOfSwitchRunner: + """Create SwitchRunner instance.""" + return FriendOfSwitchRunner(mock_ws, mock_installation) + + @pytest.mark.parametrize( + "is_file", + ( + pytest.param(True, id="single_file"), + pytest.param(False, id="directory"), + ), + ) + def test_upload_to_volume(self, runner: FriendOfSwitchRunner, mock_ws: Mock, tmp_path: Path, is_file: bool) -> None: + """Test file/directory upload to Volume.""" + # Setup test files + if is_file: + test_file = tmp_path / "test.sql" + test_file.write_text("SELECT 1;") + local_path = test_file + expected_calls = 1 + else: + test_dir = tmp_path / "queries" + test_dir.mkdir() + (test_dir / "query1.sql").write_text("SELECT 1;") + (test_dir / "query2.sql").write_text("SELECT 2;") + subdir = test_dir / "subdir" + subdir.mkdir() + (subdir / "query3.sql").write_text("SELECT 3;") + local_path = test_dir + expected_calls = 3 + + # Execute upload + volume_path = runner.upload_to_volume( + local_path=local_path, catalog="test_cat", schema="test_sch", volume="test_vol" + ) + + # Verify results + assert volume_path.startswith("/Volumes/test_cat/test_sch/test_vol/input_") + assert mock_ws.files.upload.call_count == expected_calls + + def test_upload_unique_paths(self, runner: FriendOfSwitchRunner, tmp_path: Path) -> None: + """Verify UUID usage prevents path collisions.""" + test_file = tmp_path / "test.sql" + test_file.write_text("SELECT 1;") + + # Generate two upload paths + path1 = runner.upload_to_volume(test_file, "cat", "sch", "vol") + path2 = runner.upload_to_volume(test_file, "cat", "sch", "vol") + + # Paths should be different due to UUID + assert path1 != path2 + assert path1.startswith("/Volumes/cat/sch/vol/input_") + assert path2.startswith("/Volumes/cat/sch/vol/input_") + + def test_upload_preserves_hierarchy(self, runner: FriendOfSwitchRunner, mock_ws: Mock, tmp_path: Path) -> None: + """Verify subdirectory structure is maintained.""" + # Create nested directory structure + root = tmp_path / "root" + root.mkdir() + (root / "file1.sql").write_text("SELECT 1;") + subdir = root / "sub" / "deep" + subdir.mkdir(parents=True) + (subdir / "file2.sql").write_text("SELECT 2;") + + # Upload directory + runner.upload_to_volume(root, "cat", "sch", "vol") + + # Verify hierarchy is preserved + calls = [call[1]["file_path"] for call in mock_ws.files.upload.call_args_list] + assert any("file1.sql" in path and "/sub/" not in path for path in calls) + assert any("file2.sql" in path and "/sub/deep/" in path for path in calls) + + @pytest.mark.parametrize( + ("switch_options", "expected_extra_keys"), + ( + pytest.param({}, 0, id="no_options"), + pytest.param({"endpoint_name": "claude", "concurrency": "4", "log_level": "DEBUG"}, 3, id="with_options"), + ), + ) + def test_build_job_parameters( + self, runner: FriendOfSwitchRunner, switch_options: dict, expected_extra_keys: int + ) -> None: + """Test job parameter construction.""" + params = runner.build_job_parameters( + volume_input_path="/Volumes/cat/sch/vol/input_123", + output_ws_folder="/Workspace/Users/test/output", + catalog="test_cat", + schema="test_sch", + source_dialect="snowflake", + switch_options=switch_options, + ) + + # Verify required parameters + assert params["input_dir"] == "/Volumes/cat/sch/vol/input_123" + assert params["output_dir"] == "/Workspace/Users/test/output" + assert params["result_catalog"] == "test_cat" + assert params["result_schema"] == "test_sch" + assert params["builtin_prompt"] == "snowflake" + + # Verify options are included + assert len(params) == 5 + expected_extra_keys + for key, value in switch_options.items(): + assert params[key] == value + + @pytest.mark.parametrize( + "wait_for_completion", + ( + pytest.param(False, id="async_execution"), + pytest.param(True, id="sync_execution"), + ), + ) + def test_run_job_execution( + self, runner: FriendOfSwitchRunner, mock_ws: Mock, tmp_path: Path, wait_for_completion: bool + ) -> None: + """Test async/sync job execution based on wait_for_completion flag.""" + # Setup test file + test_file = tmp_path / "test.sql" + test_file.write_text("SELECT 1;") + + # Mock job execution + mock_run = Mock() + mock_run.run_id = 99999 + if wait_for_completion: + mock_run.state.life_cycle_state.value = "TERMINATED" + mock_run.state.result_state.value = "SUCCESS" + mock_ws.jobs.run_now_and_wait.return_value = mock_run + else: + mock_ws.jobs.run_now.return_value = mock_run + + # Execute run + result = runner.run( + input_source=str(test_file), + output_ws_folder="/Workspace/output", + source_dialect="snowflake", + catalog="cat", + schema="sch", + volume="vol", + job_id=12345, + switch_options={"log_level": "DEBUG"}, + wait_for_completion=wait_for_completion, + ) + + # Verify correct method called with proper job_parameters + expected_params = { + "result_catalog": "cat", + "result_schema": "sch", + "builtin_prompt": "snowflake", + "log_level": "DEBUG", + } + + if wait_for_completion: + call_args = mock_ws.jobs.run_now_and_wait.call_args + mock_ws.jobs.run_now.assert_not_called() + else: + call_args = mock_ws.jobs.run_now.call_args + mock_ws.jobs.run_now_and_wait.assert_not_called() + + # Verify job_id and job_parameters + assert call_args[0][0] == 12345 # job_id + actual_params = call_args[1]["job_parameters"] + assert actual_params["result_catalog"] == expected_params["result_catalog"] + assert actual_params["result_schema"] == expected_params["result_schema"] + assert actual_params["builtin_prompt"] == expected_params["builtin_prompt"] + assert actual_params["log_level"] == expected_params["log_level"] + assert "input_dir" in actual_params + assert "output_dir" in actual_params + + # Verify result structure + assert isinstance(result, list) + assert len(result) == 1 + first_item = result[0] + assert isinstance(first_item, dict) + + if wait_for_completion: + assert first_item["state"] == "TERMINATED" + assert first_item["result_state"] == "SUCCESS" + else: + assert "state" not in first_item + assert "result_state" not in first_item + + # Verify common result fields + assert first_item["job_id"] == 12345 + assert first_item["run_id"] == 99999 + # Verify run_url format + assert first_item["run_url"] == "https://test.databricks.com/jobs/12345/runs/99999" + + @pytest.mark.parametrize( + "wait_for_completion", + ( + pytest.param(False, id="async_missing_run_id"), + pytest.param(True, id="sync_missing_run_id"), + ), + ) + def test_run_job_execution_with_missing_run_id( + self, runner: FriendOfSwitchRunner, mock_ws: Mock, tmp_path: Path, wait_for_completion: bool + ) -> None: + """Test SystemExit when run_id is missing.""" + # Setup test file + test_file = tmp_path / "test.sql" + test_file.write_text("SELECT 1;") + + # Mock job execution with missing run_id + mock_run = Mock() + mock_run.run_id = None # Simulate missing run_id + if wait_for_completion: + mock_ws.jobs.run_now_and_wait.return_value = mock_run + else: + mock_ws.jobs.run_now.return_value = mock_run + + # Execute and expect SystemExit + with pytest.raises(SystemExit, match="Job 12345 execution failed"): + runner.run( + input_source=str(test_file), + output_ws_folder="/Workspace/output", + source_dialect="snowflake", + catalog="cat", + schema="sch", + volume="vol", + job_id=12345, + switch_options={}, + wait_for_completion=wait_for_completion, + )