Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1829870: Allow structured types to be enabled by default #2727

Merged
merged 24 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d9bf2cb
SNOW-1829870: Allow structured types to be enabled by default
sfc-gh-jrose Dec 5, 2024
ec43e1a
type checking
sfc-gh-jrose Dec 6, 2024
7f3a5fd
lint
sfc-gh-jrose Dec 6, 2024
2e0dce9
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
ed232de
Move flag to context
sfc-gh-jrose Dec 16, 2024
0dd7b91
typo
sfc-gh-jrose Dec 16, 2024
13c1424
SNOW-1852779 Fix AST encoding for Column `in_`, `asc`, and `desc` (#2…
sfc-gh-vbudati Dec 16, 2024
a787e74
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
b32806f
merge main and fix test
sfc-gh-jrose Dec 17, 2024
c3db223
make feature flag thread safe
sfc-gh-jrose Dec 17, 2024
1c262d7
typo
sfc-gh-jrose Dec 17, 2024
869931f
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 17, 2024
0caef58
Fix ast test
sfc-gh-jrose Dec 17, 2024
2380040
move lock
sfc-gh-jrose Dec 18, 2024
995e519
test coverage
sfc-gh-jrose Dec 18, 2024
1b89027
remove context manager
sfc-gh-jrose Dec 18, 2024
4fc61d4
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
26fd29e
switch to using patch
sfc-gh-jrose Dec 19, 2024
9295e11
move test to other module
sfc-gh-jrose Dec 19, 2024
fcd16d7
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
77a57a6
fix broken import
sfc-gh-jrose Dec 19, 2024
4769169
another broken import
sfc-gh-jrose Dec 19, 2024
af5af87
another test fix
sfc-gh-jrose Dec 19, 2024
dea741b
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,16 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False)
return f"'{binascii.hexlify(bytes(value)).decode()}' :: BINARY"

if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"
type_str = "ARRAY"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(value, dict) and isinstance(datatype, MapType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: OBJECT"
type_str = "OBJECT"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(datatype, VariantType):
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
Expand Down Expand Up @@ -214,11 +220,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
return "to_timestamp('2020-09-16 06:30:00')"
if isinstance(data_type, ArrayType):
if data_type.structured:
assert isinstance(data_type.element_type, DataType)
element = schema_expression(data_type.element_type, is_nullable)
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
assert isinstance(data_type.key_type, DataType)
assert isinstance(data_type.value_type, DataType)
key = schema_expression(data_type.key_type, is_nullable)
value = schema_expression(data_type.value_type, is_nullable)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
Expand Down
19 changes: 13 additions & 6 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_origin,
)

import snowflake.snowpark.context as context
import snowflake.snowpark.types # type: ignore
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata
Expand Down Expand Up @@ -183,12 +184,15 @@ def convert_sf_to_sp_type(
max_string_size: int,
) -> DataType:
"""Convert the Snowflake logical type to the Snowpark type."""
semi_structured_fill = (
None if context._should_use_structured_type_semantics else StringType()
)
if column_type_name == "ARRAY":
return ArrayType(StringType())
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
return VariantType()
if column_type_name in {"OBJECT", "MAP"}:
return MapType(StringType(), StringType())
return MapType(semi_structured_fill, semi_structured_fill)
if column_type_name == "GEOGRAPHY":
return GeographyType()
if column_type_name == "GEOMETRY":
Expand Down Expand Up @@ -530,7 +534,10 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType
return a


def python_value_str_to_object(value, tp: DataType) -> Any:
def python_value_str_to_object(value, tp: Optional[DataType]) -> Any:
if tp is None:
return None

if isinstance(tp, StringType):
return value

Expand Down Expand Up @@ -639,7 +646,7 @@ def python_type_to_snow_type(
element_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return ArrayType(element_type), False

Expand All @@ -649,12 +656,12 @@ def python_type_to_snow_type(
key_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
value_type = (
python_type_to_snow_type(tp_args[1], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return MapType(key_type, value_type), False

Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False


def get_active_session() -> "snowflake.snowpark.Session":
"""Returns the current active Snowpark session.

Expand Down
57 changes: 44 additions & 13 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import Enum
from typing import Generic, List, Optional, Type, TypeVar, Union, Dict, Any

import snowflake.snowpark.context as context
import snowflake.snowpark._internal.analyzer.expression as expression
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

Expand Down Expand Up @@ -333,10 +334,16 @@ class ArrayType(DataType):
def __init__(
self,
element_type: Optional[DataType] = None,
structured: bool = False,
structured: Optional[bool] = None,
) -> None:
self.structured = structured
self.element_type = element_type if element_type else StringType()
if context._should_use_structured_type_semantics:
self.structured = (
structured if structured is not None else element_type is not None
)
self.element_type = element_type
else:
Comment on lines +343 to +344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can element_type be None here? What does it mean for the column type to be so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If element_type is None than this is a semi-structured array column and could contain anything.

self.structured = structured or False
self.element_type = element_type if element_type else StringType()

def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"
Expand Down Expand Up @@ -379,14 +386,30 @@ def __init__(
self,
key_type: Optional[DataType] = None,
value_type: Optional[DataType] = None,
structured: bool = False,
structured: Optional[bool] = None,
) -> None:
self.structured = structured
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()
if context._should_use_structured_type_semantics:
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
raise ValueError(
"Must either set both key_type and value_type or leave both unset."
)
self.structured = (
structured if structured is not None else key_type is not None
)
self.key_type = key_type
self.value_type = value_type
else:
self.structured = structured or False
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()

def __repr__(self) -> str:
return f"MapType({repr(self.key_type) if self.key_type else ''}, {repr(self.value_type) if self.value_type else ''})"
type_str = ""
if self.key_type and self.value_type:
type_str = f"{repr(self.key_type)}, {repr(self.value_type)}"
return f"MapType({type_str})"

def is_primitive(self):
return False
Expand Down Expand Up @@ -617,12 +640,20 @@ class StructType(DataType):
"""Represents a table schema or structured column. Contains :class:`StructField` for each field."""

def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
self,
fields: Optional[List["StructField"]] = None,
structured: Optional[bool] = False,
) -> None:
self.structured = structured
if fields is None:
fields = []
self.fields = fields
if context._should_use_structured_type_semantics:
self.structured = (
structured if structured is not None else fields is not None
)
self.fields = fields or []
else:
self.structured = structured or False
if fields is None:
fields = []
self.fields = fields

def add(
self,
Expand Down
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
TempObjectType,
parse_positional_args_to_list,
publicapi,
warning,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.types import DataType
from snowflake.snowpark.types import DataType, MapType

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -710,6 +711,14 @@ def _do_register_udaf(
name,
)

if isinstance(return_type, MapType):
if return_type.structured:
warning(
"_do_register_udaf",
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
)
return_type = MapType()

# Capture original parameters.
if _emit_ast:
stmt = self._session._ast_batch.assign()
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,10 @@ def _do_register_udtf(
output_schema=output_schema,
)

# Structured Struct is interpreted as Object by function registration
# Force unstructured to ensure Table return type.
output_schema.structured = False

# Capture original parameters.
if _emit_ast:
stmt = self._session._ast_batch.assign()
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_literal(session, local_testing_mode):
BooleanType(),
# snowflake doesn't enforce the inner type of ArrayType, so it is expected that
# it returns StringType() as inner type.
ArrayType(LongType()) if local_testing_mode else ArrayType(StringType()),
ArrayType(LongType()) if local_testing_mode else ArrayType(),
]
verify_column_result(
session,
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2797,7 +2797,7 @@ def test_save_as_table_with_table_sproc_output(session, save_mode, table_type):
lambda session_: session_.sql("SELECT 1 as A"),
packages=["snowflake-snowpark-python"],
name=temp_sp_name,
return_type=StructType([StructField("A", IntegerType())]),
return_type=StructType([StructField("A", IntegerType())], structured=False),
input_types=[],
replace=True,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_df_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_df_to_pandas_df(session):
StructField("n", TimestampType()),
StructField("o", TimeType()),
StructField("p", VariantType()),
StructField("q", MapType(StringType(), StringType())),
StructField("q", MapType()),
]
),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,9 +1291,9 @@ def test_to_date_to_array_to_variant_to_object(session, local_testing_mode):
Utils.assert_rows(res1, expected)

assert df1.schema.fields[0].datatype == DateType()
assert df1.schema.fields[1].datatype == ArrayType(StringType())
assert df1.schema.fields[1].datatype == ArrayType()
assert df1.schema.fields[2].datatype == VariantType()
assert df1.schema.fields[3].datatype == MapType(StringType(), StringType())
assert df1.schema.fields[3].datatype == MapType()


def test_to_binary(session):
Expand Down
Loading