Skip to content

Commit 931e9e3

Browse files
committed
[SPARK-53658][CONNECT] Accept literal data_type field in ExecutePlanResponse
1 parent 984e16b commit 931e9e3

File tree

9 files changed

+292
-204
lines changed

9 files changed

+292
-204
lines changed

python/pyspark/sql/connect/proto/base_pb2.py

Lines changed: 119 additions & 117 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/base_pb2.pyi

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,24 +1094,33 @@ class ExecutePlanRequest(google.protobuf.message.Message):
10941094

10951095
REATTACH_OPTIONS_FIELD_NUMBER: builtins.int
10961096
RESULT_CHUNKING_OPTIONS_FIELD_NUMBER: builtins.int
1097+
ACCEPT_RESPONSE_OPTIONS_FIELD_NUMBER: builtins.int
10971098
EXTENSION_FIELD_NUMBER: builtins.int
10981099
@property
10991100
def reattach_options(self) -> global___ReattachOptions: ...
11001101
@property
11011102
def result_chunking_options(self) -> global___ResultChunkingOptions: ...
11021103
@property
1104+
def accept_response_options(self) -> global___AcceptResponseOptions:
1105+
"""Options to describe what responses (e.g. using a new field in the response)
1106+
can be accepted.
1107+
"""
1108+
@property
11031109
def extension(self) -> google.protobuf.any_pb2.Any:
11041110
"""Extension type for request options"""
11051111
def __init__(
11061112
self,
11071113
*,
11081114
reattach_options: global___ReattachOptions | None = ...,
11091115
result_chunking_options: global___ResultChunkingOptions | None = ...,
1116+
accept_response_options: global___AcceptResponseOptions | None = ...,
11101117
extension: google.protobuf.any_pb2.Any | None = ...,
11111118
) -> None: ...
11121119
def HasField(
11131120
self,
11141121
field_name: typing_extensions.Literal[
1122+
"accept_response_options",
1123+
b"accept_response_options",
11151124
"extension",
11161125
b"extension",
11171126
"reattach_options",
@@ -1125,6 +1134,8 @@ class ExecutePlanRequest(google.protobuf.message.Message):
11251134
def ClearField(
11261135
self,
11271136
field_name: typing_extensions.Literal[
1137+
"accept_response_options",
1138+
b"accept_response_options",
11281139
"extension",
11291140
b"extension",
11301141
"reattach_options",
@@ -1138,7 +1149,12 @@ class ExecutePlanRequest(google.protobuf.message.Message):
11381149
def WhichOneof(
11391150
self, oneof_group: typing_extensions.Literal["request_option", b"request_option"]
11401151
) -> (
1141-
typing_extensions.Literal["reattach_options", "result_chunking_options", "extension"]
1152+
typing_extensions.Literal[
1153+
"reattach_options",
1154+
"result_chunking_options",
1155+
"accept_response_options",
1156+
"extension",
1157+
]
11421158
| None
11431159
): ...
11441160

@@ -3049,6 +3065,28 @@ class ResultChunkingOptions(google.protobuf.message.Message):
30493065

30503066
global___ResultChunkingOptions = ResultChunkingOptions
30513067

3068+
class AcceptResponseOptions(google.protobuf.message.Message):
3069+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
3070+
3071+
ACCEPT_LITERAL_DATA_TYPE_FIELD_FIELD_NUMBER: builtins.int
3072+
accept_literal_data_type_field: builtins.bool
3073+
"""When true, the client indicates it can handle Literal messages in responses
3074+
that include the data_type field.
3075+
"""
3076+
def __init__(
3077+
self,
3078+
*,
3079+
accept_literal_data_type_field: builtins.bool = ...,
3080+
) -> None: ...
3081+
def ClearField(
3082+
self,
3083+
field_name: typing_extensions.Literal[
3084+
"accept_literal_data_type_field", b"accept_literal_data_type_field"
3085+
],
3086+
) -> None: ...
3087+
3088+
global___AcceptResponseOptions = AcceptResponseOptions
3089+
30523090
class ReattachExecuteRequest(google.protobuf.message.Message):
30533091
DESCRIPTOR: google.protobuf.descriptor.Descriptor
30543092

sql/connect/common/src/main/protobuf/spark/connect/base.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ message ExecutePlanRequest {
334334
oneof request_option {
335335
ReattachOptions reattach_options = 1;
336336
ResultChunkingOptions result_chunking_options = 2;
337+
// Options to describe what responses (e.g. using a new field in the response)
338+
// can be accepted.
339+
AcceptResponseOptions accept_response_options = 3;
337340
// Extension type for request options
338341
google.protobuf.Any extension = 999;
339342
}
@@ -846,6 +849,12 @@ message ResultChunkingOptions {
846849
optional int64 preferred_arrow_chunk_size = 2;
847850
}
848851

852+
message AcceptResponseOptions {
853+
// When true, the client indicates it can handle Literal messages in responses
854+
// that include the data_type field.
855+
bool accept_literal_data_type_field = 1;
856+
}
857+
849858
message ReattachExecuteRequest {
850859
// (Required)
851860
//

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ private[sql] class SparkConnectClient(
138138
.setSessionId(sessionId)
139139
.setClientType(userAgent)
140140
.addAllTags(tags.get.toSeq.asJava)
141+
.addRequestOptions(SparkConnectClient.ACCEPT_RESPONSE_OPTIONS)
141142
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
142143
operationId.foreach { opId =>
143144
require(
@@ -425,6 +426,12 @@ object SparkConnectClient {
425426
private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
426427
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)
427428

429+
private val ACCEPT_RESPONSE_OPTIONS = proto.ExecutePlanRequest.RequestOption
430+
.newBuilder()
431+
.setAcceptResponseOptions(
432+
proto.AcceptResponseOptions.newBuilder().setAcceptLiteralDataTypeField(true).build())
433+
.build()
434+
428435
// for internal tests
429436
private[sql] def apply(channel: ManagedChannel): SparkConnectClient = {
430437
new SparkConnectClient(Configuration(), channel)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
257257
executeHolder.sessionHolder.sessionId,
258258
executeHolder.sessionHolder.serverSessionId,
259259
executeHolder.allObservationAndPlanIds,
260-
observedMetrics ++ accumulatedInPython))
260+
observedMetrics ++ accumulatedInPython,
261+
executeHolder.acceptLiteralDataTypeFieldInResponses))
261262
}
262263

263264
// State transition should be atomic to prevent a situation in which a client of reattachable

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.Row
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.classic.{DataFrame, Dataset}
3333
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
34-
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
34+
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoWithOptions, ToLiteralProtoOptions}
3535
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE, CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
3636
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3737
import org.apache.spark.sql.connect.service.ExecuteHolder
@@ -331,7 +331,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
331331
sessionId,
332332
sessionHolder.serverSessionId,
333333
observationAndPlanIds,
334-
observedMetrics))
334+
observedMetrics,
335+
executeHolder.acceptLiteralDataTypeFieldInResponses))
335336
} else None
336337
}
337338
}
@@ -352,17 +353,21 @@ object SparkConnectPlanExecution {
352353
sessionId: String,
353354
serverSessionId: String,
354355
observationAndPlanIds: Map[String, Long],
355-
metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]]): ExecutePlanResponse = {
356+
metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]],
357+
acceptLiteralDataTypeFieldInResponses: Boolean): ExecutePlanResponse = {
358+
val toLiteralProtoOptions =
359+
ToLiteralProtoOptions(useDeprecatedDataTypeFields = !acceptLiteralDataTypeFieldInResponses)
356360
val observedMetrics = metrics.map { case (name, values) =>
357361
val metrics = ExecutePlanResponse.ObservedMetrics
358362
.newBuilder()
359363
.setName(name)
360364
values.foreach { case (keyOpt, value, dataTypeOpt) =>
361365
dataTypeOpt match {
362366
case Some(dataType) =>
363-
metrics.addValues(toLiteralProto(value, dataType))
367+
metrics.addValues(
368+
toLiteralProtoWithOptions(value, Some(dataType), toLiteralProtoOptions))
364369
case None =>
365-
metrics.addValues(toLiteralProto(value))
370+
metrics.addValues(toLiteralProtoWithOptions(value, None, toLiteralProtoOptions))
366371
}
367372
keyOpt.foreach(metrics.addKeys)
368373
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ private[connect] class ExecuteHolder(
104104
}
105105
}
106106

107+
/**
108+
* If the client can handle Literal messages in responses that include the data_type field.
109+
*/
110+
lazy val acceptLiteralDataTypeFieldInResponses: Boolean = {
111+
request.getRequestOptionsList.asScala.exists { option =>
112+
option.getAcceptResponseOptions.getAcceptLiteralDataTypeField
113+
}
114+
}
115+
107116
val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] =
108117
new ExecuteResponseObserver[proto.ExecutePlanResponse](this)
109118

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,33 +135,30 @@ package object dsl {
135135
.build()
136136
}
137137

138-
def proto_min(e: Expression): Expression =
138+
private def unresolvedFunction(functionName: String, e: Expression): Expression =
139139
Expression
140140
.newBuilder()
141141
.setUnresolvedFunction(
142-
Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e))
142+
Expression.UnresolvedFunction
143+
.newBuilder()
144+
.setFunctionName(functionName)
145+
.addArguments(e))
143146
.build()
144147

148+
def proto_struct(e: Expression): Expression =
149+
unresolvedFunction("struct", e)
150+
151+
def proto_min(e: Expression): Expression =
152+
unresolvedFunction("min", e)
153+
145154
def proto_max(e: Expression): Expression =
146-
Expression
147-
.newBuilder()
148-
.setUnresolvedFunction(
149-
Expression.UnresolvedFunction.newBuilder().setFunctionName("max").addArguments(e))
150-
.build()
155+
unresolvedFunction("max", e)
151156

152157
def proto_sum(e: Expression): Expression =
153-
Expression
154-
.newBuilder()
155-
.setUnresolvedFunction(
156-
Expression.UnresolvedFunction.newBuilder().setFunctionName("sum").addArguments(e))
157-
.build()
158+
unresolvedFunction("sum", e)
158159

159160
def proto_explode(e: Expression): Expression =
160-
Expression
161-
.newBuilder()
162-
.setUnresolvedFunction(
163-
Expression.UnresolvedFunction.newBuilder().setFunctionName("explode").addArguments(e))
164-
.build()
161+
unresolvedFunction("explode", e)
165162

166163
/**
167164
* Create an unresolved function from name parts.

0 commit comments

Comments
 (0)