diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 0c466aeb67a0..f18e4305713f 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x92\x38\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12?\n\x0cmerge_action\x18\x13 \x01(\x0b\x32\x1a.spark.connect.MergeActionH\x00R\x0bmergeAction\x12g\n\x1atyped_aggregate_expression\x18\x14 \x01(\x0b\x32\'.spark.connect.TypedAggregateExpressionH\x00R\x18typedAggregateExpression\x12T\n\x13subquery_expression\x18\x15 \x01(\x0b\x32!.spark.connect.SubqueryExpressionH\x00R\x12subqueryExpression\x12s\n\x1b\x64irect_shuffle_partition_id\x18\x16 \x01(\x0b\x32\x32.spark.connect.Expression.DirectShufflePartitionIDH\x00R\x18\x64irectShufflePartitionId\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1aK\n\x18\x44irectShufflePartitionID\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\x9e\x11\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x61\n\x11specialized_array\x18\x19 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.SpecializedArrayH\x00R\x10specializedArray\x12<\n\x04time\x18\x1a \x01(\x0b\x32&.spark.connect.Expression.Literal.TimeH\x00R\x04time\x12\x34\n\tdata_type\x18\x64 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x86\x01\n\x05\x41rray\x12>\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xeb\x01\n\x03Map\x12\x36\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\x07keyType\x12:\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x85\x01\n\x06Struct\x12<\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xc0\x02\n\x10SpecializedArray\x12,\n\x05\x62ools\x18\x01 \x01(\x0b\x32\x14.spark.connect.BoolsH\x00R\x05\x62ools\x12)\n\x04ints\x18\x02 \x01(\x0b\x32\x13.spark.connect.IntsH\x00R\x04ints\x12,\n\x05longs\x18\x03 \x01(\x0b\x32\x14.spark.connect.LongsH\x00R\x05longs\x12/\n\x06\x66loats\x18\x04 \x01(\x0b\x32\x15.spark.connect.FloatsH\x00R\x06\x66loats\x12\x32\n\x07\x64oubles\x18\x05 \x01(\x0b\x32\x16.spark.connect.DoublesH\x00R\x07\x64oubles\x12\x32\n\x07strings\x18\x06 \x01(\x0b\x32\x16.spark.connect.StringsH\x00R\x07stringsB\x0c\n\nvalue_type\x1aK\n\x04Time\x12\x12\n\x04nano\x18\x01 \x01(\x03R\x04nano\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x42\x0c\n\n_precisionB\x0e\n\x0cliteral_typeJ\x04\x08\x1b\x10\x1cJ\x04\x08\x1c\x10\x1d\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\x82\x02\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x12$\n\x0bis_internal\x18\x05 \x01(\x08H\x00R\nisInternal\x88\x01\x01\x42\x0e\n\x0c_is_internal\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"A\n\x10\x45xpressionCommon\x12-\n\x06origin\x18\x01 \x01(\x0b\x32\x15.spark.connect.OriginR\x06origin"\x8d\x03\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdf\x12\x1f\n\x0bis_distinct\x18\x07 \x01(\x08R\nisDistinctB\n\n\x08\x66unction"\xcc\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer\x12/\n\x13\x61\x64\x64itional_includes\x18\x05 \x03(\tR\x12\x61\x64\x64itionalIncludes"\xd6\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable\x12\x1c\n\taggregate\x18\x05 \x01(\x08R\taggregate"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"c\n\x18TypedAggregateExpression\x12G\n\x10scalar_scala_udf\x18\x01 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFR\x0escalarScalaUdf"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\x80\x04\n\x0bMergeAction\x12\x46\n\x0b\x61\x63tion_type\x18\x01 \x01(\x0e\x32%.spark.connect.MergeAction.ActionTypeR\nactionType\x12<\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\tcondition\x88\x01\x01\x12G\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32%.spark.connect.MergeAction.AssignmentR\x0b\x61ssignments\x1aj\n\nAssignment\x12+\n\x03key\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\xa7\x01\n\nActionType\x12\x17\n\x13\x41\x43TION_TYPE_INVALID\x10\x00\x12\x16\n\x12\x41\x43TION_TYPE_DELETE\x10\x01\x12\x16\n\x12\x41\x43TION_TYPE_INSERT\x10\x02\x12\x1b\n\x17\x41\x43TION_TYPE_INSERT_STAR\x10\x03\x12\x16\n\x12\x41\x43TION_TYPE_UPDATE\x10\x04\x12\x1b\n\x17\x41\x43TION_TYPE_UPDATE_STAR\x10\x05\x42\x0c\n\n_condition"\xc5\x05\n\x12SubqueryExpression\x12\x17\n\x07plan_id\x18\x01 \x01(\x03R\x06planId\x12S\n\rsubquery_type\x18\x02 \x01(\x0e\x32..spark.connect.SubqueryExpression.SubqueryTypeR\x0csubqueryType\x12\x62\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.spark.connect.SubqueryExpression.TableArgOptionsH\x00R\x0ftableArgOptions\x88\x01\x01\x12G\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x10inSubqueryValues\x1a\xea\x01\n\x0fTableArgOptions\x12@\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12\x37\n\x15with_single_partition\x18\x03 \x01(\x08H\x00R\x13withSinglePartition\x88\x01\x01\x42\x18\n\x16_with_single_partition"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_optionsB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x9c\x39\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12?\n\x0cmerge_action\x18\x13 \x01(\x0b\x32\x1a.spark.connect.MergeActionH\x00R\x0bmergeAction\x12g\n\x1atyped_aggregate_expression\x18\x14 \x01(\x0b\x32\'.spark.connect.TypedAggregateExpressionH\x00R\x18typedAggregateExpression\x12T\n\x13subquery_expression\x18\x15 \x01(\x0b\x32!.spark.connect.SubqueryExpressionH\x00R\x12subqueryExpression\x12s\n\x1b\x64irect_shuffle_partition_id\x18\x16 \x01(\x0b\x32\x32.spark.connect.Expression.DirectShufflePartitionIDH\x00R\x18\x64irectShufflePartitionId\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1aK\n\x18\x44irectShufflePartitionID\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\xa8\x12\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x61\n\x11specialized_array\x18\x19 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.SpecializedArrayH\x00R\x10specializedArray\x12<\n\x04time\x18\x1a \x01(\x0b\x32&.spark.connect.Expression.Literal.TimeH\x00R\x04time\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\xc2\x01\n\x05\x41rray\x12>\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x12:\n\tdata_type\x18\x03 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayR\x08\x64\x61taType\x1a\xa5\x02\n\x03Map\x12\x36\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\x07keyType\x12:\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x38\n\tdata_type\x18\x05 \x01(\x0b\x32\x1b.spark.connect.DataType.MapR\x08\x64\x61taType\x1a\xcf\x01\n\x06Struct\x12<\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeB\x02\x18\x01R\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x12H\n\x10\x64\x61ta_type_struct\x18\x03 \x01(\x0b\x32\x1e.spark.connect.DataType.StructR\x0e\x64\x61taTypeStruct\x1a\xc0\x02\n\x10SpecializedArray\x12,\n\x05\x62ools\x18\x01 \x01(\x0b\x32\x14.spark.connect.BoolsH\x00R\x05\x62ools\x12)\n\x04ints\x18\x02 \x01(\x0b\x32\x13.spark.connect.IntsH\x00R\x04ints\x12,\n\x05longs\x18\x03 \x01(\x0b\x32\x14.spark.connect.LongsH\x00R\x05longs\x12/\n\x06\x66loats\x18\x04 \x01(\x0b\x32\x15.spark.connect.FloatsH\x00R\x06\x66loats\x12\x32\n\x07\x64oubles\x18\x05 \x01(\x0b\x32\x16.spark.connect.DoublesH\x00R\x07\x64oubles\x12\x32\n\x07strings\x18\x06 \x01(\x0b\x32\x16.spark.connect.StringsH\x00R\x07stringsB\x0c\n\nvalue_type\x1aK\n\x04Time\x12\x12\n\x04nano\x18\x01 \x01(\x03R\x04nano\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x42\x0c\n\n_precisionB\x0e\n\x0cliteral_typeJ\x04\x08\x1b\x10\x1cJ\x04\x08\x1c\x10\x1d\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\x82\x02\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x12$\n\x0bis_internal\x18\x05 \x01(\x08H\x00R\nisInternal\x88\x01\x01\x42\x0e\n\x0c_is_internal\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"A\n\x10\x45xpressionCommon\x12-\n\x06origin\x18\x01 \x01(\x0b\x32\x15.spark.connect.OriginR\x06origin"\x8d\x03\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdf\x12\x1f\n\x0bis_distinct\x18\x07 \x01(\x08R\nisDistinctB\n\n\x08\x66unction"\xcc\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer\x12/\n\x13\x61\x64\x64itional_includes\x18\x05 \x03(\tR\x12\x61\x64\x64itionalIncludes"\xd6\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable\x12\x1c\n\taggregate\x18\x05 \x01(\x08R\taggregate"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"c\n\x18TypedAggregateExpression\x12G\n\x10scalar_scala_udf\x18\x01 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFR\x0escalarScalaUdf"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\x80\x04\n\x0bMergeAction\x12\x46\n\x0b\x61\x63tion_type\x18\x01 \x01(\x0e\x32%.spark.connect.MergeAction.ActionTypeR\nactionType\x12<\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\tcondition\x88\x01\x01\x12G\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32%.spark.connect.MergeAction.AssignmentR\x0b\x61ssignments\x1aj\n\nAssignment\x12+\n\x03key\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\xa7\x01\n\nActionType\x12\x17\n\x13\x41\x43TION_TYPE_INVALID\x10\x00\x12\x16\n\x12\x41\x43TION_TYPE_DELETE\x10\x01\x12\x16\n\x12\x41\x43TION_TYPE_INSERT\x10\x02\x12\x1b\n\x17\x41\x43TION_TYPE_INSERT_STAR\x10\x03\x12\x16\n\x12\x41\x43TION_TYPE_UPDATE\x10\x04\x12\x1b\n\x17\x41\x43TION_TYPE_UPDATE_STAR\x10\x05\x42\x0c\n\n_condition"\xc5\x05\n\x12SubqueryExpression\x12\x17\n\x07plan_id\x18\x01 \x01(\x03R\x06planId\x12S\n\rsubquery_type\x18\x02 \x01(\x0e\x32..spark.connect.SubqueryExpression.SubqueryTypeR\x0csubqueryType\x12\x62\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.spark.connect.SubqueryExpression.TableArgOptionsH\x00R\x0ftableArgOptions\x88\x01\x01\x12G\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x10inSubqueryValues\x1a\xea\x01\n\x0fTableArgOptions\x12@\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12\x37\n\x15with_single_partition\x18\x03 \x01(\x08H\x00R\x13withSinglePartition\x88\x01\x01\x42\x18\n\x16_with_single_partition"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_optionsB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -68,7 +68,7 @@ "struct_type" ]._serialized_options = b"\030\001" _globals["_EXPRESSION"]._serialized_start = 133 - _globals["_EXPRESSION"]._serialized_end = 7319 + _globals["_EXPRESSION"]._serialized_end = 7457 _globals["_EXPRESSION_WINDOW"]._serialized_start = 2103 _globals["_EXPRESSION_WINDOW"]._serialized_end = 2886 _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2393 @@ -90,67 +90,67 @@ _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3595 _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3693 _globals["_EXPRESSION_LITERAL"]._serialized_start = 3712 - _globals["_EXPRESSION_LITERAL"]._serialized_end = 5918 - _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4762 - _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4879 - _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4881 - _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4979 - _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4982 - _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 5116 - _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 5119 - _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 5354 - _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 5357 - _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5490 - _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_start = 5493 - _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_end = 5813 - _globals["_EXPRESSION_LITERAL_TIME"]._serialized_start = 5815 - _globals["_EXPRESSION_LITERAL_TIME"]._serialized_end = 5890 - _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5921 - _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 6107 - _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 6110 - _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 6368 - _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 6370 - _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 6420 - _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 6422 - _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 6546 - _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 6548 - _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 6634 - _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 6637 - _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 6769 - _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 6772 - _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6959 - _globals["_EXPRESSION_ALIAS"]._serialized_start = 6961 - _globals["_EXPRESSION_ALIAS"]._serialized_end = 7081 - _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 7084 - _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 7242 - _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 7244 - _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 7306 - _globals["_EXPRESSIONCOMMON"]._serialized_start = 7321 - _globals["_EXPRESSIONCOMMON"]._serialized_end = 7386 - _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 7389 - _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 7786 - _globals["_PYTHONUDF"]._serialized_start = 7789 - _globals["_PYTHONUDF"]._serialized_end = 7993 - _globals["_SCALARSCALAUDF"]._serialized_start = 7996 - _globals["_SCALARSCALAUDF"]._serialized_end = 8210 - _globals["_JAVAUDF"]._serialized_start = 8213 - _globals["_JAVAUDF"]._serialized_end = 8362 - _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 8364 - _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 8463 - _globals["_CALLFUNCTION"]._serialized_start = 8465 - _globals["_CALLFUNCTION"]._serialized_end = 8573 - _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 8575 - _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 8667 - _globals["_MERGEACTION"]._serialized_start = 8670 - _globals["_MERGEACTION"]._serialized_end = 9182 - _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 8892 - _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8998 - _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 9001 - _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 9168 - _globals["_SUBQUERYEXPRESSION"]._serialized_start = 9185 - _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9894 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9491 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9725 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9728 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9872 + _globals["_EXPRESSION_LITERAL"]._serialized_end = 6056 + _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4708 + _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4825 + _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4827 + _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_end = 4925 + _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_start = 4928 + _globals["_EXPRESSION_LITERAL_ARRAY"]._serialized_end = 5122 + _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 5125 + _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 5418 + _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 5421 + _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5628 + _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_start = 5631 + _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_end = 5951 + _globals["_EXPRESSION_LITERAL_TIME"]._serialized_start = 5953 + _globals["_EXPRESSION_LITERAL_TIME"]._serialized_end = 6028 + _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 6059 + _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 6245 + _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 6248 + _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 6506 + _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 6508 + _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 6558 + _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 6560 + _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 6684 + _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 6686 + _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 6772 + _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 6775 + _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 6907 + _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 6910 + _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 7097 + _globals["_EXPRESSION_ALIAS"]._serialized_start = 7099 + _globals["_EXPRESSION_ALIAS"]._serialized_end = 7219 + _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 7222 + _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 7380 + _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 7382 + _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 7444 + _globals["_EXPRESSIONCOMMON"]._serialized_start = 7459 + _globals["_EXPRESSIONCOMMON"]._serialized_end = 7524 + _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 7527 + _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 7924 + _globals["_PYTHONUDF"]._serialized_start = 7927 + _globals["_PYTHONUDF"]._serialized_end = 8131 + _globals["_SCALARSCALAUDF"]._serialized_start = 8134 + _globals["_SCALARSCALAUDF"]._serialized_end = 8348 + _globals["_JAVAUDF"]._serialized_start = 8351 + _globals["_JAVAUDF"]._serialized_end = 8500 + _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 8502 + _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 8601 + _globals["_CALLFUNCTION"]._serialized_start = 8603 + _globals["_CALLFUNCTION"]._serialized_end = 8711 + _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 8713 + _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 8805 + _globals["_MERGEACTION"]._serialized_start = 8808 + _globals["_MERGEACTION"]._serialized_end = 9320 + _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 9030 + _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 9136 + _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 9139 + _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 9306 + _globals["_SUBQUERYEXPRESSION"]._serialized_start = 9323 + _globals["_SUBQUERYEXPRESSION"]._serialized_end = 10032 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9629 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9863 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9866 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 10010 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index e2e23dd8c553..1061ec922aea 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -496,6 +496,7 @@ class Expression(google.protobuf.message.Message): ELEMENT_TYPE_FIELD_NUMBER: builtins.int ELEMENTS_FIELD_NUMBER: builtins.int + DATA_TYPE_FIELD_NUMBER: builtins.int @property def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The element type of the array. @@ -509,19 +510,37 @@ class Expression(google.protobuf.message.Message): global___Expression.Literal ]: """The literal values that make up the array elements.""" + @property + def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: + """The type of the array. You don't need to set this field if the type information is not needed. + + If the element type can be inferred from the first element of the elements field, + then you don't need to set data_type.element_type to save space. + + On the other hand, redundant type information is also acceptable. + """ def __init__( self, *, element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., elements: collections.abc.Iterable[global___Expression.Literal] | None = ..., + data_type: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["element_type", b"element_type"] + self, + field_name: typing_extensions.Literal[ + "data_type", b"data_type", "element_type", b"element_type" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "element_type", b"element_type", "elements", b"elements" + "data_type", + b"data_type", + "element_type", + b"element_type", + "elements", + b"elements", ], ) -> None: ... @@ -532,6 +551,7 @@ class Expression(google.protobuf.message.Message): VALUE_TYPE_FIELD_NUMBER: builtins.int KEYS_FIELD_NUMBER: builtins.int VALUES_FIELD_NUMBER: builtins.int + DATA_TYPE_FIELD_NUMBER: builtins.int @property def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The key type of the map. @@ -559,6 +579,15 @@ class Expression(google.protobuf.message.Message): global___Expression.Literal ]: """The literal values that make up the map.""" + @property + def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: + """The type of the map. You don't need to set this field if the type information is not needed. + + If the key/value types can be inferred from the first element of the keys/values fields, + then you don't need to set data_type.key_type/data_type.value_type to save space. + + On the other hand, redundant type information is also acceptable. + """ def __init__( self, *, @@ -566,16 +595,19 @@ class Expression(google.protobuf.message.Message): value_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., keys: collections.abc.Iterable[global___Expression.Literal] | None = ..., values: collections.abc.Iterable[global___Expression.Literal] | None = ..., + data_type: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "key_type", b"key_type", "value_type", b"value_type" + "data_type", b"data_type", "key_type", b"key_type", "value_type", b"value_type" ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "data_type", + b"data_type", "key_type", b"key_type", "keys", @@ -592,6 +624,7 @@ class Expression(google.protobuf.message.Message): STRUCT_TYPE_FIELD_NUMBER: builtins.int ELEMENTS_FIELD_NUMBER: builtins.int + DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int @property def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The type of the struct. @@ -606,19 +639,35 @@ class Expression(google.protobuf.message.Message): global___Expression.Literal ]: """The literal values that make up the struct elements.""" + @property + def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct: + """The type of the struct. You don't need to set this field if the type information is not needed. + + Whether data_type_struct.fields.data_type should be set depends on + whether each field's type can be inferred from the elements field. + """ def __init__( self, *, struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., elements: collections.abc.Iterable[global___Expression.Literal] | None = ..., + data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["struct_type", b"struct_type"] + self, + field_name: typing_extensions.Literal[ + "data_type_struct", b"data_type_struct", "struct_type", b"struct_type" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "elements", b"elements", "struct_type", b"struct_type" + "data_type_struct", + b"data_type_struct", + "elements", + b"elements", + "struct_type", + b"struct_type", ], ) -> None: ... @@ -750,7 +799,6 @@ class Expression(google.protobuf.message.Message): STRUCT_FIELD_NUMBER: builtins.int SPECIALIZED_ARRAY_FIELD_NUMBER: builtins.int TIME_FIELD_NUMBER: builtins.int - DATA_TYPE_FIELD_NUMBER: builtins.int @property def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... binary: builtins.bytes @@ -784,14 +832,6 @@ class Expression(google.protobuf.message.Message): def specialized_array(self) -> global___Expression.Literal.SpecializedArray: ... @property def time(self) -> global___Expression.Literal.Time: ... - @property - def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: - """Data type information for the literal. - This field is required only in the root literal message for null values or - for data types (e.g., array, map, or struct) with non-trivial information. - If the data_type field is not set at the root level, the data type will be - inferred or retrieved from the deprecated data type fields using best efforts. - """ def __init__( self, *, @@ -817,7 +857,6 @@ class Expression(google.protobuf.message.Message): struct: global___Expression.Literal.Struct | None = ..., specialized_array: global___Expression.Literal.SpecializedArray | None = ..., time: global___Expression.Literal.Time | None = ..., - data_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., ) -> None: ... def HasField( self, @@ -832,8 +871,6 @@ class Expression(google.protobuf.message.Message): b"byte", "calendar_interval", b"calendar_interval", - "data_type", - b"data_type", "date", b"date", "day_time_interval", @@ -885,8 +922,6 @@ class Expression(google.protobuf.message.Message): b"byte", "calendar_interval", b"calendar_interval", - "data_type", - b"data_type", "date", b"date", "day_time_interval", diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index b9f72badd45f..2e0a8de95755 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -1562,13 +1562,13 @@ class ClientE2ETestSuite val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra")) val ob1Schema = new StructType() - .add("min(id)", LongType) - .add("avg(id)", DoubleType) - .add("max(id)", LongType) + .add("min(id)", LongType, nullable = false) + .add("avg(id)", DoubleType, nullable = false) + .add("max(id)", LongType, nullable = false) val ob2Schema = new StructType() - .add("min(extra)", LongType) - .add("avg(extra)", DoubleType) - .add("max(extra)", LongType) + .add("min(extra)", LongType, nullable = false) + .add("avg(extra)", DoubleType, nullable = false) + .add("max(extra)", LongType, nullable = false) val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema)) val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema)) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala index 389b3a5c52ac..ded37873a477 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala @@ -83,14 +83,11 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { .addElements(proto.Expression.Literal.newBuilder().setString("north").build()) .addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build()) .addElements(proto.Expression.Literal.newBuilder().setString("west").build()) - b.getLiteralBuilder.getDataTypeBuilder.setStruct( - proto.DataType.Struct - .newBuilder() - .addFields(structField("_1", ProtoDataTypes.DoubleType)) - .addFields(structField("_2", stringTypeWithCollation)) - .addFields(structField("_3", ProtoDataTypes.DoubleType)) - .addFields(structField("_4", stringTypeWithCollation)) - .build()) + .getDataTypeStructBuilder + .addFields(structField("_1", ProtoDataTypes.DoubleType)) + .addFields(structField("_2", stringTypeWithCollation)) + .addFields(structField("_3", ProtoDataTypes.DoubleType)) + .addFields(structField("_4", stringTypeWithCollation)) }) } diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index f74c5af11782..2c882249ae7d 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -210,13 +210,6 @@ message Expression { // Reserved for Geometry and Geography. reserved 27, 28; - // Data type information for the literal. - // This field is required only in the root literal message for null values or - // for data types (e.g., array, map, or struct) with non-trivial information. - // If the data_type field is not set at the root level, the data type will be - // inferred or retrieved from the deprecated data type fields using best efforts. - DataType data_type = 100; - message Decimal { // the string representation. string value = 1; @@ -241,6 +234,14 @@ message Expression { // The literal values that make up the array elements. repeated Literal elements = 2; + + // The type of the array. You don't need to set this field if the type information is not needed. + // + // If the element type can be inferred from the first element of the elements field, + // then you don't need to set data_type.element_type to save space. + // + // On the other hand, redundant type information is also acceptable. + DataType.Array data_type = 3; } message Map { @@ -260,6 +261,14 @@ message Expression { // The literal values that make up the map. repeated Literal values = 4; + + // The type of the map. You don't need to set this field if the type information is not needed. + // + // If the key/value types can be inferred from the first element of the keys/values fields, + // then you don't need to set data_type.key_type/data_type.value_type to save space. + // + // On the other hand, redundant type information is also acceptable. + DataType.Map data_type = 5; } message Struct { @@ -271,6 +280,12 @@ message Expression { // The literal values that make up the struct elements. repeated Literal elements = 2; + + // The type of the struct. You don't need to set this field if the type information is not needed. + // + // Whether data_type_struct.fields.data_type should be set depends on + // whether each field's type can be inferred from the elements field. + DataType.Struct data_type_struct = 3; } message SpecializedArray { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 43265e55a0ca..2c0e96cd9a7c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, FromProtoToScalaConverter} +import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils private[sql] class SparkResult[T]( @@ -422,17 +422,17 @@ private[sql] object SparkResult { /** Return value is a Seq of pairs, to preserve the order of values. */ private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = { - assert(metric.getKeysCount == metric.getValuesCount) - var schema = new StructType() + require(metric.getKeysCount == metric.getValuesCount) + val fields = mutable.ArrayBuilder.make[StructField] val values = mutable.ArrayBuilder.make[Any] + fields.sizeHint(metric.getKeysCount) values.sizeHint(metric.getKeysCount) - (0 until metric.getKeysCount).foreach { i => - val key = metric.getKeys(i) - val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i)) - schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i))) + Range(0, metric.getKeysCount).foreach { i => + val (dataType, value) = FromProtoToScalaConverter.convert(metric.getValues(i)) + fields += StructField(metric.getKeys(i), dataType, value == null) values += value } - new GenericRowWithSchema(values.result(), schema) + new GenericRowWithSchema(values.result(), new StructType(fields.result())) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala index cbbec0599b77..56871446f404 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoBuilderWithOptions, ToLiteralProtoOptions} +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoWithOptions, ToLiteralProtoOptions} import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame} @@ -67,7 +67,7 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { n match { case Literal(value, dataTypeOpt, _) => builder.setLiteral( - toLiteralProtoBuilderWithOptions( + toLiteralProtoWithOptions( value, dataTypeOpt, ToLiteralProtoOptions(useDeprecatedDataTypeFields = false))) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index ac69f084c307..df57e0f9400c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -120,7 +120,7 @@ object DataTypeProtoConverter { } } - private def toCatalystArrayType(t: proto.DataType.Array): ArrayType = { + private[common] def toCatalystArrayType(t: proto.DataType.Array): ArrayType = { ArrayType(toCatalystType(t.getElementType), t.getContainsNull) } @@ -140,7 +140,7 @@ object DataTypeProtoConverter { StructType.apply(fields) } - private def toCatalystMapType(t: proto.DataType.Map): MapType = { + private[common] def toCatalystMapType(t: proto.DataType.Map): MapType = { MapType(toCatalystType(t.getKeyType), toCatalystType(t.getValueType), t.getValueContainsNull) } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 63c43f956d78..1fba3dd5abde 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -23,323 +23,296 @@ import java.sql.{Date, Timestamp} import java.time._ import scala.collection.{immutable, mutable} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag import scala.util.Try import com.google.protobuf.ByteString import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression.Literal +import org.apache.spark.connect.proto.Expression.Literal.{Array => CArray, LiteralTypeCase, Map => CMap, Struct} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ +import org.apache.spark.sql.connect.common.InferringDataTypeBuilder.mergeDataTypes import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval +/** + * Conversions from Scala literals to Connect literals. + */ object LiteralValueProtoConverter { + private val missing: Any = new Object - private def setNullValue( - builder: proto.Expression.Literal.Builder, - dataType: DataType, - needDataType: Boolean): proto.Expression.Literal.Builder = { - if (needDataType) { - builder.setNull(toConnectProtoType(dataType)) - } else { - // No need data type but still set the null type to indicate that - // the value is null. - builder.setNull(ProtoDataTypes.NullType) - } - } - - private def toLiteralProtoBuilderInternal( - literal: Any, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { - val builder = proto.Expression.Literal.newBuilder() - - def decimalBuilder(precision: Int, scale: Int, value: String) = { - builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) - } - - def calendarIntervalBuilder(months: Int, days: Int, microseconds: Long) = { - builder.getCalendarIntervalBuilder - .setMonths(months) - .setDays(days) - .setMicroseconds(microseconds) - } - - def arrayBuilder(array: Array[_]) = { - val ab = builder.getArrayBuilder - array.foreach { x => - ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) - } - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) - } - ab - } - - literal match { - case v: Boolean => builder.setBoolean(v) - case v: Byte => builder.setByte(v) - case v: Short => builder.setShort(v) - case v: Int => builder.setInteger(v) - case v: Long => builder.setLong(v) - case v: Float => builder.setFloat(v) - case v: Double => builder.setDouble(v) - case v: BigDecimal => - builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) - case v: JBigDecimal => - builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) - case v: String => builder.setString(v) - case v: Char => builder.setString(v.toString) - case v: Array[Char] => builder.setString(String.valueOf(v)) - case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) - case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) - case v: immutable.ArraySeq[_] => - toLiteralProtoBuilderInternal(v.unsafeArray, options) - case v: LocalDate => builder.setDate(v.toEpochDay.toInt) - case v: Decimal => - builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) - case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) - case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) - case v: LocalDateTime => - builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) - case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) - case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) - case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) - case v: LocalTime => - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(TimeType.DEFAULT_PRECISION)) - case v: Array[_] => builder.setArray(arrayBuilder(v)) - case v: CalendarInterval => - builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) - case null => builder.setNull(ProtoDataTypes.NullType) - case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") - } - } - - private def toLiteralProtoBuilderInternal( - literal: Any, - dataType: DataType, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { - val builder = proto.Expression.Literal.newBuilder() - - def arrayBuilder(scalaValue: Any, elementType: DataType) = { - val ab = builder.getArrayBuilder - scalaValue match { - case a: Array[_] => - a.foreach { item => - ab.addElements(toLiteralProtoBuilderInternal(item, elementType, options).build()) - } - case s: scala.collection.Seq[_] => - s.foreach { item => - ab.addElements(toLiteralProtoBuilderInternal(item, elementType, options).build()) - } - case other => - throw new IllegalArgumentException(s"literal $other not supported (yet).") - } - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(elementType)) - } - ab - } - - def mapBuilder(scalaValue: Any, keyType: DataType, valueType: DataType) = { - val mb = builder.getMapBuilder - scalaValue match { - case map: scala.collection.Map[_, _] => - map.foreach { case (k, v) => - mb.addKeys(toLiteralProtoBuilderInternal(k, keyType, options).build()) - mb.addValues(toLiteralProtoBuilderInternal(v, valueType, options).build()) - } - case other => - throw new IllegalArgumentException(s"literal $other not supported (yet).") - } - if (options.useDeprecatedDataTypeFields) { - mb.setKeyType(toConnectProtoType(keyType)) - mb.setValueType(toConnectProtoType(valueType)) - } - mb - } - - def structBuilder(scalaValue: Any, structType: StructType) = { - val sb = builder.getStructBuilder - val fields = structType.fields - - val iter = scalaValue match { - case p: Product => - p.productIterator - case r: Row => - r.toSeq.iterator - case other => - throw new IllegalArgumentException( - s"literal ${other.getClass.getName}($other) not supported (yet).") - } - - var idx = 0 - while (idx < structType.size) { - val field = fields(idx) - val literalProto = - toLiteralProtoBuilderInternal(iter.next(), field.dataType, options) - sb.addElements(literalProto) - idx += 1 - } - if (options.useDeprecatedDataTypeFields) { - sb.setStructType(toConnectProtoType(structType)) - } - - sb - } - - (literal, dataType) match { - case (v: Option[_], _) => - if (v.isDefined) { - toLiteralProtoBuilderInternal(v.get, dataType, options) - } else { - setNullValue(builder, dataType, options.useDeprecatedDataTypeFields) - } - case (null, _) => - setNullValue(builder, dataType, options.useDeprecatedDataTypeFields) - case (v: mutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.array, dataType, options) - case (v: immutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.unsafeArray, dataType, options) - case (v: Array[Byte], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v, options) - case (v, ArrayType(elementType, _)) => - builder.setArray(arrayBuilder(v, elementType)) - case (v, MapType(keyType, valueType, _)) => - builder.setMap(mapBuilder(v, keyType, valueType)) - case (v, structType: StructType) => - builder.setStruct(structBuilder(v, structType)) - case (v: LocalTime, timeType: TimeType) => - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(timeType.precision)) - case _ => toLiteralProtoBuilderInternal(literal, options) - } - - } + case class ToLiteralProtoOptions(useDeprecatedDataTypeFields: Boolean) /** - * Transforms literal value to the `proto.Expression.Literal.Builder`. + * Transforms literal value to the `proto.Expression.Literal`. * * @return - * proto.Expression.Literal.Builder + * proto.Expression.Literal */ - def toLiteralProtoBuilder(literal: Any): proto.Expression.Literal.Builder = { - toLiteralProtoBuilderWithOptions( + def toLiteralProto(literal: Any): Literal = { + toLiteralProtoWithOptions( literal, None, ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) } - def toLiteralProtoBuilder( - literal: Any, - dataType: DataType): proto.Expression.Literal.Builder = { - toLiteralProtoBuilderWithOptions( + def toLiteralProto(literal: Any, dataType: DataType): Literal = { + toLiteralProtoWithOptions( literal, Some(dataType), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) } - private def setDataTypeForRootLiteral( - builder: proto.Expression.Literal.Builder, - dataType: DataType): proto.Expression.Literal.Builder = { - if (builder.getLiteralTypeCase == - proto.Expression.Literal.LiteralTypeCase.LITERALTYPE_NOT_SET) { - throw new IllegalArgumentException("Literal type should be set first") - } - // To be compatible with the current Scala behavior, we should convert bytes to binary. - val protoDataType = toConnectProtoType(dataType, bytesToBinary = true) - // If the value is not null and the data type is trivial, we don't need to - // set the data type field, because it will be inferred from the literal value, saving space. - val needDataType = protoDataType.getKindCase match { - case proto.DataType.KindCase.ARRAY => true - case proto.DataType.KindCase.STRUCT => true - case proto.DataType.KindCase.MAP => true - case _ => builder.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL - } - if (needDataType) { - builder.setDataType(protoDataType) - } - builder + private[connect] def toLiteralProtoWithOptions( + value: Any, + dataTypeOpt: Option[DataType], + options: ToLiteralProtoOptions): Literal = { + val dataTypeBuilder = dataTypeOpt + .map(dataType => FixedDataTypeBuilder(dataType, value == null || value == None)) + .getOrElse(InferringDataTypeBuilder()) + val (literal, _) = + toLiteralProtoBuilderInternal(value, dataTypeBuilder, options, enclosed = false) + literal } - def toLiteralProtoBuilderWithOptions( - literal: Any, - dataTypeOpt: Option[DataType], - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { - dataTypeOpt match { - case Some(dataType) => - val builder = toLiteralProtoBuilderInternal(literal, dataType, options) - if (!options.useDeprecatedDataTypeFields) { - setDataTypeForRootLiteral(builder, dataType) + private def toLiteralProtoBuilderInternal( + value: Any, + dataTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean): (Literal, DataTypeBuilder) = { + def result( + f: Literal.Builder => Literal.Builder, + dataType: DataType): (Literal, DataTypeBuilder) = { + val literal = f(Literal.newBuilder()).build() + (literal, dataTypeBuilder.merge(dataType, literal.hasNull)) + } + value match { + case null | None => + val dataType = dataTypeBuilder.result() + val protoDataType = if (enclosed) { + // Enclosed NULL value. The dataType is recorded in the enclosing dataType. + ProtoDataTypes.NullType + } else { + // Standalone NULL value. This needs the actual dataType. + DataTypeProtoConverter.toConnectProtoType(dataType) } - builder - case None => - val builder = toLiteralProtoBuilderInternal(literal, options) - if (!options.useDeprecatedDataTypeFields) { - @scala.annotation.tailrec - def unwrapArraySeq(value: Any): Any = value match { - case arraySeq: mutable.ArraySeq[_] => unwrapArraySeq(arraySeq.array) - case arraySeq: immutable.ArraySeq[_] => unwrapArraySeq(arraySeq.unsafeArray) - case _ => value - } - unwrapArraySeq(literal) match { - case null => - setDataTypeForRootLiteral(builder, NullType) - case value => - setDataTypeForRootLiteral(builder, toDataType(value.getClass)) - } + result(_.setNull(protoDataType), dataType) + case v: Boolean => + result(_.setBoolean(v), BooleanType) + case v: Byte => + result(_.setByte(v), ByteType) + case v: Short => + result(_.setShort(v), ShortType) + case v: Int => + result(_.setInteger(v), IntegerType) + case v: Long => + result(_.setLong(v), LongType) + case v: Float => + result(_.setFloat(v), FloatType) + case v: Double => + result(_.setDouble(v), DoubleType) + case v: BigDecimal => + val dataType = DecimalType(v.precision, v.scale) + result(_.setDecimal(toProtoDecimal(v.toString(), dataType)), dataType) + case v: JBigDecimal => + val dataType = DecimalType(v.precision, v.scale) + result(_.setDecimal(toProtoDecimal(v.toString, dataType)), dataType) + case v: Decimal => + val dataType = DecimalType(Math.max(v.precision, v.scale), v.scale) + result(_.setDecimal(toProtoDecimal(v.toString, dataType)), dataType) + case v: String => + result(_.setString(v), StringType) + case v: Char => + result(_.setString(v.toString), StringType) + case v: Array[Char] => + result(_.setString(String.valueOf(v)), StringType) + case v: Array[Byte] => + result(_.setBinary(ByteString.copyFrom(v)), BinaryType) + case v: Instant => + result(_.setTimestamp(SparkDateTimeUtils.instantToMicros(v)), TimestampType) + case v: Timestamp => + result(_.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)), TimestampType) + case v: LocalDateTime => + result(_.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)), TimestampNTZType) + case v: LocalDate => + result(_.setDate(v.toEpochDay.toInt), DateType) + case v: Date => + result(_.setDate(SparkDateTimeUtils.fromJavaDate(v)), DateType) + case v: LocalTime => + val time = Literal.Time + .newBuilder() + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION) + .build() + result(_.setTime(time), TimeType()) + case v: Duration => + result( + _.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)), + DayTimeIntervalType()) + case v: Period => + result( + _.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)), + YearMonthIntervalType()) + case v: CalendarInterval => + val interval = Literal.CalendarInterval + .newBuilder() + .setMonths(v.months) + .setDays(v.days) + .setMicroseconds(v.microseconds) + .build() + result(_.setCalendarInterval(interval), CalendarIntervalType) + case v: mutable.ArraySeq[_] => + toProtoArrayOrBinary(v.array, dataTypeBuilder, options, enclosed) + case v: immutable.ArraySeq[_] => + toProtoArrayOrBinary(v.unsafeArray, dataTypeBuilder, options, enclosed) + case v: Array[_] => + toProtoArray(v, dataTypeBuilder, options, enclosed) + case s: scala.collection.Seq[_] => + toProtoArray(s, dataTypeBuilder, options, enclosed, () => None) + case map: scala.collection.Map[_, _] => + toProtoMap(map, dataTypeBuilder, options, enclosed) + case Some(value) => + toLiteralProtoBuilderInternal(value, dataTypeBuilder, options, enclosed) + case product: Product => + // If we don't have a schema, we could try to extract one from the class. + toProtoStruct(product.productIterator, dataTypeBuilder, options, enclosed) + case row: Row => + var structTypeBuilder = dataTypeBuilder + if (row.schema != null) { + structTypeBuilder = structTypeBuilder.merge(row.schema, isNullable = false) } - builder + toProtoStruct(row.toSeq.iterator, structTypeBuilder, options, enclosed) + case v => + throw new UnsupportedOperationException(s"literal $v not supported (yet).") } + } + private def toProtoDecimal(value: String, dataType: DecimalType): Literal.Decimal = { + Literal.Decimal + .newBuilder() + .setValue(value) + .setPrecision(dataType.precision) + .setScale(dataType.scale) + .build() } - def create[T: TypeTag](v: T): proto.Expression.Literal.Builder = Try { - val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] - toLiteralProtoBuilder(v, dataType) - }.getOrElse { - toLiteralProtoBuilder(v) + private def toProtoMap( + map: scala.collection.Map[_, _], + mapTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean): (Literal, DataTypeBuilder) = { + var (keyTypeBuilder, valueTypeBuilder) = mapTypeBuilder.keyValueTypeBuilder() + val builder = Literal.newBuilder() + val mapBuilder = builder.getMapBuilder + map.foreach { case (k, v) => + val (keyLiteral, updatedKeyTypeBuilder) = + toLiteralProtoBuilderInternal(k, keyTypeBuilder, options, enclosed = true) + mapBuilder.addKeys(keyLiteral) + keyTypeBuilder = updatedKeyTypeBuilder + val (valueLiteral, updatedValueTypeBuilder) = + toLiteralProtoBuilderInternal(v, valueTypeBuilder, options, enclosed = true) + mapBuilder.addValues(valueLiteral) + valueTypeBuilder = updatedValueTypeBuilder + } + val updatedMapTypeBuilder = + mapTypeBuilder.mergeKeyValueBuilder(keyTypeBuilder, valueTypeBuilder) + lazy val protoMapType = DataTypeProtoConverter + .toConnectProtoType(updatedMapTypeBuilder.result()) + .getMap + if (options.useDeprecatedDataTypeFields) { + mapBuilder.setKeyType(protoMapType.getKeyType) + mapBuilder.setValueType(protoMapType.getValueType) + } else if (!enclosed && updatedMapTypeBuilder.mustRecordDataType) { + mapBuilder.setDataType(protoMapType) + } + (builder.build(), updatedMapTypeBuilder) } - case class ToLiteralProtoOptions(useDeprecatedDataTypeFields: Boolean) + private def toProtoArrayOrBinary( + array: Array[_], + arrayTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean): (Literal, DataTypeBuilder) = { + (array, arrayTypeBuilder) match { + case (_: Array[Byte], _: InferringDataTypeBuilder | FixedDataTypeBuilder(BinaryType, _)) => + toLiteralProtoBuilderInternal(array, arrayTypeBuilder, options, enclosed) + case _ => + toProtoArray(array, arrayTypeBuilder, options, enclosed) + } + } - /** - * Transforms literal value to the `proto.Expression.Literal`. - * - * @return - * proto.Expression.Literal - */ - def toLiteralProto(literal: Any): proto.Expression.Literal = { - toLiteralProtoWithOptions( - literal, - None, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + private def toProtoArray( + array: Array[_], + arrayTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean): (Literal, DataTypeBuilder) = { + val fallback = () => Try(toDataType(array.getClass.getComponentType)).toOption + toProtoArray(array, arrayTypeBuilder, options, enclosed, fallback) } - def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal = { - toLiteralProtoWithOptions( - literal, - Some(dataType), - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + private def toProtoArray( + iterable: Iterable[_], + arrayTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean, + fallbackElementTypeGetter: () => Option[DataType]): (Literal, DataTypeBuilder) = { + var elementTypeBuilder = arrayTypeBuilder.elementTypeBuilder() + val builder = Literal.newBuilder() + val arrayBuilder = builder.getArrayBuilder + iterable.foreach { e => + val (literal, updatedElementTypeBuilder) = + toLiteralProtoBuilderInternal(e, elementTypeBuilder, options, enclosed = true) + arrayBuilder.addElements(literal) + elementTypeBuilder = updatedElementTypeBuilder + } + if (!elementTypeBuilder.isDefined) { + val fallbackElementType = fallbackElementTypeGetter() + if (fallbackElementType.isDefined) { + elementTypeBuilder.merge(fallbackElementType.get, isNullable = false) + } + } + val updatedArrayTypeBuilder = arrayTypeBuilder.mergeElementTypeBuilder(elementTypeBuilder) + lazy val protoArrayType = DataTypeProtoConverter + .toConnectProtoType(updatedArrayTypeBuilder.result()) + .getArray + if (options.useDeprecatedDataTypeFields) { + arrayBuilder.setElementType(protoArrayType.getElementType) + } else if (!enclosed && updatedArrayTypeBuilder.mustRecordDataType) { + arrayBuilder.setDataType(protoArrayType) + } + (builder.build(), updatedArrayTypeBuilder) } - def toLiteralProtoWithOptions( - literal: Any, - dataTypeOpt: Option[DataType], - options: ToLiteralProtoOptions): proto.Expression.Literal = { - toLiteralProtoBuilderWithOptions(literal, dataTypeOpt, options).build() + private def toProtoStruct( + fields: Iterator[_], + structTypeBuilder: DataTypeBuilder, + options: ToLiteralProtoOptions, + enclosed: Boolean): (Literal, DataTypeBuilder) = { + val structType = structTypeBuilder.result().asInstanceOf[StructType] + val builder = Literal.newBuilder() + val structBuilder = builder.getStructBuilder + fields.zipAll(structType.fields, missing, null).foreach { case (value, field: StructField) => + require(missing != null) + require(field != null) + val (literal, _) = toLiteralProtoBuilderInternal( + value, + FixedDataTypeBuilder(field.dataType, field.nullable), + options, + enclosed = true) + structBuilder.addElements(literal) + } + def protoStructType = DataTypeProtoConverter.toConnectProtoType(structType) + if (options.useDeprecatedDataTypeFields) { + structBuilder.setStructType(protoStructType) + } else if (!enclosed) { + structBuilder.setDataTypeStruct(protoStructType.getStruct) + } + (builder.build(), structTypeBuilder) } - private[sql] def toDataType(clz: Class[_]): DataType = clz match { + private def toDataType(clz: Class[_]): DataType = clz match { // primitive types case JShort.TYPE => ShortType case JInteger.TYPE => IntegerType @@ -379,285 +352,418 @@ object LiteralValueProtoConverter { case _ => throw new UnsupportedOperationException(s"Unsupported component type $clz in arrays.") } +} - def toScalaValue(literal: proto.Expression.Literal): Any = { - getScalaConverter(getProtoDataType(literal))(literal) +/** + * Base trait for converting a [[proto.Expression.Literal]] into either its Scala representation + * or into its Catalyst representation. + */ +trait FromProtoConvertor { + def convertToValue(literal: proto.Expression.Literal): Any = { + convert(literal, EmptyDataTypeBuilder)._2 } - private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - val converter: proto.Expression.Literal => Any = dataType.getKindCase match { - case proto.DataType.KindCase.NULL => - v => throw InvalidPlanInput(s"Expected null value, but got ${v.getLiteralTypeCase}") - case proto.DataType.KindCase.SHORT => v => v.getShort.toShort - case proto.DataType.KindCase.INTEGER => v => v.getInteger - case proto.DataType.KindCase.LONG => v => v.getLong - case proto.DataType.KindCase.DOUBLE => v => v.getDouble - case proto.DataType.KindCase.BYTE => v => v.getByte.toByte - case proto.DataType.KindCase.FLOAT => v => v.getFloat - case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean - case proto.DataType.KindCase.STRING => v => v.getString - case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray - case proto.DataType.KindCase.DATE => - v => SparkDateTimeUtils.toJavaDate(v.getDate) - case proto.DataType.KindCase.TIMESTAMP => - v => SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) - case proto.DataType.KindCase.TIMESTAMP_NTZ => - v => SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) - case proto.DataType.KindCase.DAY_TIME_INTERVAL => - v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) - case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => - v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) - case proto.DataType.KindCase.TIME => - v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) - case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) - case proto.DataType.KindCase.CALENDAR_INTERVAL => - v => - val interval = v.getCalendarInterval - new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) - case proto.DataType.KindCase.ARRAY => - v => toScalaArrayInternal(v, dataType.getArray) - case proto.DataType.KindCase.MAP => - v => toScalaMapInternal(v, dataType.getMap) - case proto.DataType.KindCase.STRUCT => - v => toScalaStructInternal(v, dataType.getStruct) - case _ => - throw InvalidPlanInput(s"Unsupported Literal Type: ${dataType.getKindCase}") - } - v => if (v.hasNull) null else converter(v) + def convert(literal: proto.Expression.Literal): (DataType, Any) = { + val (builder, value) = convert(literal, EmptyDataTypeBuilder) + (builder.result(), value) } - private def isCompatible( - literalTypeCase: proto.Expression.Literal.LiteralTypeCase, - dataTypeCase: proto.DataType.KindCase): Boolean = { - (literalTypeCase, dataTypeCase) match { - case (proto.Expression.Literal.LiteralTypeCase.NULL, _) => - true - case (proto.Expression.Literal.LiteralTypeCase.BINARY, proto.DataType.KindCase.BINARY) => - true - case (proto.Expression.Literal.LiteralTypeCase.BOOLEAN, proto.DataType.KindCase.BOOLEAN) => - true - case (proto.Expression.Literal.LiteralTypeCase.BYTE, proto.DataType.KindCase.BYTE) => - true - case (proto.Expression.Literal.LiteralTypeCase.SHORT, proto.DataType.KindCase.SHORT) => - true - case (proto.Expression.Literal.LiteralTypeCase.INTEGER, proto.DataType.KindCase.INTEGER) => - true - case (proto.Expression.Literal.LiteralTypeCase.LONG, proto.DataType.KindCase.LONG) => - true - case (proto.Expression.Literal.LiteralTypeCase.FLOAT, proto.DataType.KindCase.FLOAT) => - true - case (proto.Expression.Literal.LiteralTypeCase.DOUBLE, proto.DataType.KindCase.DOUBLE) => - true - case (proto.Expression.Literal.LiteralTypeCase.DECIMAL, proto.DataType.KindCase.DECIMAL) => - true - case (proto.Expression.Literal.LiteralTypeCase.STRING, proto.DataType.KindCase.STRING) => - true - case (proto.Expression.Literal.LiteralTypeCase.DATE, proto.DataType.KindCase.DATE) => - true - case ( - proto.Expression.Literal.LiteralTypeCase.TIMESTAMP, - proto.DataType.KindCase.TIMESTAMP) => - true - case ( - proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ, - proto.DataType.KindCase.TIMESTAMP_NTZ) => - true - case ( - proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL, - proto.DataType.KindCase.CALENDAR_INTERVAL) => - true - case ( - proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL, - proto.DataType.KindCase.DAY_TIME_INTERVAL) => - true - case ( - proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL, - proto.DataType.KindCase.YEAR_MONTH_INTERVAL) => - true - case (proto.Expression.Literal.LiteralTypeCase.TIME, proto.DataType.KindCase.TIME) => - true - case (proto.Expression.Literal.LiteralTypeCase.ARRAY, proto.DataType.KindCase.ARRAY) => - true - case (proto.Expression.Literal.LiteralTypeCase.MAP, proto.DataType.KindCase.MAP) => - true - case (proto.Expression.Literal.LiteralTypeCase.STRUCT, proto.DataType.KindCase.STRUCT) => - true - case _ => false + private def convert( + literal: proto.Expression.Literal, + dataTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + def result(dataType: DataType, value: Any): (DataTypeBuilder, Any) = { + (dataTypeBuilder.merge(dataType, value == null), value) } - } - - def getProtoDataType(literal: proto.Expression.Literal): proto.DataType = { - val dataType = if (literal.hasDataType) { - literal.getDataType - } else { - // For backward compatibility, we still support the old way to - // define the data type of the literal. - if (literal.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL) { - literal.getNull - } else { - val builder = proto.DataType.newBuilder() - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BINARY => - builder.setBinary(proto.DataType.Binary.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BYTE => - builder.setByte(proto.DataType.Byte.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.SHORT => - builder.setShort(proto.DataType.Short.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - builder.setInteger(proto.DataType.Integer.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.LONG => - builder.setLong(proto.DataType.Long.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - builder.setFloat(proto.DataType.Float.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - builder.setDouble(proto.DataType.Double.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(literal.getDecimal.getValue) - var precision = decimal.precision - if (literal.getDecimal.hasPrecision) { - precision = math.max(precision, literal.getDecimal.getPrecision) - } - var scale = decimal.scale - if (literal.getDecimal.hasScale) { - scale = math.max(scale, literal.getDecimal.getScale) - } - builder.setDecimal( - proto.DataType.Decimal - .newBuilder() - .setPrecision(math.max(precision, scale)) - .setScale(scale) - .build()) - case proto.Expression.Literal.LiteralTypeCase.STRING => - builder.setString(proto.DataType.String.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DATE => - builder.setDate(proto.DataType.Date.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIME => - val timeBuilder = proto.DataType.Time.newBuilder() - if (literal.getTime.hasPrecision) { - timeBuilder.setPrecision(literal.getTime.getPrecision) - } - builder.setTime(timeBuilder.build()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - if (literal.getArray.hasElementType) { - builder.setArray( - proto.DataType.Array - .newBuilder() - .setElementType(literal.getArray.getElementType) - .setContainsNull(true) - .build()) - } else { - throw InvalidPlanInput("Data type information is missing in the array literal.") - } - case proto.Expression.Literal.LiteralTypeCase.MAP => - if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { - builder.setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(literal.getMap.getKeyType) - .setValueType(literal.getMap.getValueType) - .setValueContainsNull(true) - .build()) - } else { - throw InvalidPlanInput("Data type information is missing in the map literal.") - } - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - if (literal.getStruct.hasStructType) { - builder.setStruct(literal.getStruct.getStructType.getStruct) - } else { - throw InvalidPlanInput("Data type information is missing in the struct literal.") - } + literal.getLiteralTypeCase match { + case LiteralTypeCase.NULL => + result(DataTypeProtoConverter.toCatalystType(literal.getNull), null) + case LiteralTypeCase.BOOLEAN => + result(BooleanType, literal.getBoolean) + case LiteralTypeCase.BYTE => + result(ByteType, literal.getByte.toByte) + case LiteralTypeCase.SHORT => + result(ShortType, literal.getShort.toShort) + case LiteralTypeCase.INTEGER => + result(IntegerType, literal.getInteger) + case LiteralTypeCase.LONG => + result(LongType, literal.getLong) + case LiteralTypeCase.FLOAT => + convertFloat(literal, dataTypeBuilder) + case LiteralTypeCase.DOUBLE => + result(DoubleType, literal.getDouble) + case LiteralTypeCase.DECIMAL => + val d = literal.getDecimal + val decimal = Decimal(d.getValue) + if (d.hasPrecision && d.getPrecision != decimal.precision) { + throw InvalidPlanInput("") + } + if (d.hasScale && d.getScale != decimal.scale) { + throw InvalidPlanInput("") + } + result(DecimalType(decimal.precision, decimal.scale), decimal) + case LiteralTypeCase.TIME => + val time = literal.getTime + val dataType = if (time.hasPrecision) { + TimeType(time.getPrecision) + } else { + TimeType() + } + result(dataType, convertTime(time)) + case LiteralTypeCase.DATE => + result(DateType, convertDate(literal.getDate)) + case LiteralTypeCase.TIMESTAMP => + result(TimestampType, convertTimestamp(literal.getTimestamp)) + case LiteralTypeCase.TIMESTAMP_NTZ => + result(TimestampNTZType, convertTimestampNTZ(literal.getTimestampNtz)) + case LiteralTypeCase.STRING => + // We do not support collations yet. + result(StringType, convertString(literal.getStringBytes)) + case LiteralTypeCase.BINARY => + result(BinaryType, literal.getBinary.toByteArray) + case LiteralTypeCase.CALENDAR_INTERVAL => + val i = literal.getCalendarInterval + result( + CalendarIntervalType, + new CalendarInterval(i.getMonths, i.getDays, i.getMicroseconds)) + case LiteralTypeCase.DAY_TIME_INTERVAL => + result(DayTimeIntervalType.DEFAULT, convertDayTimeInterval(literal.getDayTimeInterval)) + case LiteralTypeCase.YEAR_MONTH_INTERVAL => + result( + YearMonthIntervalType.DEFAULT, + convertYearMonthInterval(literal.getYearMonthInterval)) + case LiteralTypeCase.ARRAY => + convertArray(literal.getArray, dataTypeBuilder) + case LiteralTypeCase.MAP => + convertMap(literal.getMap, dataTypeBuilder) + case LiteralTypeCase.STRUCT => + var structTypeBuilder = dataTypeBuilder + val struct = literal.getStruct + if (struct.hasDataTypeStruct) { + val structType = DataTypeProtoConverter.toCatalystStructType(struct.getDataTypeStruct) + structTypeBuilder = structTypeBuilder.merge(structType, isNullable = false) + } + if (struct.hasStructType) { + val structType = DataTypeProtoConverter.toCatalystType(struct.getStructType) + structTypeBuilder = structTypeBuilder.merge(structType, isNullable = false) + } + val result = structTypeBuilder.result() match { + case structType: StructType => + convertStruct(struct, structType) + case udt: UserDefinedType[_] => + convertUdt(struct, udt) case _ => - throw InvalidPlanInput( - s"Unsupported Literal Type: ${literal.getLiteralTypeCase.name}" + - s"(${literal.getLiteralTypeCase.getNumber})") + throw InvalidPlanInput("") } - builder.build() - } + (structTypeBuilder, result) + case literalTypeCase => + // At some point we may want to support specialized arrays... + throw InvalidPlanInput( + s"Unsupported Literal Type: ${literalTypeCase.name}(${literalTypeCase.getNumber})") } + } + + protected def convertFloat( + literal: Literal, + dataTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + (dataTypeBuilder.merge(FloatType, isNullable = false), literal.getFloat) + } - if (!isCompatible(literal.getLiteralTypeCase, dataType.getKindCase)) { - throw InvalidPlanInput( - s"Incompatible data type ${dataType.getKindCase} " + - s"for literal ${literal.getLiteralTypeCase}") + protected def convertArray( + array: CArray, + arrayTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + var elementTypeBuilder = arrayTypeBuilder.elementTypeBuilder() + if (array.hasDataType) { + val arrayType = DataTypeProtoConverter.toCatalystArrayType(array.getDataType) + elementTypeBuilder = elementTypeBuilder.merge(arrayType.elementType, arrayType.containsNull) + } + if (array.hasElementType) { + elementTypeBuilder = elementTypeBuilder.merge( + DataTypeProtoConverter.toCatalystType(array.getElementType), + isNullable = true) + } + val numElements = array.getElementsCount + val builder = arrayBuilder(numElements) + var i = 0 + while (i < numElements) { + val (updatedElementTypeBuilder, element) = convert(array.getElements(i), elementTypeBuilder) + elementTypeBuilder = updatedElementTypeBuilder + builder += element + i += 1 } + (arrayTypeBuilder.mergeElementTypeBuilder(elementTypeBuilder), builder.result()) + } - dataType + private def convertMap(map: CMap, dataTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + var (keyTypeBuilder, valueTypeBuilder): (DataTypeBuilder, DataTypeBuilder) = + dataTypeBuilder.keyValueTypeBuilder() + if (map.hasDataType) { + val mapType = DataTypeProtoConverter.toCatalystMapType(map.getDataType) + keyTypeBuilder = keyTypeBuilder.merge(mapType.keyType, isNullable = false) + valueTypeBuilder = valueTypeBuilder.merge(mapType.valueType, mapType.valueContainsNull) + } + if (map.hasKeyType) { + val keyType = DataTypeProtoConverter.toCatalystType(map.getKeyType) + keyTypeBuilder = keyTypeBuilder.merge(keyType, isNullable = false) + } + if (map.hasValueType) { + val valueType = DataTypeProtoConverter.toCatalystType(map.getValueType) + valueTypeBuilder = valueTypeBuilder.merge(valueType, isNullable = true) + } + val numElements = map.getKeysCount + if (numElements != map.getValuesCount) { + throw InvalidPlanInput("") + } + val builder = mapBuilder(numElements) + var i = 0 + while (i < numElements) { + val (updatedKeyTypeBuilder, key) = convert(map.getKeys(i), keyTypeBuilder) + val (updatedValueTypeBuilder, value) = convert(map.getValues(i), valueTypeBuilder) + keyTypeBuilder = updatedKeyTypeBuilder + valueTypeBuilder = updatedValueTypeBuilder + builder += key -> value + i += 1 + } + (dataTypeBuilder.mergeKeyValueBuilder(keyTypeBuilder, valueTypeBuilder), builder.result()) } - private def toScalaArrayInternal( - literal: proto.Expression.Literal, - arrayType: proto.DataType.Array): Array[_] = { - if (!literal.hasArray) { - throw InvalidPlanInput("Array literal is not set.") + private def convertStruct(struct: Struct, structType: StructType): Any = { + val numFields = structType.length + if (numFields != struct.getElementsCount) { + throw InvalidPlanInput("") + } + val builder = structBuilder(structType) + var i = 0 + while (i < numFields) { + val field = structType(i) + val fieldDataTypeBuilder = FixedDataTypeBuilder(field.dataType, field.nullable) + val (_, value) = convert(struct.getElements(i), fieldDataTypeBuilder) + builder += value + i += 1 } - val array = literal.getArray - def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit - tag: ClassTag[T]): Array[T] = { - val size = array.getElementsCount - if (size > 0) { - Array.tabulate(size)(i => converter(array.getElements(i))) - } else { - Array.empty[T] + builder.result() + } + + protected def convertUdt(struct: Struct, udt: UserDefinedType[_]): Any = + throw InvalidPlanInput(" K, - valueConverter: proto.Expression.Literal => V)(implicit - tagK: ClassTag[K], - tagV: ClassTag[V]): mutable.Map[K, V] = { - val size = map.getKeysCount - if (size > 0) { - val m = mutable.LinkedHashMap.empty[K, V] - m.sizeHint(size) - m.addAll(Iterator.tabulate(size)(i => - (keyConverter(map.getKeys(i)), valueConverter(map.getValues(i))))) - } else { - mutable.Map.empty[K, V] - } + def keyValueTypeBuilder(): (DataTypeBuilder, DataTypeBuilder) = this match { + case FixedDataTypeBuilder(mapType: MapType, _) => + ( + FixedDataTypeBuilder(mapType.keyType, isNullable = false), + FixedDataTypeBuilder(mapType.valueType, mapType.valueContainsNull)) + case _: FixedDataTypeBuilder => throw InvalidPlanInput("") + case _ => (InferringDataTypeBuilder(), InferringDataTypeBuilder()) + } + + def mergeKeyValueBuilder( + keyTypeBuilder: DataTypeBuilder, + valueTypeBuilder: DataTypeBuilder): DataTypeBuilder = { + if (keyTypeBuilder.nullable()) { + throw InvalidPlanInput("") } + val mapType = + MapType(keyTypeBuilder.result(), valueTypeBuilder.result(), valueTypeBuilder.nullable()) + merge(mapType, isNullable = false) + } - makeMapData(getScalaConverter(mapType.getKeyType), getScalaConverter(mapType.getValueType)) + def elementTypeBuilder(): DataTypeBuilder = this match { + case FixedDataTypeBuilder(ArrayType(elementType, containsNull), _) => + FixedDataTypeBuilder(elementType, containsNull) + case _: FixedDataTypeBuilder => throw InvalidPlanInput("") + case _ => InferringDataTypeBuilder() } - private def toScalaStructInternal( - literal: proto.Expression.Literal, - structType: proto.DataType.Struct): Any = { - if (!literal.hasStruct) { - throw InvalidPlanInput("Struct literal is not set.") + def mergeElementTypeBuilder(elementTypeBuilder: DataTypeBuilder): DataTypeBuilder = { + val arrayType = ArrayType(elementTypeBuilder.result(), elementTypeBuilder.nullable()) + merge(arrayType, isNullable = false) + } +} + +object DataTypeBuilder { + def unapply(dt: DataTypeBuilder): Option[DataType] = Some(dt.result()) +} + +object EmptyDataTypeBuilder extends DataTypeBuilder { + override def merge(dataType: DataType, isNull: Boolean): DataTypeBuilder = + FixedDataTypeBuilder(dataType, isNull) + override def nullable(): Boolean = throw new NoSuchElementException("nullable()") + override def result(): DataType = throw new NoSuchElementException("result()") + override def isDefined: Boolean = false + override def mustRecordDataType: Boolean = false +} + +case class FixedDataTypeBuilder(dataType: DataType, isNullable: Boolean) extends DataTypeBuilder { + override def merge(mergeDataType: DataType, mergeIsNull: Boolean): DataTypeBuilder = { + if (!isNullable && mergeIsNull) { + throw InvalidPlanInput("") + } + val ok = (dataType, mergeDataType) match { + case (l: DecimalType, r: DecimalType) => + l.isWiderThan(r) + case (_, _: NullType) => + mergeIsNull + case _ => + // A merged dataType is allowed to be more strict than the dataType we are enforcing. + DataType.equalsIgnoreCompatibleNullability(mergeDataType, dataType) } - val struct = literal.getStruct - val structData = Array.tabulate(struct.getElementsCount) { i => - val element = struct.getElements(i) - val dataType = structType.getFields(i).getDataType - getScalaConverter(dataType)(element) + if (!ok) { + throwIncompatibleDataTypeException(dataType, mergeDataType) + } + this + } + override def nullable(): Boolean = isNullable + override def result(): DataType = dataType + override def isDefined: Boolean = true + + // This can be more subtle. We could also check if the inferred + // dataType matches the fixed dataType. + override def mustRecordDataType: Boolean = true +} + +case class InferringDataTypeBuilder() extends DataTypeBuilder { + private var dataType: DataType = NullType + private var isNullable: Boolean = false + override def isDefined: Boolean = dataType != NullType + + override def mustRecordDataType: Boolean = dataType.existsRecursively { dt => + dt.isInstanceOf[StructType] || dt == NullType + } + + override def merge(mergeDataType: DataType, mergeIsNull: Boolean): DataTypeBuilder = { + assert(mergeIsNull || mergeDataType != NullType) + dataType = mergeDataTypes(dataType, mergeDataType).getOrElse { + throwIncompatibleDataTypeException(dataType, mergeDataType) } - new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType)) + isNullable |= mergeIsNull + this } + override def nullable(): Boolean = isNullable + override def result(): DataType = dataType +} - def getDataType(literal: proto.Expression.Literal): DataType = { - DataTypeProtoConverter.toCatalystType(getProtoDataType(literal)) +object InferringDataTypeBuilder { + def mergeDataTypes(left: DataType, right: DataType): Option[DataType] = { + (left, right) match { + // Add Decimal support... + case (l, r) if l eq r => Some(l) + case (l: AtomicType, r: AtomicType) if l == r => Some(l) + case (l: DecimalType, r: DecimalType) => + if (l.isWiderThan(r)) { + Some(l) + } else if (r.isWiderThan(l)) { + Some(r) + } else { + None + } + case (_: AtomicType, _: AtomicType) => None + case (_: NullType, dt) if !dt.isInstanceOf[NullType] => Some(dt) + case (dt, _: NullType) if !dt.isInstanceOf[NullType] => Some(dt) + case (l: NullType, _: NullType) => Some(l) + + case (ArrayType(leftElemType, leftNulls), ArrayType(rightElemType, rightNulls)) => + mergeDataTypes(leftElemType, rightElemType).map { elemType => + ArrayType(elemType, leftNulls || rightNulls) + } + case ( + MapType(leftKeyType, leftValueType, leftValueNulls), + MapType(rightKeyType, rightValueType, rightValueNulls)) => + mergeDataTypes(leftKeyType, rightKeyType) + .zip(mergeDataTypes(leftValueType, rightValueType)) + .map { case (keyType, valueType) => + MapType(keyType, valueType, leftValueNulls || rightValueNulls) + } + case (StructType(leftFields), StructType(rightFields)) => + if (leftFields.length != rightFields.length) { + return None + } + val fields = leftFields.zip(rightFields).flatMap { + case (leftField, rightField) if leftField.name == rightField.name => + mergeDataTypes(leftField.dataType, rightField.dataType).map { fieldType => + val nullable = leftField.nullable || rightField.nullable + leftField.copy(dataType = fieldType, nullable = nullable) + } + case _ => None + } + if (fields.length == leftFields.length) { + Some(StructType(fields)) + } else { + None + } + case _ => None + } } } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain index 0f4ae8813e89..74d512b6910c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain @@ -1,2 +1,2 @@ -Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, [-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, [7834609328726531,7834609328726532,7834609328726533] AS ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, [2.718281828459045,1.0,2.0] AS ARRAY(2.718281828459045D, 1.0D, 2.0D)#0, [-0.8,-0.7,-0.9] AS ARRAY(CAST('-0.8' AS FLOAT), CAST('-0.7' AS FLOAT), CAST('-0.9' AS FLOAT))#0, [89.976200000000000000,89.976210000000000000] AS ARRAY(89.976200000000000000BD, 89.976210000000000000BD)#0, [89889.766723100000000000,89889.766723100000000000] AS ARRAY(89889.766723100000000000BD, 89889.766723100000000000BD)#0, [connect!,disconnect!] AS ARRAY('connect!', 'disconnect!')#0, TF AS TF#0, [ABCDEFGHIJ,BCDEFGHIJK] AS ARRAY('ABCDEFGHIJ', 'BCDEFGHIJK')#0, [18545,18546] AS ARRAY(DATE '2020-10-10', DATE '2020-10-11')#0, [1677155519808000,1677155519809000] AS ARRAY(TIMESTAMP '2023-02-23 04:31:59.808', TIMESTAMP '2023-02-23 04:31:59.809')#0, [12345000,23456000] AS ARRAY(TIMESTAMP '1969-12-31 16:00:12.345', TIMESTAMP '1969-12-31 16:00:23.456')#0, [1677184560000000,1677188160000000] AS ARRAY(TIMESTAMP_NTZ '2023-02-23 20:36:00', TIMESTAMP_NTZ '2023-02-23 21:36:00')#0, [19411,19417] AS ARRAY(DATE '2023-02-23', DATE '2023-03-01')#0, [100000000,200000000] AS ARRAY(INTERVAL '0 00:01:40' DAY TO SECOND, INTERVAL '0 00:03:20' DAY TO SECOND)#0, [0,0] AS ARRAY(INTERVAL '0-0' YEAR TO MONTH, INTERVAL '0-0' YEAR TO MONTH)#0, [2 months 20 days 0.0001 seconds,2 months 21 days 0.0002 seconds] AS ARRAY(INTERVAL '2 months 20 days 0.0001 seconds', INTERVAL '2 months 21 days 0.0002 seconds')#0] +Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, [-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, [7834609328726531,7834609328726532,7834609328726533] AS ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, [2.718281828459045,1.0,2.0] AS ARRAY(2.718281828459045D, 1.0D, 2.0D)#0, [-0.8,-0.7,-0.9] AS ARRAY(CAST('-0.8' AS FLOAT), CAST('-0.7' AS FLOAT), CAST('-0.9' AS FLOAT))#0, [89.97620,89.97621] AS ARRAY(89.97620BD, 89.97621BD)#0, [89889.7667231,89889.7667231] AS ARRAY(89889.7667231BD, 89889.7667231BD)#0, [connect!,disconnect!] AS ARRAY('connect!', 'disconnect!')#0, TF AS TF#0, [ABCDEFGHIJ,BCDEFGHIJK] AS ARRAY('ABCDEFGHIJ', 'BCDEFGHIJK')#0, [18545,18546] AS ARRAY(DATE '2020-10-10', DATE '2020-10-11')#0, [1677155519808000,1677155519809000] AS ARRAY(TIMESTAMP '2023-02-23 04:31:59.808', TIMESTAMP '2023-02-23 04:31:59.809')#0, [12345000,23456000] AS ARRAY(TIMESTAMP '1969-12-31 16:00:12.345', TIMESTAMP '1969-12-31 16:00:23.456')#0, [1677184560000000,1677188160000000] AS ARRAY(TIMESTAMP_NTZ '2023-02-23 20:36:00', TIMESTAMP_NTZ '2023-02-23 21:36:00')#0, [19411,19417] AS ARRAY(DATE '2023-02-23', DATE '2023-03-01')#0, [100000000,200000000] AS ARRAY(INTERVAL '0 00:01:40' DAY TO SECOND, INTERVAL '0 00:03:20' DAY TO SECOND)#0, [0,0] AS ARRAY(INTERVAL '0-0' YEAR TO MONTH, INTERVAL '0-0' YEAR TO MONTH)#0, [2 months 20 days 0.0001 seconds,2 months 21 days 0.0002 seconds] AS ARRAY(INTERVAL '2 months 20 days 0.0001 seconds', INTERVAL '2 months 21 days 0.0002 seconds')#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 3c878be34143..66df19a45d18 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, [18545,1677155519808000,12345000,1677184560000000,19411,200000000,0,86399999999999,2 months 20 days 0.0001 seconds] AS NAMED_STRUCT('_1', DATE '2020-10-10', '_2', TIMESTAMP '2023-02-23 04:31:59.808', '_3', TIMESTAMP '1969-12-31 16:00:12.345', '_4', TIMESTAMP_NTZ '2023-02-23 20:36:00', '_5', DATE '2023-02-23', '_6', INTERVAL '0 00:03:20' DAY TO SECOND, '_7', INTERVAL '0-0' YEAR TO MONTH, '_8', TIME '23:59:59.999999999', '_9', INTERVAL '2 months 20 days 0.0001 seconds')#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [null,null] AS ARRAY(CAST(NULL AS INT), CAST(NULL AS INT))#0, [null,null,[1,a],[2,null]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'), NAMED_STRUCT('_1', 2, '_2', CAST(NULL AS STRING)))#0, [null,null,[1,a]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'))#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, map(keys: [a,b], values: [null,null]) AS MAP('a', CAST(NULL AS INT), 'b', CAST(NULL AS INT))#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [[1,2,[3,4]],[5,6,[]]] AS ARRAY(NAMED_STRUCT('_1', 1, '_2', '2', '_3', ARRAY('3', '4')), NAMED_STRUCT('_1', 5, '_2', '6', '_3', ARRAY()))#0, [[1,2],[3,4],[5,6]] AS ARRAY(NAMED_STRUCT('a', 1, 'b', '2'), NAMED_STRUCT('a', 3, 'b', '4'), NAMED_STRUCT('a', 5, 'b', '6'))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, [keys: [a,b], values: [[1,2],[3,4]],keys: [a,b], values: [[5,6],[7,8]],keys: [a,b], values: [[],[]]] AS ARRAY(MAP('a', ARRAY('1', '2'), 'b', ARRAY('3', '4')), MAP('a', ARRAY('5', '6'), 'b', ARRAY('7', '8')), MAP('a', ARRAY(), 'b', ARRAY()))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, [8,6] AS ARRAY(8Y, 6Y)#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, [18545,1677155519808000,12345000,1677184560000000,19411,200000000,0,86399999999999,2 months 20 days 0.0001 seconds] AS NAMED_STRUCT('_1', DATE '2020-10-10', '_2', TIMESTAMP '2023-02-23 04:31:59.808', '_3', TIMESTAMP '1969-12-31 16:00:12.345', '_4', TIMESTAMP_NTZ '2023-02-23 20:36:00', '_5', DATE '2023-02-23', '_6', INTERVAL '0 00:03:20' DAY TO SECOND, '_7', INTERVAL '0-0' YEAR TO MONTH, '_8', TIME '23:59:59.999999999', '_9', INTERVAL '2 months 20 days 0.0001 seconds')#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [null,null] AS ARRAY(CAST(NULL AS INT), CAST(NULL AS INT))#0, [null,null,[1,a],[2,null]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'), NAMED_STRUCT('_1', 2, '_2', CAST(NULL AS STRING)))#0, [null,null,[1,a]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'))#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, map(keys: [a,b], values: [null,null]) AS MAP('a', CAST(NULL AS INT), 'b', CAST(NULL AS INT))#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [[1,2,[3,4]],[5,6,[]]] AS ARRAY(NAMED_STRUCT('_1', 1, '_2', '2', '_3', ARRAY('3', '4')), NAMED_STRUCT('_1', 5, '_2', '6', '_3', ARRAY()))#0, [[1,2],[3,4],[5,6]] AS ARRAY(NAMED_STRUCT('a', 1, 'b', '2'), NAMED_STRUCT('a', 3, 'b', '4'), NAMED_STRUCT('a', 5, 'b', '6'))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, [keys: [a,b], values: [[1,2],[3,4]],keys: [a,b], values: [[5,6],[7,8]],keys: [a,b], values: [[],[]]] AS ARRAY(MAP('a', ARRAY('1', '2'), 'b', ARRAY('3', '4')), MAP('a', ARRAY('5', '6'), 'b', ARRAY('7', '8')), MAP('a', ARRAY(), 'b', ARRAY()))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lag.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lag.json index 89b9968dd33d..53d57913dd3a 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lag.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lag.json @@ -63,10 +63,6 @@ "null": { "null": { } - }, - "dataType": { - "null": { - } } }, "common": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lag.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lag.proto.bin index 872188c946fd..c8030a0979c6 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lag.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lag.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json index 176aab1deda6..13f7d042876a 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json @@ -363,15 +363,6 @@ }, { "integer": 6 }] - }, - "dataType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } } }, "common": { @@ -396,10 +387,6 @@ "null": { "null": { } - }, - "dataType": { - "null": { - } } }, "common": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin index 6a296702f064..14540646a506 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index 65902ad604b4..bcaa96e7818f 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -14,15 +14,6 @@ "expressions": [{ "literal": { "array": { - }, - "dataType": { - "array": { - "elementType": { - "double": { - } - }, - "containsNull": true - } } }, "common": { @@ -64,20 +55,6 @@ }] } }] - }, - "dataType": { - "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } } }, "common": { @@ -131,25 +108,6 @@ }] } }] - }, - "dataType": { - "array": { - "elementType": { - "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } - }, - "containsNull": true - } } }, "common": { @@ -177,15 +135,6 @@ }, { "boolean": false }] - }, - "dataType": { - "array": { - "elementType": { - "boolean": { - } - }, - "containsNull": true - } } }, "common": { @@ -236,15 +185,6 @@ }, { "short": 9874 }] - }, - "dataType": { - "array": { - "elementType": { - "short": { - } - }, - "containsNull": true - } } }, "common": { @@ -274,15 +214,6 @@ }, { "integer": -8726533 }] - }, - "dataType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } } }, "common": { @@ -312,15 +243,6 @@ }, { "long": "7834609328726533" }] - }, - "dataType": { - "array": { - "elementType": { - "long": { - } - }, - "containsNull": true - } } }, "common": { @@ -350,15 +272,6 @@ }, { "double": 2.0 }] - }, - "dataType": { - "array": { - "elementType": { - "double": { - } - }, - "containsNull": true - } } }, "common": { @@ -388,15 +301,6 @@ }, { "float": -0.9 }] - }, - "dataType": { - "array": { - "elementType": { - "float": { - } - }, - "containsNull": true - } } }, "common": { @@ -432,17 +336,6 @@ "scale": 5 } }] - }, - "dataType": { - "array": { - "elementType": { - "decimal": { - "scale": 18, - "precision": 38 - } - }, - "containsNull": true - } } }, "common": { @@ -478,17 +371,6 @@ "scale": 7 } }] - }, - "dataType": { - "array": { - "elementType": { - "decimal": { - "scale": 18, - "precision": 38 - } - }, - "containsNull": true - } } }, "common": { @@ -516,16 +398,6 @@ }, { "string": "disconnect!" }] - }, - "dataType": { - "array": { - "elementType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "containsNull": true - } } }, "common": { @@ -574,16 +446,6 @@ }, { "string": "BCDEFGHIJK" }] - }, - "dataType": { - "array": { - "elementType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "containsNull": true - } } }, "common": { @@ -611,15 +473,6 @@ }, { "date": 18546 }] - }, - "dataType": { - "array": { - "elementType": { - "date": { - } - }, - "containsNull": true - } } }, "common": { @@ -647,15 +500,6 @@ }, { "timestamp": "1677155519809000" }] - }, - "dataType": { - "array": { - "elementType": { - "timestamp": { - } - }, - "containsNull": true - } } }, "common": { @@ -683,15 +527,6 @@ }, { "timestamp": "23456000" }] - }, - "dataType": { - "array": { - "elementType": { - "timestamp": { - } - }, - "containsNull": true - } } }, "common": { @@ -719,15 +554,6 @@ }, { "timestampNtz": "1677188160000000" }] - }, - "dataType": { - "array": { - "elementType": { - "timestampNtz": { - } - }, - "containsNull": true - } } }, "common": { @@ -755,15 +581,6 @@ }, { "date": 19417 }] - }, - "dataType": { - "array": { - "elementType": { - "date": { - } - }, - "containsNull": true - } } }, "common": { @@ -791,17 +608,6 @@ }, { "dayTimeInterval": "200000000" }] - }, - "dataType": { - "array": { - "elementType": { - "dayTimeInterval": { - "startField": 0, - "endField": 3 - } - }, - "containsNull": true - } } }, "common": { @@ -829,17 +635,6 @@ }, { "yearMonthInterval": 0 }] - }, - "dataType": { - "array": { - "elementType": { - "yearMonthInterval": { - "startField": 0, - "endField": 1 - } - }, - "containsNull": true - } } }, "common": { @@ -875,15 +670,6 @@ "microseconds": "200" } }] - }, - "dataType": { - "array": { - "elementType": { - "calendarInterval": { - } - }, - "containsNull": true - } } }, "common": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index 7e2b7c3bf999..042ea988b778 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 41ca771596ef..aa2aff5c54b8 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -77,10 +77,6 @@ }, { "literal": { "null": { - "null": { - } - }, - "dataType": { "string": { "collation": "UTF8_BINARY" } @@ -386,7 +382,19 @@ } }, { "literal": { - "binary": "CAY=" + "array": { + "elements": [{ + "byte": 8 + }, { + "byte": 6 + }], + "dataType": { + "elementType": { + "byte": { + } + } + } + } }, "common": { "origin": { @@ -412,10 +420,8 @@ "integer": 8 }, { "integer": 6 - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "integer": { } @@ -445,10 +451,6 @@ "null": { "null": { } - }, - "dataType": { - "null": { - } } }, "common": { @@ -717,10 +719,8 @@ "days": 20, "microseconds": "100" } - }] - }, - "dataType": { - "struct": { + }], + "dataTypeStruct": { "fields": [{ "name": "_1", "dataType": { @@ -840,10 +840,8 @@ "integer": 2 }, { "integer": 3 - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "integer": { } @@ -881,10 +879,8 @@ "null": { } } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "integer": { } @@ -942,10 +938,8 @@ } }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "struct": { "fields": [{ @@ -1007,10 +1001,8 @@ "string": "a" }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "struct": { "fields": [{ @@ -1060,10 +1052,8 @@ "integer": 2 }, { "integer": 3 - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "integer": { } @@ -1100,10 +1090,8 @@ "integer": 1 }, { "integer": 2 - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "string": { "collation": "UTF8_BINARY" @@ -1151,10 +1139,8 @@ "null": { } } - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "string": { "collation": "UTF8_BINARY" @@ -1194,10 +1180,8 @@ "integer": 2 }, { "double": 1.0 - }] - }, - "dataType": { - "struct": { + }], + "dataTypeStruct": { "fields": [{ "name": "_1", "dataType": { @@ -1242,10 +1226,6 @@ }, { "literal": { "null": { - "null": { - } - }, - "dataType": { "integer": { } } @@ -1272,10 +1252,8 @@ "array": { "elements": [{ "integer": 1 - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "integer": { } @@ -1312,10 +1290,8 @@ "null": { } } - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "integer": { } @@ -1356,10 +1332,8 @@ "null": { } } - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "integer": { } @@ -1400,10 +1374,8 @@ "null": { } } - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "integer": { } @@ -1466,10 +1438,8 @@ "integer": 9 }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "array": { "elementType": { @@ -1529,10 +1499,8 @@ } }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "struct": { "fields": [{ @@ -1613,10 +1581,8 @@ "string": "6" }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "struct": { "fields": [{ @@ -1699,10 +1665,8 @@ "integer": 6 }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "map": { "keyType": { @@ -1805,10 +1769,8 @@ } }] } - }] - }, - "dataType": { - "array": { + }], + "dataType": { "elementType": { "map": { "keyType": { @@ -1884,10 +1846,8 @@ "integer": 4 }] } - }] - }, - "dataType": { - "map": { + }], + "dataType": { "keyType": { "integer": { } @@ -1971,10 +1931,8 @@ } }] } - }] - }, - "dataType": { - "struct": { + }], + "dataTypeStruct": { "fields": [{ "name": "_1", "dataType": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index 5068b513a927..f690de3cd8cb 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index addf94ed3460..ae602297efac 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, Identifiable, MLReader, MLWritable} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.classic.Dataset -import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoSpecializedArray} +import org.apache.spark.sql.connect.common.ProtoSpecializedArray import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.SessionHolder @@ -129,33 +129,16 @@ private[ml] object MLUtils { def setInstanceParams(instance: Params, params: proto.MlParams): Unit = { params.getParamsMap.asScala.foreach { case (name, literal) => val p = instance.getParam(name) - val value = literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val s = literal.getStruct - s.getStructType.getUdt.getJvmClass match { - case "org.apache.spark.ml.linalg.VectorUDT" => deserializeVector(s) - case "org.apache.spark.ml.linalg.MatrixUDT" => deserializeMatrix(s) - case _ => - throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct} for ${name}") - } - + val paramValue = MLParamConverter.convertToValue(literal) + val value = paramValue match { + case _: Matrix | _: Vector => paramValue + case _: String | _: Boolean if p.dataClass == null => paramValue + case _ if p.dataClass != null => reconcileParam(p.dataClass, paramValue) case _ => - val paramValue = LiteralValueProtoConverter.toScalaValue(literal) - val paramType: Class[_] = if (p.dataClass == null) { - if (paramValue.isInstanceOf[String]) { - classOf[String] - } else if (paramValue.isInstanceOf[Boolean]) { - classOf[Boolean] - } else { - throw MlUnsupportedException( - "Spark Connect ML requires the customized ML Param class setting 'dataClass' " + - "parameter if the param value type is not String or Boolean type, " + - s"but the param $name does not have the required dataClass.") - } - } else { - p.dataClass - } - reconcileParam(paramType, paramValue) + throw MlUnsupportedException( + "Spark Connect ML requires the customized ML Param class setting 'dataClass' " + + "parameter if the param value type is not String or Boolean type, " + + s"but the param $name does not have the required dataClass.") } instance.set(p, value) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala index 6863818d00ef..98180bf8c6de 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.connect.ml +import scala.collection.mutable + import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression.Literal import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.Params import org.apache.spark.sql.Dataset -import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.connect.common.{DataTypeBuilder, FromProtoToScalaConverter, LiteralValueProtoConverter, ProtoDataTypes} import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, UserDefinedType} private[ml] object Serializer { @@ -142,53 +146,8 @@ private[ml] object Serializer { sessionHolder: SessionHolder): Array[(Object, Class[_])] = { args.map { arg => if (arg.hasParam) { - val literal = arg.getParam - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val struct = literal.getStruct - struct.getStructType.getUdt.getJvmClass match { - case "org.apache.spark.ml.linalg.VectorUDT" => - (MLUtils.deserializeVector(struct), classOf[Vector]) - case "org.apache.spark.ml.linalg.MatrixUDT" => - (MLUtils.deserializeMatrix(struct), classOf[Matrix]) - case _ => - throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct}") - } - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - (literal.getInteger.asInstanceOf[Object], classOf[Int]) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - (literal.getFloat.toDouble.asInstanceOf[Object], classOf[Double]) - case proto.Expression.Literal.LiteralTypeCase.STRING => - (literal.getString, classOf[String]) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - (literal.getDouble.asInstanceOf[Object], classOf[Double]) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - (literal.getBoolean.asInstanceOf[Object], classOf[Boolean]) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val scalaArray = - LiteralValueProtoConverter.toScalaValue(literal).asInstanceOf[Array[_]] - val dataType = LiteralValueProtoConverter.getProtoDataType(literal) - dataType.getArray.getElementType.getKindCase match { - case proto.DataType.KindCase.DOUBLE => - (MLUtils.reconcileArray(classOf[Double], scalaArray), classOf[Array[Double]]) - case proto.DataType.KindCase.STRING => - (MLUtils.reconcileArray(classOf[String], scalaArray), classOf[Array[String]]) - case proto.DataType.KindCase.ARRAY => - dataType.getArray.getElementType.getArray.getElementType.getKindCase match { - case proto.DataType.KindCase.STRING => - ( - MLUtils.reconcileArray(classOf[Array[String]], scalaArray), - classOf[Array[Array[String]]]) - case _ => - throw MlUnsupportedException(s"Unsupported inner array ${literal.getArray}") - } - case _ => - throw MlUnsupportedException(s"Unsupported array $literal") - } - - case other => - throw MlUnsupportedException(s"$other not supported") - } + val value = MLParamConverter.convertToValue(arg.getParam).asInstanceOf[AnyRef] + (value, value.getClass) } else if (arg.hasInput) { (MLUtils.parseRelationProto(arg.getInput, sessionHolder), classOf[Dataset[_]]) } else { @@ -215,3 +174,39 @@ private[ml] object Serializer { builder.build() } } + +object MLParamConverter extends FromProtoToScalaConverter { + + override protected def convertFloat( + literal: Literal, + dataTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + (dataTypeBuilder.merge(DoubleType, isNullable = false), literal.getFloat.toDouble) + } + + override protected def convertUdt(struct: Literal.Struct, udt: UserDefinedType[_]): Any = { + if (udt.userClass == classOf[Vector]) { + MLUtils.deserializeVector(struct) + } else if (udt.userClass == classOf[Matrix]) { + MLUtils.deserializeMatrix(struct) + } else { + throw MlUnsupportedException(s"Unsupported UDT $struct") + } + } + + override protected def arrayBuilder(size: Int): mutable.Builder[Any, Any] = + mutable.ArrayBuilder.make[Any] + + override protected def convertArray( + array: Literal.Array, + arrayTypeBuilder: DataTypeBuilder): (DataTypeBuilder, Any) = { + super.convertArray(array, arrayTypeBuilder) match { + case (DataTypeBuilder(ArrayType(DoubleType, false)), value: Array[_]) => + (arrayTypeBuilder, MLUtils.reconcileArray(classOf[Double], value)) + case (DataTypeBuilder(ArrayType(ArrayType(_: StringType, _), _)), value: Array[_]) => + (arrayTypeBuilder, MLUtils.reconcileArray(classOf[Array[String]], value)) + case (DataTypeBuilder(ArrayType(_: StringType, _)), value: Array[_]) => + (arrayTypeBuilder, MLUtils.reconcileArray(classOf[String], value)) + case result => result + } + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index 4c8911c88188..299702fe7150 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -17,10 +17,18 @@ package org.apache.spark.sql.connect.planner +import scala.collection.mutable + +import com.google.protobuf.ByteString + import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connect.common.LiteralValueProtoConverter +import org.apache.spark.connect.proto.Expression.Literal +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.connect.common.FromProtoConvertor +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String object LiteralExpressionProtoConverter { @@ -31,9 +39,95 @@ object LiteralExpressionProtoConverter { * Expression */ def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal = { - val dataType = LiteralValueProtoConverter.getDataType(lit) - val scalaValue = LiteralValueProtoConverter.toScalaValue(lit) - val convert = CatalystTypeConverters.createToCatalystConverter(dataType) - expressions.Literal(convert(scalaValue), dataType) + val (dataType, value) = FromProtoToCatalystConverter.convert(lit) + expressions.Literal(value, dataType) + } + + private object FromProtoToCatalystConverter extends FromProtoConvertor { + override protected def convertString(string: ByteString): UTF8String = + UTF8String.fromBytes(string.toByteArray) + + override protected def convertTime(value: Literal.Time): Long = value.getNano + + override protected def convertDate(value: Int): Int = value + + override protected def convertTimestamp(value: Long): Long = value + + override protected def convertTimestampNTZ(value: Long): Long = value + + override protected def convertDayTimeInterval(value: Long): Long = value + + override protected def convertYearMonthInterval(value: Int): Int = value + + override protected def arrayBuilder(size: Int): mutable.Builder[Any, ArrayData] = { + new mutable.Builder[Any, ArrayData] { + private var index = 0 + private var data: Array[Any] = _ + clear() + + override def clear(): Unit = { + index = 0 + data = new Array[Any](size) + } + + override def addOne(elem: Any): this.type = { + data(index) = elem + index += 1 + this + } + + override def result(): ArrayData = { + assert(index == data.length) + new GenericArrayData(data) + } + } + } + + override protected def mapBuilder(size: Int): mutable.Builder[(Any, Any), Any] = { + new mutable.Builder[(Any, Any), MapData] { + private val keys: mutable.Builder[Any, ArrayData] = arrayBuilder(size) + private val values: mutable.Builder[Any, ArrayData] = arrayBuilder(size) + clear() + + override def clear(): Unit = { + keys.clear() + values.clear() + } + + override def addOne(elem: (Any, Any)): this.type = { + keys += elem._1 + values += elem._2 + this + } + + override def result(): MapData = { + new ArrayBasedMapData(keys.result(), values.result()) + } + } + } + + override protected def structBuilder(schema: StructType): mutable.Builder[Any, Any] = { + new mutable.Builder[Any, InternalRow] { + private var index = 0 + private var row: GenericInternalRow = _ + clear() + + override def clear(): Unit = { + index = 0 + row = new GenericInternalRow(schema.length) + } + + override def addOne(elem: Any): this.type = { + row.update(index, elem) + index += 1 + this + } + + override def result(): InternalRow = { + assert(index == row.values.length) + row + } + } + } } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7e17a935f599..9393e68b3510 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connect.client.arrow.ArrowSerializer -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, FromProtoToScalaConverter, InvalidPlanInput, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.ml.MLHandler import org.apache.spark.sql.connect.pipelines.PipelinesHandler @@ -464,13 +464,13 @@ class SparkConnectPlanner( val cols = rel.getColsList.asScala.toArray val values = rel.getValuesList.asScala.toArray if (values.length == 1) { - val value = LiteralValueProtoConverter.toScalaValue(values.head) + val value = FromProtoToScalaConverter.convertToValue(values.head) val columns = if (cols.nonEmpty) Some(cols.toImmutableArraySeq) else None dataset.na.fillValue(value, columns).logicalPlan } else { val valueMap = mutable.Map.empty[String, Any] cols.zip(values).foreach { case (col, value) => - valueMap.update(col, LiteralValueProtoConverter.toScalaValue(value)) + valueMap.update(col, FromProtoToScalaConverter.convertToValue(value)) } dataset.na.fill(valueMap = valueMap.toMap).logicalPlan } @@ -497,8 +497,8 @@ class SparkConnectPlanner( val replacement = mutable.Map.empty[Any, Any] rel.getReplacementsList.asScala.foreach { replace => replacement.update( - LiteralValueProtoConverter.toScalaValue(replace.getOldValue), - LiteralValueProtoConverter.toScalaValue(replace.getNewValue)) + FromProtoToScalaConverter.convertToValue(replace.getOldValue), + FromProtoToScalaConverter.convertToValue(replace.getNewValue)) } if (rel.getColsCount == 0) { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 3eab2560bcc1..3b5e5b3f27b9 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.connect.planner -import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.connect.common.InvalidPlanInput -import org.apache.spark.sql.connect.common.LiteralValueProtoConverter +import org.apache.spark.sql.connect.common.{FromProtoToScalaConverter, LiteralValueProtoConverter} import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.ToLiteralProtoOptions -import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter import org.apache.spark.sql.types._ -class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite +class LiteralExpressionProtoConverterSuite extends SparkFunSuite { private def toLiteralProto(v: Any): proto.Expression.Literal = { LiteralValueProtoConverter @@ -49,7 +46,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i test("basic proto value and catalyst value conversion") { val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark") for (v <- values) { - assertResult(v)(LiteralValueProtoConverter.toScalaValue(toLiteralProto(v))) + assertResult(v)(FromProtoToScalaConverter.convertToValue(toLiteralProto(v))) } } @@ -124,7 +121,6 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i Seq(1, 2, 3), Some(ArrayType(IntegerType, containsNull = false)), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) - assert(!literalProto.hasDataType) assert(literalProto.getArray.getElementsList.size == 3) assert(literalProto.getArray.getElementType.hasInteger) @@ -147,7 +143,6 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i Map[String, Int]("a" -> 1, "b" -> 2), Some(MapType(StringType, IntegerType, valueContainsNull = false)), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) - assert(!literalProto.hasDataType) assert(literalProto.getMap.getKeysList.size == 2) assert(literalProto.getMap.getValuesList.size == 2) assert(literalProto.getMap.getKeyType.hasString) @@ -172,15 +167,14 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i test("backward compatibility for struct literal proto") { // Test the old way of defining structs with structType field and elements + val schema = StructType( + Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = false))) val structProto = LiteralValueProtoConverter.toLiteralProtoWithOptions( (1, "test"), - Some( - StructType( - Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", StringType, nullable = false)))), + Some(schema), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) - assert(!structProto.hasDataType) assert(structProto.getStruct.getElementsList.size == 2) val structTypeProto = structProto.getStruct.getStructType.getStruct assert(structTypeProto.getFieldsList.size == 2) @@ -189,8 +183,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(structTypeProto.getFieldsList.get(1).getName == "b") assert(structTypeProto.getFieldsList.get(1).getDataType.hasString) - val result = LiteralValueProtoConverter.toScalaValue(structProto) - val resultType = LiteralValueProtoConverter.getProtoDataType(structProto) + val (resultType, result) = FromProtoToScalaConverter.convert(structProto) // Verify the result is a GenericRowWithSchema with correct values assert(result.isInstanceOf[GenericRowWithSchema]) @@ -198,36 +191,25 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(row.length == 2) assert(row.get(0) == 1) assert(row.get(1) == "test") - - // Verify the returned struct type matches the original - assert(resultType.getKindCase == proto.DataType.KindCase.STRUCT) - val structType = resultType.getStruct - assert(structType.getFieldsCount == 2) - assert(structType.getFields(0).getName == "a") - assert(structType.getFields(0).getDataType.hasInteger) - assert(structType.getFields(0).getNullable) - assert(structType.getFields(1).getName == "b") - assert(structType.getFields(1).getDataType.hasString) - assert(!structType.getFields(1).getNullable) + assert(row.schema == resultType) + assert(resultType == schema) } - test("an invalid array literal") { + test("an empty array literal") { val literalProto = proto.Expression.Literal .newBuilder() .setArray(proto.Expression.Literal.Array.newBuilder()) .build() - intercept[InvalidPlanInput] { - LiteralValueProtoConverter.toScalaValue(literalProto) - } + val result = FromProtoToScalaConverter.convertToValue(literalProto) + assert(result == Seq.empty[Any]) } - test("an invalid map literal") { + test("an empty map literal") { val literalProto = proto.Expression.Literal .newBuilder() .setMap(proto.Expression.Literal.Map.newBuilder()) .build() - intercept[InvalidPlanInput] { - LiteralValueProtoConverter.toScalaValue(literalProto) - } + val result = FromProtoToScalaConverter.convertToValue(literalProto) + assert(result == Map.empty[Any, Any]) } }