Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 65 additions & 65 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

73 changes: 54 additions & 19 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ class Expression(google.protobuf.message.Message):

ELEMENT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The element type of the array.
Expand All @@ -509,19 +510,37 @@ class Expression(google.protobuf.message.Message):
global___Expression.Literal
]:
"""The literal values that make up the array elements."""
@property
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array:
"""The type of the array. You don't need to set this field if the type information is not needed.

If the element type can be inferred from the first element of the elements field,
then you don't need to set data_type.element_type to save space.

On the other hand, redundant type information is also acceptable.
"""
def __init__(
self,
*,
element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
self,
field_name: typing_extensions.Literal[
"data_type", b"data_type", "element_type", b"element_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"element_type", b"element_type", "elements", b"elements"
"data_type",
b"data_type",
"element_type",
b"element_type",
"elements",
b"elements",
],
) -> None: ...

Expand All @@ -532,6 +551,7 @@ class Expression(google.protobuf.message.Message):
VALUE_TYPE_FIELD_NUMBER: builtins.int
KEYS_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The key type of the map.
Expand Down Expand Up @@ -559,23 +579,35 @@ class Expression(google.protobuf.message.Message):
global___Expression.Literal
]:
"""The literal values that make up the map."""
@property
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map:
"""The type of the map. You don't need to set this field if the type information is not needed.

If the key/value types can be inferred from the first element of the keys/values fields,
then you don't need to set data_type.key_type/data_type.value_type to save space.

On the other hand, redundant type information is also acceptable.
"""
def __init__(
self,
*,
key_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
value_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
keys: collections.abc.Iterable[global___Expression.Literal] | None = ...,
values: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"key_type", b"key_type", "value_type", b"value_type"
"data_type", b"data_type", "key_type", b"key_type", "value_type", b"value_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"data_type",
b"data_type",
"key_type",
b"key_type",
"keys",
Expand All @@ -592,6 +624,7 @@ class Expression(google.protobuf.message.Message):

STRUCT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int
@property
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The type of the struct.
Expand All @@ -606,19 +639,35 @@ class Expression(google.protobuf.message.Message):
global___Expression.Literal
]:
"""The literal values that make up the struct elements."""
@property
def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct:
"""The type of the struct. You don't need to set this field if the type information is not needed.

Whether data_type_struct.fields.data_type should be set depends on
whether each field's type can be inferred from the elements field.
"""
def __init__(
self,
*,
struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["struct_type", b"struct_type"]
self,
field_name: typing_extensions.Literal[
"data_type_struct", b"data_type_struct", "struct_type", b"struct_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"elements", b"elements", "struct_type", b"struct_type"
"data_type_struct",
b"data_type_struct",
"elements",
b"elements",
"struct_type",
b"struct_type",
],
) -> None: ...

Expand Down Expand Up @@ -750,7 +799,6 @@ class Expression(google.protobuf.message.Message):
STRUCT_FIELD_NUMBER: builtins.int
SPECIALIZED_ARRAY_FIELD_NUMBER: builtins.int
TIME_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
binary: builtins.bytes
Expand Down Expand Up @@ -784,14 +832,6 @@ class Expression(google.protobuf.message.Message):
def specialized_array(self) -> global___Expression.Literal.SpecializedArray: ...
@property
def time(self) -> global___Expression.Literal.Time: ...
@property
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""Data type information for the literal.
This field is required only in the root literal message for null values or
for data types (e.g., array, map, or struct) with non-trivial information.
If the data_type field is not set at the root level, the data type will be
inferred or retrieved from the deprecated data type fields using best efforts.
"""
def __init__(
self,
*,
Expand All @@ -817,7 +857,6 @@ class Expression(google.protobuf.message.Message):
struct: global___Expression.Literal.Struct | None = ...,
specialized_array: global___Expression.Literal.SpecializedArray | None = ...,
time: global___Expression.Literal.Time | None = ...,
data_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -832,8 +871,6 @@ class Expression(google.protobuf.message.Message):
b"byte",
"calendar_interval",
b"calendar_interval",
"data_type",
b"data_type",
"date",
b"date",
"day_time_interval",
Expand Down Expand Up @@ -885,8 +922,6 @@ class Expression(google.protobuf.message.Message):
b"byte",
"calendar_interval",
b"calendar_interval",
"data_type",
b"data_type",
"date",
b"date",
"day_time_interval",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1562,13 +1562,13 @@ class ClientE2ETestSuite
val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra"))

val ob1Schema = new StructType()
.add("min(id)", LongType)
.add("avg(id)", DoubleType)
.add("max(id)", LongType)
.add("min(id)", LongType, nullable = false)
.add("avg(id)", DoubleType, nullable = false)
.add("max(id)", LongType, nullable = false)
val ob2Schema = new StructType()
.add("min(extra)", LongType)
.add("avg(extra)", DoubleType)
.add("max(extra)", LongType)
.add("min(extra)", LongType, nullable = false)
.add("avg(extra)", DoubleType, nullable = false)
.add("max(extra)", LongType, nullable = false)
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,11 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite {
.addElements(proto.Expression.Literal.newBuilder().setString("north").build())
.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build())
.addElements(proto.Expression.Literal.newBuilder().setString("west").build())
b.getLiteralBuilder.getDataTypeBuilder.setStruct(
proto.DataType.Struct
.newBuilder()
.addFields(structField("_1", ProtoDataTypes.DoubleType))
.addFields(structField("_2", stringTypeWithCollation))
.addFields(structField("_3", ProtoDataTypes.DoubleType))
.addFields(structField("_4", stringTypeWithCollation))
.build())
.getDataTypeStructBuilder
.addFields(structField("_1", ProtoDataTypes.DoubleType))
.addFields(structField("_2", stringTypeWithCollation))
.addFields(structField("_3", ProtoDataTypes.DoubleType))
.addFields(structField("_4", stringTypeWithCollation))
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,6 @@ message Expression {
// Reserved for Geometry and Geography.
reserved 27, 28;

// Data type information for the literal.
// This field is required only in the root literal message for null values or
// for data types (e.g., array, map, or struct) with non-trivial information.
// If the data_type field is not set at the root level, the data type will be
// inferred or retrieved from the deprecated data type fields using best efforts.
DataType data_type = 100;

message Decimal {
// the string representation.
string value = 1;
Expand All @@ -241,6 +234,14 @@ message Expression {

// The literal values that make up the array elements.
repeated Literal elements = 2;

// The type of the array. You don't need to set this field if the type information is not needed.
//
// If the element type can be inferred from the first element of the elements field,
// then you don't need to set data_type.element_type to save space.
//
// On the other hand, redundant type information is also acceptable.
DataType.Array data_type = 3;
}

message Map {
Expand All @@ -260,6 +261,14 @@ message Expression {

// The literal values that make up the map.
repeated Literal values = 4;

// The type of the map. You don't need to set this field if the type information is not needed.
//
// If the key/value types can be inferred from the first element of the keys/values fields,
// then you don't need to set data_type.key_type/data_type.value_type to save space.
//
// On the other hand, redundant type information is also acceptable.
DataType.Map data_type = 5;
}

message Struct {
Expand All @@ -271,6 +280,12 @@ message Expression {

// The literal values that make up the struct elements.
repeated Literal elements = 2;

// The type of the struct. You don't need to set this field if the type information is not needed.
//
// Whether data_type_struct.fields.data_type should be set depends on
// whether each field's type can be inferred from the elements field.
DataType.Struct data_type_struct = 3;
}

message SpecializedArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, FromProtoToScalaConverter}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkResult[T](
Expand Down Expand Up @@ -422,17 +422,17 @@ private[sql] object SparkResult {

/** Return value is a Seq of pairs, to preserve the order of values. */
private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = {
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
require(metric.getKeysCount == metric.getValuesCount)
val fields = mutable.ArrayBuilder.make[StructField]
val values = mutable.ArrayBuilder.make[Any]
fields.sizeHint(metric.getKeysCount)
values.sizeHint(metric.getKeysCount)
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
Range(0, metric.getKeysCount).foreach { i =>
val (dataType, value) = FromProtoToScalaConverter.convert(metric.getValues(i))
fields += StructField(metric.getKeys(i), dataType, value == null)
values += value
}
new GenericRowWithSchema(values.result(), schema)
new GenericRowWithSchema(values.result(), new StructType(fields.result()))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoBuilderWithOptions, ToLiteralProtoOptions}
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoWithOptions, ToLiteralProtoOptions}
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}

Expand Down Expand Up @@ -67,7 +67,7 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
n match {
case Literal(value, dataTypeOpt, _) =>
builder.setLiteral(
toLiteralProtoBuilderWithOptions(
toLiteralProtoWithOptions(
value,
dataTypeOpt,
ToLiteralProtoOptions(useDeprecatedDataTypeFields = false)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ object DataTypeProtoConverter {
}
}

private def toCatalystArrayType(t: proto.DataType.Array): ArrayType = {
private[common] def toCatalystArrayType(t: proto.DataType.Array): ArrayType = {
ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
}

Expand All @@ -140,7 +140,7 @@ object DataTypeProtoConverter {
StructType.apply(fields)
}

private def toCatalystMapType(t: proto.DataType.Map): MapType = {
private[common] def toCatalystMapType(t: proto.DataType.Map): MapType = {
MapType(toCatalystType(t.getKeyType), toCatalystType(t.getValueType), t.getValueContainsNull)
}

Expand Down
Loading