Skip to content

Commit 803ea95

Browse files
mihailoale-dbcloud-fan
authored andcommitted
[SPARK-54041][SQL] Refactor ParameterizedQuery arguments validation
### What changes were proposed in this pull request? In this issue I propose to refactor `ParameterizedQuery` arguments validation to `ParameterizedQueryArgumentsValidator` so it can be reused between single-pass and fixed-point analyzer implementations. We also remove one redundant case from the `ParameterizedQueryArgumentsValidator.isNotAllowed` (`Alias` shouldn't call `isNotAllowed` again as it introduces unnecessary overhead to the method) to improve performance. ### Why are the changes needed? To ease code maintenance between single-pass and fixed-point analyzer implementations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests (refactor). ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52744 from mihailoale-db/refactorparamcheckargs. Authored-by: mihailoale-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 029393a commit 803ea95

File tree

3 files changed

+73
-39
lines changed

3 files changed

+73
-39
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions.{
21+
Alias,
22+
CreateArray,
23+
CreateMap,
24+
CreateNamedStruct,
25+
Expression,
26+
Literal,
27+
MapFromArrays,
28+
MapFromEntries,
29+
VariableReference
30+
}
31+
32+
/**
33+
* Object used to validate arguments of [[ParameterizedQuery]] nodes.
34+
*/
35+
object ParameterizedQueryArgumentsValidator {
36+
37+
/**
38+
* Validates the list of provided arguments. In case there is a invalid argument, throws
39+
* `INVALID_SQL_ARG` exception.
40+
*/
41+
def apply(arguments: Iterable[(String, Expression)]): Unit = {
42+
arguments.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
43+
expr.failAnalysis(
44+
errorClass = "INVALID_SQL_ARG",
45+
messageParameters = Map("name" -> name))
46+
}
47+
}
48+
49+
/**
50+
* Recursively checks the provided expression tree. In case there is an invalid expression type
51+
* returns `false`. Otherwise, returns `true`.
52+
*/
53+
private def isNotAllowed(expression: Expression): Boolean = expression.exists {
54+
case _: Literal | _: CreateArray | _: CreateNamedStruct | _: CreateMap | _: MapFromArrays |
55+
_: MapFromEntries | _: VariableReference | _: Alias =>
56+
false
57+
case _ => true
58+
}
59+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.SparkException
21-
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference}
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, SubqueryExpression, Unevaluable}
2222
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand}
2323
import org.apache.spark.sql.catalyst.rules.Rule
2424
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
@@ -173,19 +173,6 @@ object MoveParameterizedQueriesDown extends Rule[LogicalPlan] {
173173
* from the user-specified arguments.
174174
*/
175175
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
176-
private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
177-
def isNotAllowed(expr: Expression): Boolean = expr.exists {
178-
case _: Literal | _: CreateArray | _: CreateNamedStruct |
179-
_: CreateMap | _: MapFromArrays | _: MapFromEntries | _: VariableReference => false
180-
case a: Alias => isNotAllowed(a.child)
181-
case _ => true
182-
}
183-
args.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
184-
expr.failAnalysis(
185-
errorClass = "INVALID_SQL_ARG",
186-
messageParameters = Map("name" -> name))
187-
}
188-
}
189176

190177
private def bind(p0: LogicalPlan)(f: PartialFunction[Expression, Expression]): LogicalPlan = {
191178
var stop = false
@@ -210,15 +197,15 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
210197
s"must be equal to the number of argument values ${argValues.length}.")
211198
}
212199
val args = argNames.zip(argValues).toMap
213-
checkArgs(args)
200+
ParameterizedQueryArgumentsValidator(args)
214201
bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }
215202

216203
case PosParameterizedQuery(child, args)
217204
if !child.containsPattern(UNRESOLVED_WITH) &&
218205
args.forall(_.resolved) =>
219206

220207
val indexedArgs = args.zipWithIndex
221-
checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
208+
ParameterizedQueryArgumentsValidator(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
222209

223210
val positions = scala.collection.mutable.Set.empty[Int]
224211
bind(child) { case p @ PosParameter(pos) => positions.add(pos); p }
@@ -238,7 +225,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
238225
val finalName = if (name.isEmpty) s"_$index" else name
239226
finalName -> arg
240227
}
241-
checkArgs(allArgs)
228+
ParameterizedQueryArgumentsValidator(allArgs)
242229

243230
// Collect parameter types used in the query to enforce invariants
244231
var hasNamedParam = false

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ import org.apache.spark.sql
4040
import org.apache.spark.sql.{AnalysisException, Artifact, DataSourceRegistration, Encoder, Encoders, ExperimentalMethods, Row, SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions, SparkSessionExtensionsProvider, UDTFRegistration}
4141
import org.apache.spark.sql.artifact.ArtifactManager
4242
import org.apache.spark.sql.catalyst._
43-
import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery, NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
43+
import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery, NameParameterizedQuery, ParameterizedQueryArgumentsValidator, PosParameterizedQuery, UnresolvedRelation}
4444
import org.apache.spark.sql.catalyst.encoders._
45-
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal, MapFromArrays, MapFromEntries, VariableReference}
45+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, Literal}
4646
import org.apache.spark.sql.catalyst.parser.{HybridParameterContext, NamedParameterContext, ParserInterface, PositionalParameterContext}
4747
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, OneRowRelation, Project, Range}
4848
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
@@ -453,27 +453,15 @@ class SparkSession private(
453453
// Analyze the plan to resolve expressions
454454
val analyzed = sessionState.analyzer.execute(fakePlan)
455455

456-
// Validate: the expression tree must only contain allowed expression types.
457-
// This mirrors the validation in BindParameters.checkArgs.
458-
// We check this BEFORE optimization to catch unsupported functions like str_to_map.
459-
def isNotAllowed(expr: Expression): Boolean = expr.exists {
460-
case _: Literal | _: CreateArray | _: CreateNamedStruct |
461-
_: CreateMap | _: MapFromArrays | _: MapFromEntries | _: VariableReference => false
462-
case a: Alias => isNotAllowed(a.child)
463-
case _ => true
464-
}
465-
466-
analyzed.asInstanceOf[Project].projectList.foreach { alias =>
467-
val optimizedExpr = alias.asInstanceOf[Alias].child
468-
if (isNotAllowed(optimizedExpr)) {
469-
// Both modern and legacy modes use INVALID_SQL_ARG for sql() API argument validation.
470-
// UNSUPPORTED_EXPR_FOR_PARAMETER is reserved for EXECUTE IMMEDIATE.
471-
throw new AnalysisException(
472-
errorClass = "INVALID_SQL_ARG",
473-
messageParameters = Map("name" -> alias.name),
474-
origin = optimizedExpr.origin)
475-
}
456+
val expressionsToValidate = analyzed.asInstanceOf[Project].projectList.map {
457+
case alias: Alias =>
458+
(alias.name, alias.child)
459+
case other =>
460+
throw SparkException.internalError(
461+
s"Expected an Alias, but got ${other.getClass.getSimpleName}"
462+
)
476463
}
464+
ParameterizedQueryArgumentsValidator(expressionsToValidate)
477465

478466
// Optimize to constant-fold expressions. After optimization, all allowed expressions
479467
// should be folded to Literals.

0 commit comments

Comments
 (0)