Skip to content

Commit

Permalink
[CALCITE-6435] SqlToRel conversion of IN expressions may lead to inco…
Browse files Browse the repository at this point in the history
…rrect simplifications

Conversion path for comparisions generated from IN expressions was handling types differently.
This may have lead to some over-simplification in some cases.

Altered the conversion to do the full SqlToRex conversion steps for these generated nodes as well.
Added an extra safeguard check to RexSimplify to prevent the bug from being triggered.
  • Loading branch information
kgyrtkirk committed Jul 11, 2024
1 parent a98508f commit 73846cc
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 44 deletions.
7 changes: 5 additions & 2 deletions core/src/main/java/org/apache/calcite/rex/RexSimplify.java
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,10 @@ private <C extends Comparable<C>> RexNode simplifyAnd2ForUnknownAsFalse(
final RexLiteral literal = comparison.literal;
final RexLiteral prevLiteral =
equalityConstantTerms.put(comparison.ref, literal);
if (prevLiteral != null && !literal.equals(prevLiteral)) {

if (prevLiteral != null
&& literal.getType().equals(prevLiteral.getType())
&& !literal.equals(prevLiteral)) {
return rexBuilder.makeLiteral(false);
}
} else if (RexUtil.isReferenceOrAccess(left, true)
Expand Down Expand Up @@ -1753,7 +1756,7 @@ private <C extends Comparable<C>> RexNode simplifyAnd2ForUnknownAsFalse(
if (literal2 == null) {
continue;
}
if (!literal1.equals(literal2)) {
if (literal1.getType().equals(literal2.getType()) && !literal1.equals(literal2)) {
// If an expression is equal to two different constants,
// it is not satisfiable
return rexBuilder.makeLiteral(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@
import org.apache.calcite.sql.validate.SqlValidatorTable;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.sql.validate.SqlWithItemTableRef;
import org.apache.calcite.sql2rel.SqlToRelConverter.Blackboard;
import org.apache.calcite.sql2rel.SqlToRelConverter.SqlIdentifierFinder;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
Expand Down Expand Up @@ -1185,16 +1187,16 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
}
final SqlNode leftKeyNode = call.operand(0);

final List<RexNode> leftKeys;
final List<SqlNode> leftSqlKeys;
switch (leftKeyNode.getKind()) {
case ROW:
leftKeys = new ArrayList<>();
leftSqlKeys = new ArrayList<>();
for (SqlNode sqlExpr : ((SqlBasicCall) leftKeyNode).getOperandList()) {
leftKeys.add(bb.convertExpression(sqlExpr));
leftSqlKeys.add(sqlExpr);
}
break;
default:
leftKeys = ImmutableList.of(bb.convertExpression(leftKeyNode));
leftSqlKeys = ImmutableList.of(leftKeyNode);
}

if (query instanceof SqlNodeList) {
Expand All @@ -1205,7 +1207,7 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
subQuery.expr =
convertInToOr(
bb,
leftKeys,
leftSqlKeys,
valueList,
(SqlInOperator) call.getOperator());
return;
Expand All @@ -1216,6 +1218,10 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
// reference to Q below.
}

final List<RexNode> leftKeys = leftSqlKeys.stream()
.map(bb::convertExpression)
.collect(toImmutableList());

// Project out the search columns from the left side

// Q1:
Expand Down Expand Up @@ -1719,12 +1725,11 @@ public RelNode convertToSingleValueSubq(
*/
private @Nullable RexNode convertInToOr(
final Blackboard bb,
final List<RexNode> leftKeys,
final List<SqlNode> leftKeys,
SqlNodeList valuesList,
SqlInOperator op) {
final List<RexNode> comparisons = new ArrayList<>();
for (SqlNode rightVals : valuesList) {
RexNode rexComparison;
final SqlOperator comparisonOp;
if (op instanceof SqlQuantifyOperator) {
comparisonOp =
Expand All @@ -1733,25 +1738,23 @@ public RelNode convertToSingleValueSubq(
} else {
comparisonOp = SqlStdOperatorTable.EQUALS;
}
RexNode rexComparison;
if (leftKeys.size() == 1) {
rexComparison =
rexBuilder.makeCall(comparisonOp,
leftKeys.get(0),
ensureSqlType(leftKeys.get(0).getType(),
bb.convertExpression(rightVals)));
SqlCall sqlCall =
comparisonOp.createCall(rightVals.getParserPosition(), leftKeys.get(0), rightVals);
rexComparison = bb.convertExpression(sqlCall);
} else {
assert rightVals instanceof SqlCall;
final SqlBasicCall call = (SqlBasicCall) rightVals;
assert (call.getOperator() instanceof SqlRowOperator)
&& call.operandCount() == leftKeys.size();
rexComparison =
RexUtil.composeConjunction(rexBuilder,
Util.transform(
Pair.zip(leftKeys, call.getOperandList()),
pair -> rexBuilder.makeCall(comparisonOp, pair.left,
// TODO: remove requireNonNull when checkerframework issue resolved
ensureSqlType(requireNonNull(pair.left, "pair.left").getType(),
bb.convertExpression(pair.right)))));
RexUtil.composeConjunction(
rexBuilder, Util.transform(
Pair.zip(leftKeys, call.getOperandList()),
pair -> bb.convertExpression(
comparisonOp.createCall(
rightVals.getParserPosition(), pair.left, pair.right))));
}
comparisons.add(rexComparison);
}
Expand All @@ -1770,18 +1773,6 @@ public RelNode convertToSingleValueSubq(
}
}

/** Ensures that an expression has a given {@link SqlTypeName}, applying a
* cast if necessary. If the expression already has the right type family,
* returns the expression unchanged. */
private RexNode ensureSqlType(RelDataType type, RexNode node) {
if (type.getSqlTypeName() == node.getType().getSqlTypeName()
|| (type.getSqlTypeName() == SqlTypeName.VARCHAR
&& node.getType().getSqlTypeName() == SqlTypeName.CHAR)) {
return node;
}
return rexBuilder.ensureType(type, node, true);
}

/**
* Gets the list size threshold under which {@link #convertInToOr} is used.
* Lists of this size or greater will instead be converted to use a join
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,15 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}

@Test void testIncorrectInType() {
final String sql = "select ename from emp "
+ " where ename in ( 'Sebastian' ) and ename = 'Sebastian' and deptno < 100";
sql(sql)
.withTrim(true)
.withRule()
.checkUnchanged();
}

@Test void testSemiJoinRule() {
final String sql = "select dept.* from dept join (\n"
+ " select distinct deptno from emp\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4950,6 +4950,19 @@ LogicalUnion(all=[true])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EXPR$0=[LOWER($1)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testIncorrectInType">
<Resource name="sql">
<![CDATA[select ename from emp where ename in ( 'Sebastian' ) and ename = 'Sebastian' and deptno < 100]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(ENAME=[$0])
LogicalFilter(condition=[AND(=($0, 'Sebastian'), <($1, 100))])
LogicalProject(ENAME=[$1], DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ GROUP by deptno, job]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(JOB_NAME=[CASE(SEARCH($1, Sarg['810000', '820000']:CHAR(6)), $1, 'error':VARCHAR(10))], EXPR$1=[$2])
LogicalProject(JOB_NAME=[CASE(SEARCH($1, Sarg['810000':VARCHAR(10), '820000':VARCHAR(10)]:VARCHAR(10)), $1, 'error':VARCHAR(10))], EXPR$1=[$2])
LogicalAggregate(group=[{0, 1}], EXPR$1=[COUNT()])
LogicalProject(DEPTNO=[$7], JOB=[$2], EMPNO=[$0])
LogicalFilter(condition=[OR(<>($2, ''), =($2, '810000'), =($2, '820000'))])
Expand Down Expand Up @@ -561,7 +561,7 @@ GROUP BY GROUPING SETS ((empno, derived_col),(empno))]]>
<Resource name="plan">
<![CDATA[
LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}]])
LogicalProject(EMPNO=[$0], DERIVED_COL=[CASE(SEARCH($1, Sarg['Eric', 'Fred']:CHAR(4)), 'CEO ', 'Other')])
LogicalProject(EMPNO=[$0], DERIVED_COL=[CASE(SEARCH($1, Sarg['Eric':VARCHAR(20), 'Fred':VARCHAR(20)]:VARCHAR(20)), 'CEO ', 'Other')])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand All @@ -579,7 +579,7 @@ GROUP BY GROUPING SETS (
<Resource name="plan">
<![CDATA[
LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}]])
LogicalProject(EMPNO=[$0], EXPR$1=[CASE(SEARCH($1, Sarg['Eric', 'Fred']:CHAR(4)), 'Manager', 'Other ')])
LogicalProject(EMPNO=[$0], EXPR$1=[CASE(SEARCH($1, Sarg['Eric':VARCHAR(20), 'Fred':VARCHAR(20)]:VARCHAR(20)), 'Manager', 'Other ')])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down Expand Up @@ -2240,7 +2240,7 @@ group by case when coalesce(ename, 'a') in ('1', '2') then 'CKA' else 'QT' END]]
<Resource name="plan">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1)])
LogicalProject(EXPR$0=[CASE(SEARCH($1, Sarg['1', '2']:CHAR(1)), 'CKA', 'QT ')], DEPTNO=[$7])
LogicalProject(EXPR$0=[CASE(SEARCH($1, Sarg['1':VARCHAR(20), '2':VARCHAR(20)]:VARCHAR(20)), 'CKA', 'QT ')], DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down
33 changes: 30 additions & 3 deletions core/src/test/resources/sql/sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -3149,7 +3149,7 @@ select * from "scott".emp where comm in (300, 500, null);

!ok

EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[300:DECIMAL(7, 2), 500:DECIMAL(7, 2)]:DECIMAL(7, 2)], expr#9=[SEARCH($t6, $t8)], proj#0..7=[{exprs}], $condition=[$t9])
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[Sarg[300:DECIMAL(12, 2), 500:DECIMAL(12, 2)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10])
EnumerableTableScan(table=[[scott, EMP]])
!plan

Expand Down Expand Up @@ -3177,7 +3177,7 @@ select *, comm in (300, 500, null) as i from "scott".emp;

!ok

EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[300:DECIMAL(7, 2), 500:DECIMAL(7, 2)]:DECIMAL(7, 2)], expr#9=[SEARCH($t6, $t8)], expr#10=[null:BOOLEAN], expr#11=[OR($t9, $t10)], proj#0..7=[{exprs}], I=[$t11])
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[Sarg[300:DECIMAL(12, 2), 500:DECIMAL(12, 2)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[OR($t10, $t11)], proj#0..7=[{exprs}], I=[$t12])
EnumerableTableScan(table=[[scott, EMP]])
!plan

Expand Down Expand Up @@ -3218,7 +3218,34 @@ select *, comm not in (300, 500, null) as i from "scott".emp;

!ok

EnumerableCalc(expr#0..7=[{inputs}], expr#8=[Sarg[(-∞..300:DECIMAL(7, 2)), (300:DECIMAL(7, 2)..500:DECIMAL(7, 2)), (500:DECIMAL(7, 2)..+∞)]:DECIMAL(7, 2)], expr#9=[SEARCH($t6, $t8)], expr#10=[null:BOOLEAN], expr#11=[AND($t9, $t10)], proj#0..7=[{exprs}], I=[$t11])
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[Sarg[(-∞..300:DECIMAL(12, 2)), (300:DECIMAL(12, 2)..500:DECIMAL(12, 2)), (500:DECIMAL(12, 2)..+∞)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], I=[$t12])
EnumerableTableScan(table=[[scott, EMP]])
!plan

# Previous NOT IN expressions in conjunction form
select *, (comm <> 300 and comm <> 500 and comm <> null) as i from "scott".emp;
+-------+--------+-----------+------+------------+---------+---------+--------+-------+
| EMPNO | ENAME | JOB | MGR | HIREDATE | SAL | COMM | DEPTNO | I |
+-------+--------+-----------+------+------------+---------+---------+--------+-------+
| 7369 | SMITH | CLERK | 7902 | 1980-12-17 | 800.00 | | 20 | |
| 7499 | ALLEN | SALESMAN | 7698 | 1981-02-20 | 1600.00 | 300.00 | 30 | false |
| 7521 | WARD | SALESMAN | 7698 | 1981-02-22 | 1250.00 | 500.00 | 30 | false |
| 7566 | JONES | MANAGER | 7839 | 1981-02-04 | 2975.00 | | 20 | |
| 7654 | MARTIN | SALESMAN | 7698 | 1981-09-28 | 1250.00 | 1400.00 | 30 | |
| 7698 | BLAKE | MANAGER | 7839 | 1981-01-05 | 2850.00 | | 30 | |
| 7782 | CLARK | MANAGER | 7839 | 1981-06-09 | 2450.00 | | 10 | |
| 7788 | SCOTT | ANALYST | 7566 | 1987-04-19 | 3000.00 | | 20 | |
| 7839 | KING | PRESIDENT | | 1981-11-17 | 5000.00 | | 10 | |
| 7844 | TURNER | SALESMAN | 7698 | 1981-09-08 | 1500.00 | 0.00 | 30 | |
| 7876 | ADAMS | CLERK | 7788 | 1987-05-23 | 1100.00 | | 20 | |
| 7900 | JAMES | CLERK | 7698 | 1981-12-03 | 950.00 | | 30 | |
| 7902 | FORD | ANALYST | 7566 | 1981-12-03 | 3000.00 | | 20 | |
| 7934 | MILLER | CLERK | 7782 | 1982-01-23 | 1300.00 | | 10 | |
+-------+--------+-----------+------+------------+---------+---------+--------+-------+
(14 rows)

!ok
EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[Sarg[(-∞..300:DECIMAL(12, 2)), (300:DECIMAL(12, 2)..500:DECIMAL(12, 2)), (500:DECIMAL(12, 2)..+∞)]:DECIMAL(12, 2)], expr#10=[SEARCH($t8, $t9)], expr#11=[null:BOOLEAN], expr#12=[AND($t10, $t11)], proj#0..7=[{exprs}], I=[$t12])
EnumerableTableScan(table=[[scott, EMP]])
!plan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[AND("
+ "=($3, 'High Top Dried Mushrooms'), "
+ "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+ "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
+ "=($30, 'WA'))], "
+ "projects=[[$30, $29, $3]], groups=[{0, 1, 2}], aggs=[[]])\n";
sql(sql)
Expand Down Expand Up @@ -1072,7 +1072,7 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[AND("
+ "=($3, 'High Top Dried Mushrooms'), "
+ "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+ "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
+ "=($30, 'WA'))], "
+ "projects=[[$30, $29, $3]])\n";
sql(sql)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[AND("
+ "=($3, 'High Top Dried Mushrooms'), "
+ "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+ "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
+ "=($30, 'WA'))], "
+ "projects=[[$30, $29, $3]], groups=[{0, 1, 2}], aggs=[[]])\n";
sql(sql)
Expand Down Expand Up @@ -1347,7 +1347,7 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[AND("
+ "=($3, 'High Top Dried Mushrooms'), "
+ "SEARCH($87, Sarg['Q2', 'Q3']:CHAR(2)), "
+ "SEARCH($87, Sarg['Q2':VARCHAR, 'Q3':VARCHAR]:VARCHAR), "
+ "=($30, 'WA'))], "
+ "projects=[[$30, $29, $3]])\n";
sql(sql)
Expand Down

0 comments on commit 73846cc

Please sign in to comment.