Skip to content

Commit 7efd6db

Browse files
authored
Fix type hints on LSP configuration classes and improve validation when loading (#2041)
## Changes This PR updates type hints on the configuration classes, and implements proper checking of the configuration during loading. Changes include: - We now distinguish between missing attributes and attributes that are present but malformed. - With the options for a dialect, the prompt is now optional instead of stubbing with an empty string. (The prompt is not needed for the `FIXED` option type.) - For options, the combinations of choices/default are now dealt with properly depending on the kind of option it is. - Some tests that were using invalid fixtures have been fixed. ### Caveats/things to watch out for when reviewing During deserialisation, in many places we now handle unhappy paths via exceptions rather than checking in advance: the latter is unambiguous whereas its easy to have subtle issues when the check doesn't quite correctly anticipate or correspond to the actual usage. ### Relevant implementation details Sadly, most of this has been done by hand even though there are libraries that do this sort of thing for us. ### Tests - existing unit tests
1 parent 4ce9033 commit 7efd6db

File tree

5 files changed

+214
-62
lines changed

5 files changed

+214
-62
lines changed

src/databricks/labs/lakebridge/config.py

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2+
from collections.abc import Mapping, Sequence
23
from dataclasses import dataclass
34
from enum import Enum, auto
45
from pathlib import Path
5-
from typing import Any, Literal, cast
6+
from typing import Any, Literal, TypeVar, cast
67

78
from databricks.labs.blueprint.installation import JsonValue
89
from databricks.labs.blueprint.tui import Prompts
@@ -20,39 +21,112 @@ class LSPPromptMethod(Enum):
2021
CONFIRM = auto()
2122

2223

23-
@dataclass
24+
E = TypeVar("E", bound=Enum)
25+
26+
27+
def extract_string_field(data: Mapping[str, JsonValue], name: str) -> str:
28+
"""Extract a string field from the given mapping.
29+
30+
Parameters:
31+
data: The mapping to get the string field from.
32+
name: The name of the field to extract.
33+
Raises:
34+
ValueError: If the field is not present, not a string, or an empty string.
35+
"""
36+
value = _maybe_extract_string_field(data, name, is_required=True)
37+
if not value:
38+
msg = f"Invalid '{name}' attribute, must be a non-empty string: {value}"
39+
raise ValueError(msg)
40+
return value
41+
42+
43+
def _maybe_extract_string_field(data: Mapping[str, JsonValue], name: str, *, is_required: bool) -> str | None:
44+
# A variant of extract_string_field() with two differences:
45+
# - It allows for optional fields.
46+
# - A provided string may be empty.
47+
# (This can't easily be folded into extract_string_field() because of the different return type.)
48+
try:
49+
value = data[name]
50+
except KeyError as e:
51+
if is_required:
52+
raise ValueError(f"Missing '{name}' attribute in {data}") from e
53+
return None
54+
if not isinstance(value, str):
55+
msg = f"Invalid '{name}' entry in {data}, expecting a string: {value}"
56+
raise ValueError(msg)
57+
return value
58+
59+
60+
def extract_enum_field(data: Mapping[str, JsonValue], name: str, enum_type: type[E]) -> E:
61+
"""Extract an enum field from the given mapping.
62+
63+
Parameters:
64+
data: The mapping to get the enum field from.
65+
name: The name of the field to extract.
66+
enum_type: The enum type to use for parsing the value.
67+
Raises:
68+
ValueError: If the field is not present and no default is provided, or if it's present but not a valid enum value.
69+
"""
70+
enum_value = extract_string_field(data, name)
71+
try:
72+
return enum_type[enum_value]
73+
except ValueError as e:
74+
valid_values = [m.name for m in enum_type]
75+
msg = f"Invalid '{name}' entry in {data}, expecting one of [{', '.join(valid_values)}]: {enum_value}"
76+
raise ValueError(msg) from e
77+
78+
79+
@dataclass(frozen=True)
2480
class LSPConfigOptionV1:
2581
flag: str
2682
method: LSPPromptMethod
27-
prompt: str = ""
83+
prompt: str | None = None
2884
choices: list[str] | None = None
29-
default: Any = None
85+
default: str | None = None
3086

3187
@classmethod
32-
def parse_all(cls, data: dict[str, Any]) -> dict[str, list["LSPConfigOptionV1"]]:
88+
def parse_all(cls, data: dict[str, Sequence[JsonValue]]) -> dict[str, list["LSPConfigOptionV1"]]:
3389
return {key: list(LSPConfigOptionV1.parse(item) for item in value) for (key, value) in data.items()}
3490

3591
@classmethod
36-
def parse(cls, data: Any) -> "LSPConfigOptionV1":
92+
def _extract_choices_field(cls, data: Mapping[str, JsonValue], prompt_method: LSPPromptMethod) -> list[str] | None:
93+
try:
94+
choices_unsafe = data["choices"]
95+
except KeyError as e:
96+
if prompt_method == LSPPromptMethod.CHOICE:
97+
raise ValueError(f"Missing 'choices' attribute in {data}") from e
98+
return None
99+
if not isinstance(choices_unsafe, list) or not all(isinstance(item, str) for item in choices_unsafe):
100+
msg = f"Invalid 'choices' entry in {data}, expecting a list of strings: {choices_unsafe}"
101+
raise ValueError(msg)
102+
return cast(list[str], choices_unsafe)
103+
104+
@classmethod
105+
def parse(cls, data: JsonValue) -> "LSPConfigOptionV1":
37106
if not isinstance(data, dict):
38107
raise ValueError(f"Invalid transpiler config option, expecting a dict entry, got {data}")
39-
flag: str = data.get("flag", "")
40-
if not flag:
41-
raise ValueError(f"Missing 'flag' entry in {data}")
42-
method_name: str = data.get("method", "")
43-
if not method_name:
44-
raise ValueError(f"Missing 'method' entry in {data}")
45-
method: LSPPromptMethod = cast(LSPPromptMethod, LSPPromptMethod[method_name])
46-
prompt: str = data.get("prompt", "")
47-
if not prompt:
48-
raise ValueError(f"Missing 'prompt' entry in {data}")
49-
choices = data.get("choices", [])
50-
default = data.get("default", None)
51-
return LSPConfigOptionV1(flag, method, prompt, choices, default)
108+
109+
# Field extraction is factored out mainly to ensure the complexity of this method is not too high.
110+
111+
flag = extract_string_field(data, "flag")
112+
method = extract_enum_field(data, "method", LSPPromptMethod)
113+
prompt = _maybe_extract_string_field(data, "prompt", is_required=method != LSPPromptMethod.FORCE)
114+
115+
optional: dict[str, Any] = {}
116+
choices = cls._extract_choices_field(data, method)
117+
if choices is not None:
118+
optional["choices"] = choices
119+
120+
default = _maybe_extract_string_field(data, "default", is_required=False)
121+
if default is not None:
122+
optional["default"] = default
123+
124+
return LSPConfigOptionV1(flag, method, prompt, **optional)
52125

53126
def prompt_for_value(self, prompts: Prompts) -> JsonValue:
54127
if self.method == LSPPromptMethod.FORCE:
55128
return self.default
129+
assert self.prompt is not None
56130
if self.method == LSPPromptMethod.CONFIRM:
57131
return prompts.confirm(self.prompt)
58132
if self.method == LSPPromptMethod.QUESTION:

src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py

Lines changed: 106 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import shutil
88
import sys
99
import venv
10-
from collections.abc import Callable, Sequence, Mapping
10+
from collections.abc import Callable, Iterable, Mapping, Sequence
1111
from dataclasses import dataclass
1212
from pathlib import Path
13-
from typing import Any, Literal
13+
from typing import Any, Literal, cast
1414

1515
import attrs
1616
import yaml
@@ -34,8 +34,9 @@
3434
from pygls.exceptions import FeatureRequestError
3535
from pygls.lsp.client import BaseLanguageClient
3636

37+
from databricks.labs.blueprint.installation import JsonValue, RootJsonValue
3738
from databricks.labs.blueprint.wheels import ProductInfo
38-
from databricks.labs.lakebridge.config import LSPConfigOptionV1, TranspileConfig, TranspileResult
39+
from databricks.labs.lakebridge.config import LSPConfigOptionV1, TranspileConfig, TranspileResult, extract_string_field
3940
from databricks.labs.lakebridge.errors.exceptions import IllegalStateException
4041
from databricks.labs.lakebridge.helpers.file_utils import is_dbt_project_file, is_sql_file
4142
from databricks.labs.lakebridge.transpiler.transpile_engine import TranspileEngine
@@ -50,6 +51,16 @@
5051
logger = logging.getLogger(__name__)
5152

5253

54+
def _is_all_strings(values: Iterable[object]) -> bool:
55+
"""Typeguard, to check if all values in the iterable are strings."""
56+
return all(isinstance(x, str) for x in values)
57+
58+
59+
def _is_all_sequences(values: Iterable[object]) -> bool:
60+
"""Typeguard, to check if all values in the iterable are sequences."""
61+
return all(isinstance(x, Sequence) for x in values)
62+
63+
5364
@dataclass
5465
class _LSPRemorphConfigV1:
5566
name: str
@@ -58,29 +69,63 @@ class _LSPRemorphConfigV1:
5869
command_line: Sequence[str]
5970

6071
@classmethod
61-
def parse(cls, data: Mapping[str, Any]) -> _LSPRemorphConfigV1:
62-
version = data.get("version", 0)
72+
def parse(cls, data: Mapping[str, JsonValue]) -> _LSPRemorphConfigV1:
73+
cls._check_version(data)
74+
name = extract_string_field(data, "name")
75+
dialects = cls._extract_dialects(data)
76+
env_vars = cls._extract_env_vars(data)
77+
command_line = cls._extract_command_line(data)
78+
return _LSPRemorphConfigV1(name, dialects, env_vars, command_line)
79+
80+
@classmethod
81+
def _check_version(cls, data: Mapping[str, JsonValue]) -> None:
82+
try:
83+
version = data["version"]
84+
except KeyError as e:
85+
raise ValueError("Missing 'version' attribute") from e
6386
if version != 1:
6487
raise ValueError(f"Unsupported transpiler config version: {version}")
65-
name: str | None = data.get("name", None)
66-
if not name:
67-
raise ValueError("Missing 'name' entry")
68-
dialects = data.get("dialects", [])
69-
if len(dialects) == 0:
70-
raise ValueError("Missing 'dialects' entry")
71-
env_vars = data.get("environment", {})
72-
command_line = data.get("command_line", [])
73-
if len(command_line) == 0:
74-
raise ValueError("Missing 'command_line' entry")
75-
return _LSPRemorphConfigV1(name, dialects, env_vars, command_line)
88+
89+
@classmethod
90+
def _extract_dialects(cls, data: Mapping[str, JsonValue]) -> Sequence[str]:
91+
try:
92+
dialects_unsafe = data["dialects"]
93+
except KeyError as e:
94+
raise ValueError("Missing 'dialects' attribute") from e
95+
if not isinstance(dialects_unsafe, list) or not dialects_unsafe or not _is_all_strings(dialects_unsafe):
96+
msg = f"Invalid 'dialects' attribute, expected a non-empty list of strings but got: {dialects_unsafe}"
97+
raise ValueError(msg)
98+
return cast(list[str], dialects_unsafe)
99+
100+
@classmethod
101+
def _extract_env_vars(cls, data: Mapping[str, JsonValue]) -> Mapping[str, str]:
102+
try:
103+
env_vars_unsafe = data["environment"]
104+
if not isinstance(env_vars_unsafe, Mapping) or not _is_all_strings(env_vars_unsafe.values()):
105+
msg = f"Invalid 'environment' entry, expected a mapping with string values but got: {env_vars_unsafe}"
106+
raise ValueError(msg)
107+
return cast(dict[str, str], env_vars_unsafe)
108+
except KeyError:
109+
return {}
110+
111+
@classmethod
112+
def _extract_command_line(cls, data: Mapping[str, JsonValue]) -> Sequence[str]:
113+
try:
114+
command_line = data["command_line"]
115+
except KeyError as e:
116+
raise ValueError("Missing 'command_line' attribute") from e
117+
if not isinstance(command_line, list) or not command_line or not _is_all_strings(command_line):
118+
msg = f"Invalid 'command_line' attribute, expected a non-empty list of strings but got: {command_line}"
119+
raise ValueError(msg)
120+
return cast(list[str], command_line)
76121

77122

78123
@dataclass
79124
class LSPConfig:
80125
path: Path
81126
remorph: _LSPRemorphConfigV1
82127
options: Mapping[str, Sequence[LSPConfigOptionV1]]
83-
custom: Mapping[str, Any]
128+
custom: Mapping[str, JsonValue]
84129

85130
@property
86131
def name(self):
@@ -92,20 +137,52 @@ def options_for_dialect(self, source_dialect: str) -> Sequence[LSPConfigOptionV1
92137
@classmethod
93138
def load(cls, path: Path) -> LSPConfig:
94139
yaml_text = path.read_text()
95-
data = yaml.safe_load(yaml_text)
96-
if not isinstance(data, dict):
97-
raise ValueError(f"Invalid transpiler config, expecting a dict, got a {type(data).__name__}")
98-
remorph_data = data.get("remorph", None)
99-
if not isinstance(remorph_data, dict):
100-
raise ValueError(f"Invalid transpiler config, expecting a 'remorph' dict entry, got {remorph_data}")
101-
remorph = _LSPRemorphConfigV1.parse(remorph_data)
102-
options_data = data.get("options", {})
103-
if not isinstance(options_data, dict):
104-
raise ValueError(f"Invalid transpiler config, expecting an 'options' dict entry, got {options_data}")
105-
options = LSPConfigOptionV1.parse_all(options_data)
106-
custom = data.get("custom", {})
140+
data: RootJsonValue = yaml.safe_load(yaml_text)
141+
if not isinstance(data, Mapping):
142+
msg = f"Invalid transpiler configuration, expecting a root object but got: {data}"
143+
raise ValueError(msg)
144+
145+
remorph = cls._extract_remorph_data(data)
146+
options = cls._extract_options(data)
147+
custom = cls._extract_custom(data)
107148
return LSPConfig(path, remorph, options, custom)
108149

150+
@classmethod
151+
def _extract_remorph_data(cls, data: Mapping[str, JsonValue]) -> _LSPRemorphConfigV1:
152+
try:
153+
remorph_data = data["remorph"]
154+
except KeyError as e:
155+
raise ValueError("Missing 'remorph' attribute") from e
156+
if not isinstance(remorph_data, Mapping):
157+
msg = f"Invalid transpiler config, 'remorph' entry must be an object but got: {remorph_data}"
158+
raise ValueError(msg)
159+
return _LSPRemorphConfigV1.parse(remorph_data)
160+
161+
@classmethod
162+
def _extract_options(cls, data: Mapping[str, JsonValue]) -> Mapping[str, Sequence[LSPConfigOptionV1]]:
163+
try:
164+
options_data_unsfe = data["options"]
165+
except KeyError:
166+
# Optional, so no problem if missing
167+
return {}
168+
if not isinstance(options_data_unsfe, Mapping) or not _is_all_sequences(options_data_unsfe.values()):
169+
msg = f"Invalid transpiler config, 'options' must be an object with list properties but got: {options_data_unsfe}"
170+
raise ValueError(msg)
171+
options_data = cast(dict[str, Sequence[JsonValue]], options_data_unsfe)
172+
return LSPConfigOptionV1.parse_all(options_data)
173+
174+
@classmethod
175+
def _extract_custom(cls, data: Mapping[str, JsonValue]) -> Mapping[str, JsonValue]:
176+
try:
177+
custom = data["custom"]
178+
if not isinstance(custom, Mapping):
179+
msg = f"Invalid 'custom' entry, expected a mapping but got: {custom}"
180+
raise ValueError(msg)
181+
return custom
182+
except KeyError:
183+
# Optional, so no problem if missing
184+
return {}
185+
109186

110187
def lsp_feature(name: str, options: Any | None = None):
111188
def wrapped(func: Callable):

tests/integration/transpile/test_repository.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def test_installed_transpiler_info(transpiler_repository: TranspilerRepository)
1414
"overrides-file",
1515
LSPPromptMethod.QUESTION,
1616
"Specify the config file to override the default[Bladebridge] config - press <enter> for none",
17-
[],
1817
default='<none>',
1918
)
2019
target_tech = LSPConfigOptionV1(

tests/unit/test_install.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def test_runs_and_stores_force_config_option(
11311131
)
11321132

11331133
transpiler_repository = _StubTranspilerRepository(
1134-
tmp_path / "labs", config_options=(LSPConfigOptionV1(flag="-XX", method=LSPPromptMethod.FORCE, default=1254),)
1134+
tmp_path / "labs", config_options=(LSPConfigOptionV1(flag="-XX", method=LSPPromptMethod.FORCE, default="1254"),)
11351135
)
11361136

11371137
workspace_installer = ws_installer(
@@ -1150,7 +1150,7 @@ def test_runs_and_stores_force_config_option(
11501150
expected_config = LakebridgeConfiguration(
11511151
transpile=TranspileConfig(
11521152
transpiler_config_path=PATH_TO_TRANSPILER_CONFIG,
1153-
transpiler_options={"-XX": 1254},
1153+
transpiler_options={"-XX": "1254"},
11541154
source_dialect="snowflake",
11551155
input_source="/tmp/queries/snow",
11561156
output_folder="/tmp/queries/databricks",
@@ -1165,7 +1165,7 @@ def test_runs_and_stores_force_config_option(
11651165
"config.yml",
11661166
{
11671167
"transpiler_config_path": PATH_TO_TRANSPILER_CONFIG,
1168-
"transpiler_options": {'-XX': 1254},
1168+
"transpiler_options": {'-XX': "1254"},
11691169
"catalog_name": "remorph_test",
11701170
"input_source": "/tmp/queries/snow",
11711171
"output_folder": "/tmp/queries/databricks",

0 commit comments

Comments
 (0)