diff --git a/src/databricks/labs/lakebridge/transpiler/execute.py b/src/databricks/labs/lakebridge/transpiler/execute.py index d2fa067e5..c194e957f 100644 --- a/src/databricks/labs/lakebridge/transpiler/execute.py +++ b/src/databricks/labs/lakebridge/transpiler/execute.py @@ -178,9 +178,7 @@ def _process_non_mime_result(context: TranspilingContext, error_list: list[Trans output_code: str = context.transpiled_code or "" output_path = cast(Path, context.output_path) - if any(err.kind == ErrorKind.PARSING for err in error_list): - output_code = context.source_code or "" - elif output_path.suffix == ".sql": + if output_path.suffix == ".sql": output_code = _validate_transpiled_sql(context, output_code, error_list) with output_path.open("w") as w: # The above adds a java-style comment block at the top of the output file diff --git a/tests/unit/transpiler/test_execute.py b/tests/unit/transpiler/test_execute.py index 23510b22e..1ec7e88ae 100644 --- a/tests/unit/transpiler/test_execute.py +++ b/tests/unit/transpiler/test_execute.py @@ -181,36 +181,75 @@ def test_with_file(input_source, error_file, mock_workspace_client): check_error_lines(status["error_log_file"], expected_errors) -class IdentityTranspileEngine(TranspileEngine): - """A simple "identity" transpiler that does not change the source it is given.""" +class ConfigurableTestEngine(TranspileEngine): + """Expand test transpiler engine.""" + + def __init__( + self, + *, + transpiler_name: str = "test", + supported_dialects: list[str] | None = None, + transform: Any | None = None, # Callable[[str], str] | None + transpiled_code: str | None = None, + errors: list[TranspileError] | None = None, + file_extensions: list[str] | None = None, + success_count: int | None = None, + ): + + if transform is not None and transpiled_code is not None: + raise ValueError("Cannot specify both transform and transpiled_code") + + self._transpiler_name = transpiler_name + self._supported_dialects = supported_dialects or ["test"] + self._transform = transform + self._static_code = transpiled_code + self._errors = errors or [] + self._file_extensions = file_extensions + + # Auto-calculate success_count if not provided + if success_count is None: + self._success_count = 0 if self._errors else 1 + else: + self._success_count = success_count @property def transpiler_name(self) -> str: - return "identity" + return self._transpiler_name @property def supported_dialects(self) -> list[str]: - return ["identity"] + return self._supported_dialects async def initialize(self, config: TranspileConfig) -> None: assert config.source_dialect in self.supported_dialects async def shutdown(self) -> None: - # Nothing needed here. - return + pass async def transpile( self, source_dialect: str, target_dialect: str, source_code: str, file_path: Path ) -> TranspileResult: assert source_dialect in self.supported_dialects + + # Determine the transpiled code + if self._static_code is not None: + code = self._static_code + elif self._transform is not None: + code = self._transform(source_code) + else: + # Default: identity transform + code = source_code + return TranspileResult( - transpiled_code=source_code, - success_count=1, - error_list=[], + transpiled_code=code, + success_count=self._success_count, + error_list=self._errors, ) def is_supported_file(self, file: Path) -> bool: - return True + if self._file_extensions is None: + return True + return file.suffix in self._file_extensions @pytest.mark.parametrize("encoding", ["utf-32-le", "utf-32-be", "utf-16-le", "utf-16-be", "utf-8-sig", "utf-8"]) @@ -233,7 +272,13 @@ def test_transpile_unicode_files( output_folder=str(output_folder), skip_validation=True, ) - status, _ = transpile(mock_workspace_client, IdentityTranspileEngine(), transpile_config) + # Use identity transform (returns source unchanged) for all files + identity_engine = ConfigurableTestEngine( + transpiler_name="identity", + supported_dialects=["identity"], + file_extensions=None, # Support all file types + ) + status, _ = transpile(mock_workspace_client, identity_engine, transpile_config) assert status.get("total_files_processed") == 1 transpiled_query = (output_folder / "unicode_query.sql").read_text(encoding="utf-8") @@ -469,7 +514,12 @@ def test_encoding_error_utf8_decode_error( ) # This reads files using the default system encoding, and this test assumes it's UTF-8. - status, errors = transpile(mock_workspace_client, IdentityTranspileEngine(), transpile_config) + identity_engine = ConfigurableTestEngine( + transpiler_name="identity", + supported_dialects=["identity"], + file_extensions=None, + ) + status, errors = transpile(mock_workspace_client, identity_engine, transpile_config) # Verify error handling assert status.get("total_files_processed") == 1 # File was processed (but failed) @@ -501,7 +551,12 @@ def test_encoding_error_lookup_error( skip_validation=True, ) - status, errors = transpile(mock_workspace_client, IdentityTranspileEngine(), transpile_config) + identity_engine = ConfigurableTestEngine( + transpiler_name="identity", + supported_dialects=["identity"], + file_extensions=None, + ) + status, errors = transpile(mock_workspace_client, identity_engine, transpile_config) # Verify error handling assert status.get("total_files_processed") == 1 # File was processed (but failed) @@ -698,3 +753,145 @@ def test_make_header_with_one_repeated_warning(): */ """ ) + + +def test_transpiled_code_output_on_parsing_error(tmp_path: Path, mock_workspace_client: WorkspaceClient): + """Test that transpiled code is output even when parsing errors occur.""" + input_file = tmp_path / "test.sql" + output_folder = tmp_path / "output" + output_folder.mkdir() + + original_sql = "SELECT NUMBER_COL, VARCHAR_COL FROM my_table" + input_file.write_text(original_sql) + transpiled_sql = "SELECT DECIMAL_COL, STRING_COL FROM my_table" + + parsing_error = TranspileError( + code="PARSE_ERROR", + kind=ErrorKind.PARSING, + severity=ErrorSeverity.ERROR, + path=input_file, + message="Parsing error: unexpected token", + range=CodeRange(start=CodePosition(0, 0), end=CodePosition(0, 10)), + ) + + # Mock engine returns static transpiled code with parsing errors + mock_engine = ConfigurableTestEngine( + transpiler_name="mock", + supported_dialects=["snowflake", "tsql"], + transpiled_code=transpiled_sql, + errors=[parsing_error], + file_extensions=[".sql"], + ) + + config = TranspileConfig( + transpiler_config_path="mock", + input_source=str(input_file), + output_folder=str(output_folder), + source_dialect="snowflake", + skip_validation=True, + ) + + with patch("databricks.labs.lakebridge.helpers.db_sql.get_sql_backend", return_value=MockBackend()): + status, errors = transpile(mock_workspace_client, mock_engine, config) + + output_file = output_folder / "test.sql" + assert output_file.exists(), "Output file was not created" + output_content = output_file.read_text() + + assert output_content == transpiled_sql, f"Expected transpiled code '{transpiled_sql}' but got '{output_content}'" + assert output_content != original_sql, "Output should not be the original SQL" + + assert len(errors) == 1 + assert errors[0].kind == ErrorKind.PARSING + assert status["parsing_error_count"] == 1 + + +def test_transpiled_code_output_without_errors(tmp_path: Path, mock_workspace_client: WorkspaceClient): + """Test that transpiled code is output correctly when no errors occur.""" + input_file = tmp_path / "success.sql" + output_folder = tmp_path / "output" + output_folder.mkdir() + + original_sql = "CREATE TABLE test (id NUMBER, name VARCHAR(100))" + input_file.write_text(original_sql) + transpiled_sql = "CREATE TABLE test (id DECIMAL(38, 0), name STRING)" + + # Mock engine returns static transpiled code with no errors + mock_engine = ConfigurableTestEngine( + transpiler_name="mock", + supported_dialects=["snowflake", "tsql"], + transpiled_code=transpiled_sql, + errors=[], + file_extensions=[".sql"], + ) + + config = TranspileConfig( + transpiler_config_path="mock", + input_source=str(input_file), + output_folder=str(output_folder), + source_dialect="snowflake", + skip_validation=True, + ) + + with patch("databricks.labs.lakebridge.helpers.db_sql.get_sql_backend", return_value=MockBackend()): + status, errors = transpile(mock_workspace_client, mock_engine, config) + + output_file = output_folder / "success.sql" + assert output_file.exists(), "Output file was not created" + output_content = output_file.read_text() + + assert output_content == transpiled_sql, f"Expected '{transpiled_sql}' but got '{output_content}'" + assert output_content != original_sql, "Output should be transpiled, not original" + + assert len(errors) == 0 + assert status["parsing_error_count"] == 0 + assert status["total_files_processed"] == 1 + + +def test_empty_transpiled_code_with_parsing_error(tmp_path: Path, mock_workspace_client: WorkspaceClient): + """Test handling when transpiled_code is empty/None during parsing error.""" + input_file = tmp_path / "error.sql" + output_folder = tmp_path / "output" + output_folder.mkdir() + + original_sql = "INVALID SQL SYNTAX !!!" + input_file.write_text(original_sql) + + parsing_error = TranspileError( + code="PARSE_ERROR", + kind=ErrorKind.PARSING, + severity=ErrorSeverity.ERROR, + path=input_file, + message="Fatal parsing error", + range=None, + ) + + # Mock engine returns empty transpiled code with parsing errors + mock_engine = ConfigurableTestEngine( + transpiler_name="mock", + supported_dialects=["snowflake", "tsql"], + transpiled_code="", + errors=[parsing_error], + file_extensions=[".sql"], + ) + + config = TranspileConfig( + transpiler_config_path="mock", + input_source=str(input_file), + output_folder=str(output_folder), + source_dialect="snowflake", + skip_validation=True, + ) + + with patch("databricks.labs.lakebridge.helpers.db_sql.get_sql_backend", return_value=MockBackend()): + status, errors = transpile(mock_workspace_client, mock_engine, config) + + output_file = output_folder / "error.sql" + assert output_file.exists(), "Output file should be created even with errors" + output_content = output_file.read_text() + + assert output_content == "", "Output should be empty string when transpilation fails completely" + + assert len(errors) == 1 + assert errors[0].kind == ErrorKind.PARSING + assert status["parsing_error_count"] == 1