Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Snowflake object name handling in the Snowpark AST #2789

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
54 changes: 33 additions & 21 deletions src/snowflake/snowpark/_internal/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def build_proto_from_struct_type(
ast_field.nullable = field.nullable


def build_sp_name(name: Union[str, Iterable[str]], expr: proto.SpName) -> None:
if isinstance(name, str):
expr.sp_name_flat.name = name
elif isinstance(name, Iterable):
expr.sp_name_structured.name.extend(name)
else:
raise ValueError(
f"Invalid object name: {name}. The object name must be a string or an iterable of strings."
)


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def _set_fn_name(
name: Union[str, Iterable[str]], fn: proto.FnNameRefExpr
Expand All @@ -358,26 +369,27 @@ def _set_fn_name(
Raises:
ValueError: Raised if the function name is not a string or an iterable of strings.
"""
if isinstance(name, str):
fn.name.fn_name_flat.name = name # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
elif isinstance(name, Iterable):
fn.name.fn_name_structured.name.extend(name) # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
else:
raise ValueError(
f"Invalid function name: {name}. The function name must be a string or an iterable of strings."
)
try:
build_sp_name(name, fn.name.name)
except ValueError as e:
raise ValueError("Invalid function name") from e


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def build_sp_table_name( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function is missing a return type annotation
expr_builder: proto.SpTableName, name: Union[str, Iterable[str]]
): # pragma: no cover
if isinstance(name, str):
expr_builder.sp_table_name_flat.name = name
elif isinstance(name, Iterable):
expr_builder.sp_table_name_structured.name.extend(name)
else:
raise ValueError(f"Invalid name type {type(name)} for SpTableName entity.")
def build_sp_table_name(
expr_builder: proto.SpNameRef, name: Union[str, Iterable[str]]
) -> None: # pragma: no cover
try:
build_sp_name(name, expr_builder.name)
except ValueError as e:
raise ValueError("Invalid table name") from e


def build_sp_view_name(expr: proto.SpNameRef, name: Union[str, Iterable[str]]) -> None:
try:
build_sp_name(name, expr.name)
except ValueError as e:
raise ValueError("Invalid view name") from e


def build_function_expr(
Expand Down Expand Up @@ -1108,7 +1120,7 @@ def build_udf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function i
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1197,7 +1209,7 @@ def build_udaf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location.value = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1294,7 +1306,7 @@ def build_udtf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1406,7 +1418,7 @@ def build_sproc( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down
Loading
Loading