Skip to content

Commit

Permalink
[CALCITE-6593] NPE when outer joining tables with many fields and unm…
Browse files Browse the repository at this point in the history
…atching rows
  • Loading branch information
rorueda committed Sep 24, 2024
1 parent 8d3cb82 commit 3862a10
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.calcite.linq4j.tree.NewArrayExpression;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.Statement;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.linq4j.tree.UnaryExpression;
import org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -213,8 +214,15 @@ static Expression joinSelector(JoinRelType joinType, PhysType physType,
final Expression copyExpr =
Nullness.castNonNull(
inputPhysType.getFormat().copy(parameter, Nullness.castNonNull(compactOutputVar),
outputField, fieldCount));
compactCode.add(Expressions.statement(copyExpr));
outputField, fieldCount));
Statement copyStatement = Expressions.statement(copyExpr);
if (joinType.generatesNullsOn(ord.i)) {
// [CALCITE-6593] NPE when outer joining tables with many fields and unmatching rows
copyStatement =
Expressions.ifThen(Expressions.notEqual(parameter, Expressions.constant(null)),
copyStatement);
}
compactCode.add(copyStatement);
outputField += fieldCount;
continue;
}
Expand Down Expand Up @@ -243,8 +251,12 @@ static Expression joinSelector(JoinRelType joinType, PhysType physType,
// public String[] apply(org.apache.calcite.interpreter.Row left,
// org.apache.calcite.interpreter.Row right) {
// String[] outputArray = new String[left.length + right.length];
// System.arraycopy(left.copyValues(), 0, outputArray, 0, left.length);
// System.arraycopy(right.copyValues(), 0, outputArray, left.length, right.length);
// if (left != null) { // because left is null when left side is empty
// System.arraycopy(left.copyValues(), 0, outputArray, 0, left.length);
// }
// if (right != null) { // because right is null when right side is empty
// System.arraycopy(right.copyValues(), 0, outputArray, left.length, right.length);
// }
// return outputArray;
// }
// public String[] apply(Object left, Object right) {
Expand Down
217 changes: 177 additions & 40 deletions core/src/test/java/org/apache/calcite/test/LargeGeneratedJoinTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.calcite.test;

import org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.calcite.adapter.java.AbstractQueryableTable;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.config.Lex;
Expand All @@ -26,10 +25,8 @@
import org.apache.calcite.linq4j.Linq4j;
import org.apache.calcite.linq4j.QueryProvider;
import org.apache.calcite.linq4j.Queryable;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.schema.QueryableTable;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
Expand All @@ -46,10 +43,10 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
Expand Down Expand Up @@ -85,7 +82,7 @@ static RowT row(FieldT... fields) {
};
}

private static QueryableTable tab(int fieldCount) {
private static QueryableTable tab(String table, int fieldCount) {
List<Row> lRow = new ArrayList<>();
for (int r = 0; r < 2; r++) {
Object[] current = new Object[fieldCount];
Expand All @@ -97,7 +94,7 @@ private static QueryableTable tab(int fieldCount) {

List<FieldT> fields = new ArrayList<>();
for (int i = 0; i < fieldCount; i++) {
fields.add(field("F_" + i));
fields.add(field(table + "_F_" + i));
}

final Enumerable<?> enumerable = Linq4j.asEnumerable(lRow);
Expand All @@ -114,65 +111,205 @@ private static QueryableTable tab(int fieldCount) {
};
}

@Test public void test() throws SQLException {
private static CalciteAssert.AssertQuery assertQuery(String sql) {
Schema rootSchema = new AbstractSchema() {
@Override protected Map<String, Table> getTableMap() {
return ImmutableMap.of("T0", tab(100),
"T1", tab(101));
return ImmutableMap.of("T0", tab("T0", 100),
"T1", tab("T1", 101));
}
};

final CalciteSchema sp = CalciteSchema.createRootSchema(false, true);
sp.add("ROOT", rootSchema);

final CalciteAssert.AssertThat ca = CalciteAssert.that()
.with(CalciteConnectionProperty.LEX, Lex.JAVA)
.withSchema("ROOT", rootSchema)
.withDefaultSchema("ROOT");

return ca.query(sql);
}

@Test public void test() {
String sql = "SELECT * \n"
+ "FROM ROOT.T0 \n"
+ "JOIN ROOT.T1 \n"
+ "ON TRUE";

sql = "select F_0||F_1, * from (" + sql + ")";
sql = "select T0_F_0||T0_F_1, * from (" + sql + ")";

final CalciteAssert.AssertQuery query = assertQuery(sql);
query.returns(rs -> {
try {
assertTrue(rs.next());
assertEquals(1 + 100 + 101, rs.getMetaData().getColumnCount());
long row = 0;
do {
++row;
for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
// Rows have the format: v0v1, v0, v1, v2, ..., v99, v0, v1, v2, ..., v99, v100
if (i == 1) {
assertEquals("v0v1", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i == rs.getMetaData().getColumnCount()) {
assertEquals("v100", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else {
assertEquals("v" + ((i - 2) % 100), rs.getString(i),
"Error at row: " + row + ", column: " + i);
}
}
} while (rs.next());
assertEquals(4, row);
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}

final CalciteAssert.AssertThat ca = CalciteAssert.that()
.with(CalciteConnectionProperty.LEX, Lex.JAVA)
.withSchema("ROOT", rootSchema)
.withDefaultSchema("ROOT");
/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6593">[CALCITE-6593]
* NPE when outer joining tables with many fields and unmatching rows</a>.
*/
@Test public void testLeftJoinWithEmptyRightSide() {
String sql = "SELECT * \n"
+ "FROM ROOT.T0 \n"
+ "LEFT JOIN (SELECT * FROM ROOT.T1 WHERE T1_F_0 = 'xyz') \n"
+ "ON TRUE";

sql = "select T0_F_0||T0_F_1, * from (" + sql + ")";

final CalciteAssert.AssertQuery query = ca.query(sql);
query.withHook(Hook.PLANNER, (Consumer<RelOptPlanner>) pl -> {
pl.removeRule(EnumerableRules.ENUMERABLE_CORRELATE_RULE);
pl.addRule(EnumerableRules.ENUMERABLE_BATCH_NESTED_LOOP_JOIN_RULE);
final CalciteAssert.AssertQuery query = assertQuery(sql);
query.returns(rs -> {
try {
assertTrue(rs.next());
assertEquals(1 + 100 + 101, rs.getMetaData().getColumnCount());
long row = 0;
do {
++row;
for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
// Rows have the format: v0v1, v0, v1, v2, ..., v99, null, ..., null
if (i == 1) {
assertEquals("v0v1", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i <= 101) {
assertEquals("v" + (i - 2), rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else {
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
}
}
} while (rs.next());
assertEquals(2, row);
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}

try {
query.returns(rs -> {
try {
assertTrue(rs.next());
assertEquals(101 + 100 + 1, rs.getMetaData().getColumnCount());
long row = 0;
do {
++row;
for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
// Rows have the format: v0v1, v0, v1, v2, ..., v99, v0, v1, v2, ..., v99, v100
/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6593">[CALCITE-6593]
* NPE when outer joining tables with many fields and unmatching rows</a>.
*/
@Test public void testRightJoinWithEmptyLeftSide() {
String sql = "SELECT * \n"
+ "FROM (SELECT * FROM ROOT.T0 WHERE T0_F_0 = 'xyz') \n"
+ "RIGHT JOIN ROOT.T1 \n"
+ "ON TRUE";

sql = "select T1_F_0||T1_F_1, * from (" + sql + ")";

final CalciteAssert.AssertQuery query = assertQuery(sql);
query.returns(rs -> {
try {
assertTrue(rs.next());
assertEquals(1 + 100 + 101, rs.getMetaData().getColumnCount());
long row = 0;
do {
++row;
for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
// Rows have the format: v0v1, null, ..., null, v0, v1, v2, ..., v100
if (i == 1) {
assertEquals("v0v1", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i <= 101) {
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else {
assertEquals("v" + (i - 2 - 100), rs.getString(i),
"Error at row: " + row + ", column: " + i);
}
}
} while (rs.next());
assertEquals(2, row);
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6593">[CALCITE-6593]
* NPE when outer joining tables with many fields and unmatching rows</a>.
*/
@Test public void testFullJoinWithUnmatchedRows() {
String sql = "SELECT * \n"
+ "FROM ROOT.T0 \n"
+ "FULL JOIN ROOT.T1 \n"
+ "ON T0_F_0 <> T1_F_0";

sql = "select T0_F_0||T0_F_1, T1_F_0||T1_F_1, * from (" + sql + ")";

final CalciteAssert.AssertQuery query = assertQuery(sql);
query.returns(rs -> {
try {
assertTrue(rs.next());
assertEquals(1 + 1 + 100 + 101, rs.getMetaData().getColumnCount());
long row = 0;
do {
++row;
for (int i = 1; i <= rs.getMetaData().getColumnCount(); ++i) {
if (row <= 2) {
// First 2 rows have the format: v0v1, null, v0, v1, v2, ..., v99, null, ..., null
if (i == 1) {
assertEquals("v0v1", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i == rs.getMetaData().getColumnCount()) {
assertEquals("v100", rs.getString(i),
} else if (i == 2) {
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i <= 102) {
assertEquals("v" + (i - 3), rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else {
assertEquals("v" + ((i - 2) % 100), rs.getString(i),
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
}
} else {
// Last 2 rows have the format: null, v0v1, null, ..., null, v0, v1, v2, ..., v100
if (i == 1) {
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i == 2) {
assertEquals("v0v1", rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else if (i <= 102) {
assertNull(rs.getString(i),
"Error at row: " + row + ", column: " + i);
} else {
assertEquals("v" + (i - 3 - 100), rs.getString(i),
"Error at row: " + row + ", column: " + i);
}
}
} while (rs.next());
assertEquals(4, row);
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
} catch (RuntimeException ex) {
throw (SQLException) ex.getCause();
}
}
} while (rs.next());
assertEquals(4, row);
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
}
}

0 comments on commit 3862a10

Please sign in to comment.