Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1847626: Add support for value_contains_null to MapType #2771

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Added support for `DataFrame.map`.
- Added support for `DataFrame.from_dict` and `DataFrame.from_records`.
- Added support for mixed case field names in struct type columns.
- Added support for `value_contains_null` parameter to MapType.

#### Improvements
- Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,12 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
key = schema_expression(data_type.key_type, is_nullable)
value = schema_expression(data_type.value_type, is_nullable)
# Key values can never be null
key = schema_expression(data_type.key_type, False)
# Value nullability is variable. Defaults to True
value = schema_expression(
data_type.value_type, data_type.value_contains_null
)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
return "to_object(parse_json('0'))"
if isinstance(data_type, StructType):
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def convert_metadata_to_sp_type(
convert_metadata_to_sp_type(metadata.fields[0], max_string_size),
convert_metadata_to_sp_type(metadata.fields[1], max_string_size),
structured=True,
value_contains_null=metadata.fields[1]._is_nullable,
)
else:
assert all(
Expand Down Expand Up @@ -290,7 +291,8 @@ def convert_sp_to_sf_type(datatype: DataType) -> str:
return "ARRAY"
if isinstance(datatype, MapType):
if datatype.structured:
return f"MAP({convert_sp_to_sf_type(datatype.key_type)}, {convert_sp_to_sf_type(datatype.value_type)})"
nullable = "" if datatype.value_contains_null else " NOT NULL"
return f"MAP({convert_sp_to_sf_type(datatype.key_type)}, {convert_sp_to_sf_type(datatype.value_type)}{nullable})"
else:
return "OBJECT"
if isinstance(datatype, StructType):
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,12 @@ def __init__(
key_type: Optional[DataType] = None,
value_type: Optional[DataType] = None,
structured: bool = False,
value_contains_null: bool = True,
) -> None:
self.structured = structured
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()
self.value_contains_null = value_contains_null

def __repr__(self) -> str:
return f"MapType({repr(self.key_type) if self.key_type else ''}, {repr(self.value_type) if self.value_type else ''})"
Expand Down
44 changes: 42 additions & 2 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,45 @@ def test_structured_type_print_schema(
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
)
def test_structured_map_value_contains_null(
structured_type_session, structured_type_support
):
if not structured_type_support:
pytest.skip("Test requires structured type support.")

# SNOW-1862947 create DDL test once save as table supported
array_df = structured_type_session.sql(
"select {'test' : 'test'} :: MAP(STRING, STRING NOT NULL) AS M, {'test' : 'test'} :: MAP(STRING, STRING) AS M_N"
)
expected_schema = StructType(
[
StructField(
"M",
MapType(
StringType(),
StringType(),
structured=True,
value_contains_null=False,
),
),
StructField(
"M_N",
MapType(
StringType(),
StringType(),
structured=True,
value_contains_null=True,
),
),
]
)
assert array_df.schema == expected_schema


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
Expand Down Expand Up @@ -1104,12 +1143,13 @@ def test_structured_type_schema_expression(
assert table.union(table).schema == expected_schema
# Functions used in schema generation don't respect nested nullability so compare query string instead
non_null_union = non_null_table.union(non_null_table)
# __import__('pdb').set_trace()
assert non_null_union._plan.schema_query == (
"( SELECT object_construct_keep_null('a' :: STRING (16777216), 0 :: DOUBLE) :: "
"( SELECT object_construct_keep_null('a' :: STRING (16777216), NULL :: DOUBLE) :: "
'MAP(STRING(16777216), DOUBLE) AS "MAP", to_array(0 :: DOUBLE) :: ARRAY(DOUBLE) AS "ARR",'
" object_construct_keep_null('FIELD1', 'a' :: STRING (16777216), 'FIELD2', 0 :: "
'DOUBLE) :: OBJECT(FIELD1 STRING(16777216), FIELD2 DOUBLE) AS "OBJ") UNION ( SELECT '
"object_construct_keep_null('a' :: STRING (16777216), 0 :: DOUBLE) :: "
"object_construct_keep_null('a' :: STRING (16777216), NULL :: DOUBLE) :: "
'MAP(STRING(16777216), DOUBLE) AS "MAP", to_array(0 :: DOUBLE) :: ARRAY(DOUBLE) AS "ARR", '
"object_construct_keep_null('FIELD1', 'a' :: STRING (16777216), 'FIELD2', 0 :: "
'DOUBLE) :: OBJECT(FIELD1 STRING(16777216), FIELD2 DOUBLE) AS "OBJ")'
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,18 @@ def test_convert_sp_to_sf_type():
assert convert_sp_to_sf_type(BinaryType()) == "BINARY"
assert convert_sp_to_sf_type(ArrayType()) == "ARRAY"
assert convert_sp_to_sf_type(MapType()) == "OBJECT"
assert (
convert_sp_to_sf_type(MapType(StringType(), StringType(), structured=True))
== "MAP(STRING, STRING)"
)
assert (
convert_sp_to_sf_type(
MapType(
StringType(), StringType(), structured=True, value_contains_null=False
)
)
== "MAP(STRING, STRING NOT NULL)"
)
assert convert_sp_to_sf_type(StructType()) == "OBJECT"
assert convert_sp_to_sf_type(VariantType()) == "VARIANT"
assert convert_sp_to_sf_type(GeographyType()) == "GEOGRAPHY"
Expand Down
Loading