diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinSpec.java index 04149cf84bc85..56c995c86467e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinSpec.java @@ -24,12 +24,14 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonInclude; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import org.apache.calcite.rex.RexNode; import javax.annotation.Nullable; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -44,6 +46,9 @@ public class DeltaJoinSpec { public static final String FIELD_NAME_LOOKUP_TABLE = "lookupTable"; public static final String FIELD_NAME_LOOKUP_KEYS = "lookupKeys"; public static final String FIELD_NAME_REMAINING_CONDITION = "remainingCondition"; + public static final String FIELD_NAME_PROJECTION_ON_TEMPORAL_TABLE = + "projectionOnTemporalTable"; + public static final String FIELD_NAME_FILTER_ON_TEMPORAL_TABLE = "filterOnTemporalTable"; @JsonProperty(FIELD_NAME_LOOKUP_TABLE) private final TemporalTableSourceSpec lookupTable; @@ -56,15 +61,29 @@ public class DeltaJoinSpec { @JsonProperty(FIELD_NAME_REMAINING_CONDITION) private final @Nullable RexNode remainingCondition; + @JsonProperty(FIELD_NAME_PROJECTION_ON_TEMPORAL_TABLE) + @JsonInclude(JsonInclude.Include.NON_NULL) + private final @Nullable List projectionOnTemporalTable; + + @JsonProperty(FIELD_NAME_FILTER_ON_TEMPORAL_TABLE) + @JsonInclude(JsonInclude.Include.NON_NULL) + private final @Nullable RexNode filterOnTemporalTable; + @JsonCreator public DeltaJoinSpec( @JsonProperty(FIELD_NAME_LOOKUP_TABLE) TemporalTableSourceSpec lookupTable, @JsonProperty(FIELD_NAME_LOOKUP_KEYS) Map lookupKeyMap, - @JsonProperty(FIELD_NAME_REMAINING_CONDITION) @Nullable RexNode remainingCondition) { + @JsonProperty(FIELD_NAME_REMAINING_CONDITION) @Nullable RexNode remainingCondition, + @JsonProperty(FIELD_NAME_PROJECTION_ON_TEMPORAL_TABLE) @Nullable + List projectionOnTemporalTable, + @JsonProperty(FIELD_NAME_FILTER_ON_TEMPORAL_TABLE) @Nullable + RexNode filterOnTemporalTable) { this.lookupKeyMap = lookupKeyMap; this.lookupTable = lookupTable; this.remainingCondition = remainingCondition; + this.projectionOnTemporalTable = projectionOnTemporalTable; + this.filterOnTemporalTable = filterOnTemporalTable; } @JsonIgnore @@ -81,4 +100,14 @@ public Map getLookupKeyMap() { public Optional getRemainingCondition() { return Optional.ofNullable(remainingCondition); } + + @JsonIgnore + public Optional> getProjectionOnTemporalTable() { + return Optional.ofNullable(projectionOnTemporalTable); + } + + @JsonIgnore + public Optional getFilterOnTemporalTable() { + return Optional.ofNullable(filterOnTemporalTable); + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index 07ae35a6e82cc..2836f0b601597 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -19,6 +19,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.flink.FlinkVersion; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.dag.Transformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.ReadableConfig; @@ -54,6 +55,7 @@ import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.runtime.collector.TableFunctionResultFuture; +import org.apache.flink.table.runtime.generated.GeneratedFunction; import org.apache.flink.table.runtime.generated.GeneratedResultFuture; import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; import org.apache.flink.table.runtime.operators.StreamingDeltaJoinOperatorFactory; @@ -389,6 +391,7 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner( boolean treatRightAsLookupTable) { RelOptTable lookupTable = treatRightAsLookupTable ? rightTempTable : leftTempTable; RowType streamSideType = treatRightAsLookupTable ? leftStreamSideType : rightStreamSideType; + RowType lookupSideType = treatRightAsLookupTable ? rightStreamSideType : leftStreamSideType; AsyncTableFunction lookupSideAsyncTableFunction = getUnwrappedAsyncLookupFunction(lookupTable, lookupKeys.keySet(), classLoader); @@ -454,11 +457,36 @@ private AsyncDeltaJoinRunner createAsyncDeltaJoinRunner( JavaScalaConversionUtil.toScala(newCond)); } + GeneratedFunction> lookupSideGeneratedCalc = null; + if ((treatRightAsLookupTable + && lookupRightTableJoinSpec.getProjectionOnTemporalTable().isPresent()) + || (!treatRightAsLookupTable + && lookupLeftTableJoinSpec.getProjectionOnTemporalTable().isPresent())) { + // a projection or filter after lookup table + List projectionOnTemporalTable = + treatRightAsLookupTable + ? lookupRightTableJoinSpec.getProjectionOnTemporalTable().get() + : lookupLeftTableJoinSpec.getProjectionOnTemporalTable().get(); + RexNode filterOnTemporalTable = + treatRightAsLookupTable + ? lookupRightTableJoinSpec.getFilterOnTemporalTable().orElse(null) + : lookupLeftTableJoinSpec.getFilterOnTemporalTable().orElse(null); + lookupSideGeneratedCalc = + LookupJoinCodeGenerator.generateCalcMapFunction( + config, + planner.getFlinkContext().getClassLoader(), + JavaScalaConversionUtil.toScala(projectionOnTemporalTable), + filterOnTemporalTable, + lookupSideType, + lookupTableSourceRowType); + } + return new AsyncDeltaJoinRunner( lookupSideGeneratedFuncWithType.tableFunc(), (DataStructureConverter) lookupSideFetcherConverter, + lookupSideGeneratedCalc, lookupSideGeneratedResultFuture, - InternalSerializers.create(lookupTableSourceRowType), + InternalSerializers.create(lookupSideType), leftJoinKeySelector, leftUpsertKeySelector, rightJoinKeySelector, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/DeltaJoinUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/DeltaJoinUtil.java index a2e356e714f13..7f8ac77f4e853 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/DeltaJoinUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/DeltaJoinUtil.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.utils; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.table.catalog.Index; import org.apache.flink.table.catalog.ResolvedSchema; import org.apache.flink.table.connector.ChangelogMode; @@ -25,10 +26,16 @@ import org.apache.flink.table.connector.source.LookupTableSource; import org.apache.flink.table.functions.AsyncTableFunction; import org.apache.flink.table.functions.UserDefinedFunction; +import org.apache.flink.table.planner.plan.abilities.source.FilterPushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.PartitionPushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.ProjectPushDownSpec; +import org.apache.flink.table.planner.plan.abilities.source.ReadingMetadataSpec; +import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec; import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery; import org.apache.flink.table.planner.plan.nodes.exec.spec.DeltaJoinSpec; import org.apache.flink.table.planner.plan.nodes.exec.spec.JoinSpec; import org.apache.flink.table.planner.plan.nodes.exec.spec.TemporalTableSourceSpec; +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDeltaJoin; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDropUpdateBefore; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange; @@ -51,23 +58,32 @@ import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelVisitor; +import org.apache.calcite.rel.core.Calc; import org.apache.calcite.rel.core.JoinInfo; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.mapping.IntPair; +import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; + +import scala.Option; /** Utils for delta joins. */ public class DeltaJoinUtil { @@ -83,7 +99,22 @@ public class DeltaJoinUtil { Sets.newHashSet( StreamPhysicalTableSourceScan.class, StreamPhysicalExchange.class, - StreamPhysicalDropUpdateBefore.class); + StreamPhysicalDropUpdateBefore.class, + StreamPhysicalCalc.class); + + /** + * All supported {@link SourceAbilitySpec}s in sources. Only the sources with the following + * {@link SourceAbilitySpec} can be used as delta join sources. Otherwise, the regular join will + * not be optimized into the delta join. + */ + private static final Set> ALL_SUPPORTED_ABILITY_SPEC_IN_SOURCE = + Sets.newHashSet( + FilterPushDownSpec.class, + ProjectPushDownSpec.class, + PartitionPushDownSpec.class, + // TODO FLINK-38569 ReadingMetadataSpec should not be generated when there are + // no metadata keys to be read + ReadingMetadataSpec.class); private DeltaJoinUtil() {} @@ -112,6 +143,10 @@ public static boolean canConvertToDeltaJoin(StreamPhysicalJoin join) { return false; } + if (!areAllUpstreamCalcSupported(join)) { + return false; + } + return areAllJoinTableScansSupported(join); } @@ -135,13 +170,29 @@ public static DeltaJoinSpec getDeltaJoinSpec( Optional remainingCondition = condition.isAlwaysTrue() ? Optional.empty() : Optional.of(condition); + List joinPairs = joinInfo.pairs(); final RelOptTable lookupRelOptTable; - List streamToLookupJoinKeys = joinInfo.pairs(); + final IntPair[] streamToLookupJoinKeys; + final Optional calcOnLookupTable; + if (treatRightAsLookupSide) { + calcOnLookupTable = getRexProgramBetweenJoinAndTableScan(join.getRight()); + lookupRelOptTable = DeltaJoinUtil.getTableScanRelOptTable(join.getRight()); + streamToLookupJoinKeys = + TemporalJoinUtil.getTemporalTableJoinKeyPairs( + joinPairs.toArray(new IntPair[0]), + JavaScalaConversionUtil.toScala(calcOnLookupTable)); + } else { - streamToLookupJoinKeys = reverseIntPairs(streamToLookupJoinKeys); + calcOnLookupTable = getRexProgramBetweenJoinAndTableScan(join.getLeft()); + + joinPairs = reverseIntPairs(joinInfo.pairs()); lookupRelOptTable = DeltaJoinUtil.getTableScanRelOptTable(join.getLeft()); + streamToLookupJoinKeys = + TemporalJoinUtil.getTemporalTableJoinKeyPairs( + joinPairs.toArray(new IntPair[0]), + JavaScalaConversionUtil.toScala(calcOnLookupTable)); } Preconditions.checkState(lookupRelOptTable instanceof TableSourceTable); final TableSourceTable lookupTable = (TableSourceTable) lookupRelOptTable; @@ -149,10 +200,24 @@ public static DeltaJoinSpec getDeltaJoinSpec( Map allLookupKeys = analyzerDeltaJoinLookupKeys(streamToLookupJoinKeys); + List projectionOnTemporalTable = null; + RexNode filterOnTemporalTable = null; + + if (calcOnLookupTable.isPresent()) { + Tuple2, Option> projectionsAndFilter = + JavaScalaConversionUtil.toJava( + FlinkRexUtil.expandRexProgram(calcOnLookupTable.get())); + projectionOnTemporalTable = projectionsAndFilter.f0; + filterOnTemporalTable = + JavaScalaConversionUtil.toJava(projectionsAndFilter.f1).orElse(null); + } + return new DeltaJoinSpec( new TemporalTableSourceSpec(lookupTable), allLookupKeys, - remainingCondition.orElse(null)); + remainingCondition.orElse(null), + projectionOnTemporalTable, + filterOnTemporalTable); } /** @@ -212,7 +277,7 @@ public static boolean isJoinTypeSupported(FlinkJoinType flinkJoinType) { * @param streamToLookupJoinKeys the join keys from stream side to lookup side */ private static Map analyzerDeltaJoinLookupKeys( - List streamToLookupJoinKeys) { + IntPair[] streamToLookupJoinKeys) { Map allFieldRefLookupKeys = new LinkedHashMap<>(); for (IntPair intPair : streamToLookupJoinKeys) { allFieldRefLookupKeys.put( @@ -227,30 +292,31 @@ private static List reverseIntPairs(List intPairs) { .collect(Collectors.toList()); } - private static int[][] getColumnIndicesOfAllTableIndexes(TableSourceTable tableSourceTable) { - List> columnsOfIndexes = getAllIndexesColumnsOfTable(tableSourceTable); + private static int[][] getAllIndexesColumnsFromTableSchema(ResolvedSchema schema) { + List indexes = schema.getIndexes(); + List> columnsOfIndexes = + indexes.stream().map(Index::getColumns).collect(Collectors.toList()); int[][] results = new int[columnsOfIndexes.size()][]; for (int i = 0; i < columnsOfIndexes.size(); i++) { - List fieldNames = tableSourceTable.getRowType().getFieldNames(); + List fieldNames = schema.getColumnNames(); results[i] = columnsOfIndexes.get(i).stream().mapToInt(fieldNames::indexOf).toArray(); } return results; } - private static List> getAllIndexesColumnsOfTable( - TableSourceTable tableSourceTable) { - ResolvedSchema schema = tableSourceTable.contextResolvedTable().getResolvedSchema(); - List indexes = schema.getIndexes(); - return indexes.stream().map(Index::getColumns).collect(Collectors.toList()); - } - private static boolean areJoinConditionsSupported(StreamPhysicalJoin join) { JoinInfo joinInfo = join.analyzeCondition(); // there must be one pair of join key if (joinInfo.pairs().isEmpty()) { return false; } + JoinSpec joinSpec = join.joinSpec(); + Optional nonEquiCond = joinSpec.getNonEquiCondition(); + if (nonEquiCond.isPresent() + && !areAllRexNodeDeterministic(Collections.singletonList(nonEquiCond.get()))) { + return false; + } // if this join outputs cdc records and has non-equiv condition, the reference columns in // the non-equiv condition must come from the same set of upsert keys @@ -258,25 +324,59 @@ private static boolean areJoinConditionsSupported(StreamPhysicalJoin join) { if (changelogMode.containsOnly(RowKind.INSERT)) { return true; } - JoinSpec joinSpec = join.joinSpec(); - Optional nonEquiCond = joinSpec.getNonEquiCondition(); + if (nonEquiCond.isEmpty()) { return true; } - ImmutableBitSet fieldRefIndices = - ImmutableBitSet.of( - RexNodeExtractor.extractRefInputFields( - Collections.singletonList(nonEquiCond.get()))); + FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(join.getCluster().getMetadataQuery()); Set upsertKeys = fmq.getUpsertKeys(join); + return isFilterOnOneSetOfUpsertKeys(nonEquiCond.get(), upsertKeys); + } + + private static boolean isFilterOnOneSetOfUpsertKeys( + RexNode filter, @Nullable Set upsertKeys) { + ImmutableBitSet fieldRefIndices = + ImmutableBitSet.of( + RexNodeExtractor.extractRefInputFields(Collections.singletonList(filter))); return upsertKeys.stream().anyMatch(uk -> uk.contains(fieldRefIndices)); } private static boolean areAllJoinTableScansSupported(StreamPhysicalJoin join) { - return isTableScanSupported(getTableScan(join.getLeft()), join.joinSpec().getLeftKeys()) + List left2RightJoinPairs = + JoinUtil.createJoinInfo( + join.getLeft(), + join.getRight(), + join.getCondition(), + new ArrayList<>()) + .pairs(); + + Optional calcOnLeftLookupTable = + getRexProgramBetweenJoinAndTableScan(join.getLeft()); + Optional calcOnRightLookupTable = + getRexProgramBetweenJoinAndTableScan(join.getRight()); + + List right2LeftJoinPair = reverseIntPairs(left2RightJoinPairs); + + int[] leftJoinKeyOnLeftLookupTable = + Arrays.stream( + TemporalJoinUtil.getTemporalTableJoinKeyPairs( + right2LeftJoinPair.toArray(new IntPair[0]), + JavaScalaConversionUtil.toScala(calcOnLeftLookupTable))) + .mapToInt(pair -> pair.target) + .toArray(); + int[] rightJoinKeyOnRightLookupTable = + Arrays.stream( + TemporalJoinUtil.getTemporalTableJoinKeyPairs( + left2RightJoinPairs.toArray(new IntPair[0]), + JavaScalaConversionUtil.toScala(calcOnRightLookupTable))) + .mapToInt(pair -> pair.target) + .toArray(); + + return isTableScanSupported(getTableScan(join.getLeft()), leftJoinKeyOnLeftLookupTable) && isTableScanSupported( - getTableScan(join.getRight()), join.joinSpec().getRightKeys()); + getTableScan(join.getRight()), rightJoinKeyOnRightLookupTable); } private static boolean isTableScanSupported(TableScan tableScan, int[] lookupKeys) { @@ -288,8 +388,8 @@ private static boolean isTableScanSupported(TableScan tableScan, int[] lookupKey TableSourceTable tableSourceTable = ((StreamPhysicalTableSourceScan) tableScan).tableSourceTable(); - // source with ability specs are not supported yet - if (tableSourceTable.abilitySpecs().length != 0) { + if (tableSourceTable.abilitySpecs().length != 0 + && !areAllSourceAbilitySpecsSupported(tableScan, tableSourceTable.abilitySpecs())) { return false; } @@ -299,40 +399,120 @@ private static boolean isTableScanSupported(TableScan tableScan, int[] lookupKey return false; } - int[][] idxsOfAllIndexes = getColumnIndicesOfAllTableIndexes(tableSourceTable); - if (idxsOfAllIndexes.length == 0) { - return false; - } - // the source must have at least one index, and the join key contains one index - Set lookupKeysSet = Arrays.stream(lookupKeys).boxed().collect(Collectors.toSet()); - - boolean lookupKeyContainsOneIndex = - Arrays.stream(idxsOfAllIndexes) - .peek(idxsOfIndex -> Preconditions.checkState(idxsOfIndex.length > 0)) - .anyMatch( - idxsOfIndex -> - Arrays.stream(idxsOfIndex) - .allMatch(lookupKeysSet::contains)); - if (!lookupKeyContainsOneIndex) { + Set lookupKeySet = Arrays.stream(lookupKeys).boxed().collect(Collectors.toSet()); + + if (!isLookupKeysContainsIndex(tableSourceTable, lookupKeySet)) { return false; } // the lookup source must support async lookup return LookupJoinUtil.isAsyncLookup( tableSourceTable, - lookupKeysSet, + lookupKeySet, null, // hint false, // upsertMaterialize false // preferCustomShuffle ); } + private static boolean isLookupKeysContainsIndex( + TableSourceTable tableSourceTable, Set lookupKeySet) { + // the source must have at least one index, and the join key contains one index + int[][] idxsOfAllIndexes = + getAllIndexesColumnsFromTableSchema( + tableSourceTable.contextResolvedTable().getResolvedSchema()); + if (idxsOfAllIndexes.length == 0) { + return false; + } + + final Set lookupKeySetPassThroughProjectPushDownSpec; + Optional projectPushDownSpec = + Arrays.stream(tableSourceTable.abilitySpecs()) + .filter(spec -> spec instanceof ProjectPushDownSpec) + .map(spec -> (ProjectPushDownSpec) spec) + .findFirst(); + + if (projectPushDownSpec.isEmpty()) { + lookupKeySetPassThroughProjectPushDownSpec = lookupKeySet; + } else { + Map mapOut2InPos = new HashMap<>(); + int[][] projectedFields = projectPushDownSpec.get().getProjectedFields(); + for (int i = 0; i < projectedFields.length; i++) { + int[] projectedField = projectedFields[i]; + // skip nested projection push-down spec + if (projectedField.length > 1) { + continue; + } + int input = projectedField[0]; + mapOut2InPos.put(i, input); + } + + lookupKeySetPassThroughProjectPushDownSpec = + lookupKeySet.stream() + .flatMap(out -> Stream.ofNullable(mapOut2InPos.get(out))) + .collect(Collectors.toSet()); + } + + return Arrays.stream(idxsOfAllIndexes) + .peek(idxsOfIndex -> Preconditions.checkState(idxsOfIndex.length > 0)) + .anyMatch( + idxsOfIndex -> + Arrays.stream(idxsOfIndex) + .allMatch( + lookupKeySetPassThroughProjectPushDownSpec + ::contains)); + } + + private static boolean areAllSourceAbilitySpecsSupported( + TableScan tableScan, SourceAbilitySpec[] sourceAbilitySpecs) { + if (!Arrays.stream(sourceAbilitySpecs) + .allMatch(spec -> ALL_SUPPORTED_ABILITY_SPEC_IN_SOURCE.contains(spec.getClass()))) { + return false; + } + + Optional metadataSpec = + Arrays.stream(sourceAbilitySpecs) + .filter(spec -> spec instanceof ReadingMetadataSpec) + .map(spec -> (ReadingMetadataSpec) spec) + .findFirst(); + if (metadataSpec.isPresent() && !metadataSpec.get().getMetadataKeys().isEmpty()) { + return false; + } + + // source with non-deterministic filter pushed down is not supported + Optional filterPushDownSpec = + Arrays.stream(sourceAbilitySpecs) + .filter(spec -> spec instanceof FilterPushDownSpec) + .map(spec -> (FilterPushDownSpec) spec) + .findFirst(); + if (filterPushDownSpec.isEmpty()) { + return true; + } + + List filtersOnSource = filterPushDownSpec.get().getPredicates(); + if (!areAllRexNodeDeterministic(filtersOnSource)) { + return false; + } + + ChangelogMode changelogMode = getChangelogMode((StreamPhysicalRel) tableScan); + if (changelogMode.containsOnly(RowKind.INSERT)) { + return true; + } + + FlinkRelMetadataQuery fmq = + FlinkRelMetadataQuery.reuseOrCreate(tableScan.getCluster().getMetadataQuery()); + Set upsertKeys = fmq.getUpsertKeys(tableScan); + return filtersOnSource.stream() + .allMatch(filter -> isFilterOnOneSetOfUpsertKeys(filter, upsertKeys)); + } + private static TableScan getTableScan(RelNode node) { node = unwrapNode(node, true); // support to get table across more nodes if we support more nodes in // `ALL_SUPPORTED_DELTA_JOIN_UPSTREAM_NODES` if (node instanceof StreamPhysicalExchange - || node instanceof StreamPhysicalDropUpdateBefore) { + || node instanceof StreamPhysicalDropUpdateBefore + || node instanceof StreamPhysicalCalc) { return getTableScan(node.getInput(0)); } @@ -340,6 +520,64 @@ private static TableScan getTableScan(RelNode node) { return (TableScan) node; } + private static boolean areAllUpstreamCalcSupported(StreamPhysicalJoin join) { + return areAllUpstreamCalcFromOneJoinInputSupported(join.getLeft()) + && areAllUpstreamCalcFromOneJoinInputSupported(join.getRight()); + } + + private static boolean areAllUpstreamCalcFromOneJoinInputSupported(RelNode joinInput) { + List calcListFromThisInput = collectCalcBetweenJoinAndTableScan(joinInput); + + // currently, at most one calc is allowed to appear between Join and TableScan + if (calcListFromThisInput.size() > 1) { + return false; + } + + if (calcListFromThisInput.isEmpty()) { + return true; + } + + Calc calc = calcListFromThisInput.get(0); + return isCalcSupported(calc); + } + + private static Optional getCalcBetweenJoinAndTableScan(RelNode joinInput) { + List calcListFromLeftInput = collectCalcBetweenJoinAndTableScan(joinInput); + Preconditions.checkState( + calcListFromLeftInput.size() <= 1, + "Should be validated before calling this function"); + if (calcListFromLeftInput.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(calcListFromLeftInput.get(0)); + } + } + + private static Optional getRexProgramBetweenJoinAndTableScan(RelNode joinInput) { + return getCalcBetweenJoinAndTableScan(joinInput).map(Calc::getProgram); + } + + private static List collectCalcBetweenJoinAndTableScan(RelNode joinInput) { + CalcCollector calcCollector = new CalcCollector(); + calcCollector.go(joinInput); + return calcCollector.collectResult; + } + + private static boolean isCalcSupported(Calc calc) { + RexProgram calcProgram = calc.getProgram(); + // calc with non-deterministic fields or filters is not supported + return calcProgram == null || areAllRexNodeDeterministic(calcProgram.getExprList()); + } + + private static boolean areAllRexNodeDeterministic(List rexNodes) { + // Delta joins may produce duplicate data, and when this data is sent downstream, we want it + // to be processed in an idempotent manner. However, the presence of non-deterministic + // functions can lead to unpredictable results, such as random filtering or the addition of + // non-deterministic columns. Therefore, we strictly prohibit the use of non-deterministic + // functions in this context to ensure consistent and reliable processing. + return rexNodes.stream().allMatch(RexUtil::isDeterministic); + } + private static boolean areAllJoinInputsInWhiteList(RelNode node) { for (RelNode input : node.getInputs()) { input = unwrapNode(input, true); @@ -407,4 +645,20 @@ private static StreamPhysicalRel unwrapNode(RelNode node, boolean transposeToChi } return (StreamPhysicalRel) node; } + + private static class CalcCollector extends RelVisitor { + + private final List collectResult = new ArrayList<>(); + + @Override + public void visit(RelNode node, int ordinal, @Nullable RelNode parent) { + node = unwrapNode(node, true); + + super.visit(node, ordinal, parent); + + if (node instanceof Calc) { + collectResult.add((Calc) node); + } + } + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index 2f3ab1708b869..13b817b1fca4e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -24,8 +24,7 @@ import org.apache.flink.table.api.ValidationException import org.apache.flink.table.catalog.DataTypeFactory import org.apache.flink.table.connector.source.{LookupTableSource, ScanTableSource} import org.apache.flink.table.data.{GenericRowData, RowData} -import org.apache.flink.table.data.utils.JoinedRowData -import org.apache.flink.table.functions.{AsyncLookupFunction, AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction} +import org.apache.flink.table.functions._ import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType @@ -36,7 +35,6 @@ import org.apache.flink.table.planner.functions.inference.FunctionCallContext import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.collector.{ListenableCollector, TableFunctionResultFuture} -import org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener import org.apache.flink.table.runtime.generated.{GeneratedCollector, GeneratedFunction, GeneratedResultFuture} import org.apache.flink.table.types.DataType import org.apache.flink.table.types.extraction.ExtractionUtils.extractSimpleGeneric @@ -44,7 +42,6 @@ import org.apache.flink.table.types.inference.{TypeInference, TypeStrategies, Ty import org.apache.flink.table.types.logical.{LogicalType, RowType} import org.apache.flink.table.types.utils.DataTypeUtils.transform import org.apache.flink.types.Row -import org.apache.flink.util.Collector import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.RexNode @@ -358,10 +355,31 @@ object LookupJoinCodeGenerator { condition: RexNode, outputType: RelDataType, tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { + generateCalcMapFunction( + tableConfig, + classLoader, + projection, + condition, + FlinkTypeFactory.toLogicalRowType(outputType), + tableSourceRowType + ) + } + + /** + * Generates calculate flatmap function for temporal join which is used to projection/filter the + * dimension table results + */ + def generateCalcMapFunction( + tableConfig: ReadableConfig, + classLoader: ClassLoader, + projection: Seq[RexNode], + condition: RexNode, + outputType: RowType, + tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { CalcCodeGenerator.generateFunction( tableSourceRowType, "TableCalcMapFunction", - FlinkTypeFactory.toLogicalRowType(outputType), + outputType, classOf[GenericRowData], projection, Option(condition), diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLookupJoin.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLookupJoin.scala index c3550be71e501..646f48de2b098 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLookupJoin.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLookupJoin.scala @@ -23,7 +23,6 @@ import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecLookupJoin import org.apache.flink.table.planner.plan.nodes.exec.spec.TemporalTableSourceSpec import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalLookupJoin import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, JoinTypeUtil} -import org.apache.flink.table.planner.utils.JavaScalaConversionUtil import org.apache.calcite.plan.{RelOptCluster, RelOptTable, RelTraitSet} import org.apache.calcite.rel.RelNode @@ -79,7 +78,7 @@ class BatchPhysicalLookupJoin( val (projectionOnTemporalTable, filterOnTemporalTable) = calcOnTemporalTable match { case Some(program) => val (projection, filter) = FlinkRexUtil.expandRexProgram(program) - (JavaScalaConversionUtil.toJava(projection), filter.orNull) + (projection, filter.orNull) case _ => (null, null) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLookupJoin.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLookupJoin.scala index c7b1279f096b7..d66135676ecc0 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLookupJoin.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLookupJoin.scala @@ -24,7 +24,6 @@ import org.apache.flink.table.planner.plan.nodes.exec.spec.TemporalTableSourceSp import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLookupJoin import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalLookupJoin import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, JoinTypeUtil, UpsertKeyUtil} -import org.apache.flink.table.planner.utils.JavaScalaConversionUtil import org.apache.calcite.plan.{RelOptCluster, RelOptTable, RelTraitSet} import org.apache.calcite.rel.{RelNode, RelWriter} @@ -102,7 +101,7 @@ class StreamPhysicalLookupJoin( val (projectionOnTemporalTable, filterOnTemporalTable) = calcOnTemporalTable match { case Some(program) => val (projection, filter) = FlinkRexUtil.expandRexProgram(program) - (JavaScalaConversionUtil.toJava(projection), filter.orNull) + (projection, filter.orNull) case _ => (null, null) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala index 49ac2a92df84e..1ea8006eb124b 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.table.planner.plan.utils -import org.apache.flink.table.planner.{JInt, JMap} +import org.apache.flink.table.planner.{JInt, JList, JMap} import org.apache.flink.table.planner.functions.sql.SqlTryCastFunction import org.apache.flink.table.planner.plan.nodes.calcite.{LegacySink, Sink} import org.apache.flink.table.planner.plan.optimize.RelNodeBlock @@ -364,7 +364,7 @@ object FlinkRexUtil { }) /** Expands the RexProgram to projection list and condition. */ - def expandRexProgram(program: RexProgram): (Seq[RexNode], Option[RexNode]) = { + def expandRexProgram(program: RexProgram): (JList[RexNode], Option[RexNode]) = { val projection = program.getProjectList.map(program.expandLocalRef) val filter = if (program.getCondition != null) { Some(program.expandLocalRef(program.getCondition)) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/TemporalJoinUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/TemporalJoinUtil.scala index bd467802e5aa5..47bed3b4b2487 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/TemporalJoinUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/TemporalJoinUtil.scala @@ -25,7 +25,6 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin import org.apache.flink.table.runtime.types.PlannerTypeUtils import org.apache.flink.util.Preconditions.checkState -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} import org.apache.calcite.rex._ @@ -414,7 +413,12 @@ object TemporalJoinUtil { def getTemporalTableJoinKeyPairs( joinInfo: JoinInfo, calcOnTemporalTable: Option[RexProgram]): Array[IntPair] = { - val joinPairs = joinInfo.pairs().asScala.toArray + getTemporalTableJoinKeyPairs(joinInfo.pairs().asScala.toArray, calcOnTemporalTable) + } + + def getTemporalTableJoinKeyPairs( + joinPairs: Array[IntPair], + calcOnTemporalTable: Option[RexProgram]): Array[IntPair] = { calcOnTemporalTable match { case Some(program) => // the target key of joinInfo is the calc output fields, we have to remapping to table here diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml index 55249bedc21d0..12b51151c8f48 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml @@ -268,12 +268,52 @@ LogicalSink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3 (a1, 1.1)]) : +- TableSourceScan(table=[[default_catalog, default_database, src1, filter=[]]], fields=[a0, a1, a2, a3]) +- Exchange(distribution=[hash[b1, b2]]) +- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[b0, b2, b1]) +]]> + + + + + 1 + ) join ( + select b1, b2, b0 from src2 where b1 < 10 + ) + on a1 = b1 + and a2 = b2 + and b0 <> 0 +]]> + + + ($5, 0))], joinType=[inner]) + :- LogicalProject(a0=[$0], a2=[$2], a1=[$1]) + : +- LogicalFilter(condition=[>($0, 1)]) + : +- LogicalTableScan(table=[[default_catalog, default_database, src1]]) + +- LogicalProject(b1=[$2], b2=[$1], b0=[$0]) + +- LogicalFilter(condition=[<($2, 10)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src2]]) +]]> + + + (a0, 1)], project=[a0, a2, a1], metadata=[]]], fields=[a0, a2, a1]) + +- Exchange(distribution=[hash[b1, b2]]) + +- Calc(select=[b1, b2, b0], where=[<(b1, 10)]) + +- TableSourceScan(table=[[default_catalog, default_database, src2, filter=[<>(b0, 0)]]], fields=[b0, b2, b1]) ]]> @@ -301,6 +341,59 @@ Sink(table=[default_catalog.default_database.snk_for_cdc_src], fields=[a0, a1, a +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) +- DropUpdateBefore(changelogMode=[I,UA]) +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) +]]> + + + + + 1]]> + + + ($3, 1)]) + +- LogicalJoin(condition=[AND(=($1, $6), =($2, $5))], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src2]]) +]]> + + + (a3, 1)], changelogMode=[I,UB,UA]) + : +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src1, filter=[]]], fields=[a0, a1, a2, a3], changelogMode=[I,UB,UA]) + +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UB,UA]) + +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) +]]> + + + + + 1]]> + + + ($3, 1)]) + +- LogicalJoin(condition=[AND(=($1, $6), =($2, $5))], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src2]]) +]]> + + + (a3, 1)]]], fields=[a0, a1, a2, a3], changelogMode=[I,UB,UA]) + +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) + +- DropUpdateBefore(changelogMode=[I,UA]) + +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) ]]> @@ -328,6 +421,61 @@ Sink(table=[default_catalog.default_database.snk_for_cdc_src], fields=[a0, a1, a +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) +- DropUpdateBefore(changelogMode=[I,UA]) +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) +]]> + + + + + 1]]> + + + ($0, 1)]) + +- LogicalJoin(condition=[AND(=($1, $6), =($2, $5))], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src2]]) +]]> + + + (a0, 1)], changelogMode=[I,UA]) + : +- DropUpdateBefore(changelogMode=[I,UA]) + : +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src1, filter=[]]], fields=[a0, a1, a2, a3], changelogMode=[I,UB,UA]) + +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) + +- DropUpdateBefore(changelogMode=[I,UA]) + +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) +]]> + + + + + 1]]> + + + ($0, 1)]) + +- LogicalJoin(condition=[AND(=($1, $6), =($2, $5))], joinType=[inner]) + :- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, no_delete_src2]]) +]]> + + + (a0, 1)]]], fields=[a0, a1, a2, a3], changelogMode=[I,UB,UA]) + +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) + +- DropUpdateBefore(changelogMode=[I,UA]) + +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) ]]> @@ -595,31 +743,6 @@ Sink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3, b0, b Sink(table=[default_catalog.default_database.snk2], fields=[a0, a1, a2, a3, b0, b2, b1]) +- Reused(reference_id=[1]) -]]> - - - - - - - - - - - @@ -654,6 +777,145 @@ Sink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3, b0, b Sink(table=[default_catalog.default_database.snk2], fields=[a0, a1, a2, a3, b0, b2, b1]) +- Reused(reference_id=[1]) +]]> + + + + + rand(10) + ) join src2 + on a1 = b1 + and a2 = b2 +]]> + + + ($0, RAND(10))]) + : +- LogicalTableScan(table=[[default_catalog, default_database, src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, src2]]) +]]> + + + (a0, RAND(10))]) + : +- TableSourceScan(table=[[default_catalog, default_database, src1, filter=[], project=[a0, a1, a2], metadata=[]]], fields=[a0, a1, a2]) + +- Exchange(distribution=[hash[b1, b2]]) + +- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[b0, b2, b1]) +]]> + + + + + rand(10) + ) + on a1 = b1 + and a2 = b2 +]]> + + + ($0, RAND(10))]) + +- LogicalTableScan(table=[[default_catalog, default_database, src2]]) +]]> + + + (b0, RAND(10))]) + +- TableSourceScan(table=[[default_catalog, default_database, src2, filter=[]]], fields=[b0, b2, b1]) +]]> + + + + + rand(10) + ) join src2 + on a1 = b1 + and a2 = b2 +]]> + + + ($0, RAND(10))]) + : +- LogicalTableScan(table=[[default_catalog, default_database, src1]]) + +- LogicalTableScan(table=[[default_catalog, default_database, src2]]) +]]> + + + (a0, RAND(10))]) + : +- TableSourceScan(table=[[default_catalog, default_database, src1, filter=[], project=[a0, a1, a2], metadata=[]]], fields=[a0, a1, a2]) + +- Exchange(distribution=[hash[b1, b2]]) + +- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[b0, b2, b1]) +]]> + + + + + + + + + + + @@ -734,6 +996,60 @@ Sink(table=[default_catalog.default_database.snk_for_cdc_src], fields=[a0, a1, a +- Exchange(distribution=[hash[b1, b2]], changelogMode=[I,UA]) +- DropUpdateBefore(changelogMode=[I,UA]) +- TableSourceScan(table=[[default_catalog, default_database, no_delete_src2]], fields=[b0, b2, b1], changelogMode=[I,UB,UA]) +]]> + + + + + + + + + + + + + + + + + + + + + + @@ -780,7 +1096,7 @@ LogicalSink(table=[default_catalog.default_database.snk], targetColumns=[[0],[1] - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + - + - + - + - + - + @@ -994,6 +1425,30 @@ Sink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3, b0, b : +- TableSourceScan(table=[[default_catalog, default_database, src1]], fields=[a0, a1, a2, a3]) +- Exchange(distribution=[single]) +- TableSourceScan(table=[[default_catalog, default_database, src2]], fields=[b0, b2, b1]) +]]> + + + + + + + + + + + @@ -1037,7 +1492,7 @@ LogicalSink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3 (b0, a0))], select=[a0, a1, a2, a3, b0, b2, b1], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey]) ++- DeltaJoin(joinType=[InnerJoin], where=[AND(=(a1, b1), =(a2, b2), >(b0, a0))], select=[a0, a1, a2, a3, b0, b2, b1]) :- Exchange(distribution=[hash[a1, a2]]) : +- Calc(select=[a0, a1, a2, a3], where=[AND(>(a0, 99), <>(a2, 'Hello'))]) : +- TableSourceScan(table=[[default_catalog, default_database, src1, filter=[]]], fields=[a0, a1, a2, a3]) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.scala index 7cb04376bba89..8a9f25b707c19 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.scala @@ -17,12 +17,13 @@ */ package org.apache.flink.table.planner.plan.stream.sql +import org.apache.flink.shaded.guava33.com.google.common.collect.Lists import org.apache.flink.table.api.{DataTypes, ExplainDetail, Schema} import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.api.config.ExecutionConfigOptions.UpsertMaterialize import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy import org.apache.flink.table.catalog.{CatalogTable, ObjectPath, ResolvedCatalogTable} -import org.apache.flink.table.planner.JMap +import org.apache.flink.table.planner.{JList, JMap} import org.apache.flink.table.planner.utils.{TableTestBase, TestingTableEnvironment} import org.assertj.core.api.Assertions.assertThatThrownBy @@ -171,7 +172,6 @@ class DeltaJoinTest extends TableTestBase { @Test def testWithNonEquiCondition2(): Unit = { - // could not optimize into delta join because there is a calc between join and source util.verifyRelPlanInsert( "insert into snk select * from src1 join src2 " + "on src1.a1 = src2.b1 " + @@ -181,6 +181,15 @@ class DeltaJoinTest extends TableTestBase { "and src1.a0 > 99") } + @Test + def testWithNonDeterministicInNonEquiCondition(): Unit = { + util.verifyRelPlanInsert( + "insert into snk select * from src1 join src2 " + + "on src1.a1 = src2.b1 " + + "and src1.a2 = src2.b2 " + + "and src1.a0 + rand(10) < src2.b0") + } + @Test def testJsonPlanWithTableHints(): Unit = { util.verifyJsonPlan( @@ -193,7 +202,6 @@ class DeltaJoinTest extends TableTestBase { @Test def testProjectFieldsBeforeJoin(): Unit = { - // could not optimize into delta join because the source has ProjectPushDownSpec util.verifyRelPlanInsert( "insert into snk(l0, l1, l2, r0, r2, r1) " + "select * from ( " + @@ -211,9 +219,96 @@ class DeltaJoinTest extends TableTestBase { "and src1.a2 = src2.b2") } + @Test + def testProjectFieldsBeforeJoinWhileModifyingOnIndex(): Unit = { + // a2 is modified + util.verifyRelPlanInsert( + "insert into snk(l0, l1, l2, r0, r2, r1) " + + "select * from ( " + + " select a0, a1, SUBSTRING(a2, 2) as a2 from src1" + + ") tmp join src2 " + + "on tmp.a1 = src2.b1 " + + "and tmp.a2 = src2.b2") + } + + @Test + def testProjectFieldsBeforeJoinWhileModifyingOnOneIndexButRetainingAnother(): Unit = { + addTable( + "src1WithMultiIndexes", + Schema + .newBuilder() + .column("a0", DataTypes.INT.notNull) + .column("a1", DataTypes.DOUBLE.notNull) + .column("a2", DataTypes.STRING) + .column("a3", DataTypes.INT) + .index("a1", "a2") + .index("a1") + .build() + ) + + // a2 is modified + util.verifyRelPlanInsert( + "insert into snk(l0, l1, l2, r0, r2, r1) " + + "select * from ( " + + " select a0, a1, SUBSTRING(a2, 2) as a2 from src1WithMultiIndexes" + + ") tmp join src2 " + + "on tmp.a1 = src2.b1 " + + "and tmp.a2 = src2.b2") + } + + @Test + def testProjectFieldsBeforeJoinWhileAliasAndReorderOnIndex(): Unit = { + util.verifyRelPlanInsert( + "insert into snk(l1, l2, l0, r0, r2, r1) " + + "select * from ( " + + " select a1 as a0, a2 as a1, a0 as a2 from src1" + + ") tmp join src2 " + + "on tmp.a0 = src2.b1 " + + "and tmp.a1 = src2.b2") + } + + @Test + def testSourceDDLContainsComputingCol(): Unit = { + addTable( + "src1WithComputingCol", + Schema + .newBuilder() + .column("a0", DataTypes.INT.notNull) + .column("a1", DataTypes.DOUBLE.notNull) + .column("a2", DataTypes.STRING) + .columnByExpression("new_a1", "a1 + 1") + .index("a1", "a2") + .build() + ) + util.verifyRelPlanInsert( + "insert into snk(l0, l1, r0) select a0, new_a1, b0 " + + "from src1WithComputingCol join src2 " + + "on a1 = b1 " + + "and a2 = b2") + } + + @Test + def testSourceDDLContainsNonDeterministicComputingCol(): Unit = { + addTable( + "src1WithComputingCol", + Schema + .newBuilder() + .column("a0", DataTypes.INT.notNull) + .column("a1", DataTypes.DOUBLE.notNull) + .column("a2", DataTypes.STRING) + .columnByExpression("new_a1", "a1 + rand(10)") + .index("a1", "a2") + .build() + ) + util.verifyRelPlanInsert( + "insert into snk(l0, l1, r0) select a0, new_a1, b0 " + + "from src1WithComputingCol join src2 " + + "on a1 = b1 " + + "and a2 = b2") + } + @Test def testFilterFieldsBeforeJoin(): Unit = { - // could not optimize into delta join because there is a calc between source and join util.verifyRelPlanInsert( "insert into snk select * from ( " + " select * from src1 where a1 > 1.1 " + @@ -231,6 +326,137 @@ class DeltaJoinTest extends TableTestBase { "where a3 > b0") } + @Test + def testFilterFieldsBeforeJoinWithFilterPushDown(): Unit = { + replaceTable("src1", "src1", Maps.newHashMap("filterable-fields", "a0")) + replaceTable("src2", "src2", Maps.newHashMap("filterable-fields", "b0")) + + util.verifyRelPlanInsert(""" + |insert into snk(l0, l1, r0, r2, r1) + | select a0, a1, b0, b2, b1 from ( + | select a0, a2, a1 from src1 where a0 > 1 + | ) join ( + | select b1, b2, b0 from src2 where b1 < 10 + | ) + | on a1 = b1 + | and a2 = b2 + | and b0 <> 0 + |""".stripMargin) + } + + @Test + def testNonDeterministicFilterFieldsBeforeJoin1(): Unit = { + util.verifyRelPlanInsert(""" + |insert into snk(l0, l1, r0, r2, r1) + | select a0, a1, b0, b2, b1 from ( + | select a0, a2, a1 from src1 where a0 > rand(10) + | ) join src2 + | on a1 = b1 + | and a2 = b2 + |""".stripMargin) + } + + @Test + def testNonDeterministicFilterFieldsBeforeJoin2(): Unit = { + util.verifyRelPlanInsert(""" + |insert into snk + | select * from src1 + | join ( + | select * from src2 where b0 > rand(10) + | ) + | on a1 = b1 + | and a2 = b2 + |""".stripMargin) + } + + @Test + def testNonDeterministicFilterFieldsBeforeJoinWithFilterPushDown(): Unit = { + replaceTable("src1", "src1", Maps.newHashMap("filterable-fields", "a0")) + + // actually, 'values' source will not push down filter 'a0 > rand(10)' into source + util.verifyRelPlanInsert(""" + |insert into snk(l0, l1, r0, r2, r1) + | select a0, a1, b0, b2, b1 from ( + | select a0, a2, a1 from src1 where a0 > rand(10) + | ) join src2 + | on a1 = b1 + | and a2 = b2 + |""".stripMargin) + } + + @Test + def testPartitionPushDown(): Unit = { + addTable( + "src1WithPartition", + Schema + .newBuilder() + .column("a0", DataTypes.INT.notNull) + .column("a1", DataTypes.DOUBLE.notNull) + .column("a2", DataTypes.STRING) + .column("pt", DataTypes.INT) + .index("a1", "a2") + .build(), + Maps.newHashMap("partition-list", "pt:1;pt:2"), + Lists.newArrayList("pt") + ) + + util.verifyRelPlanInsert(""" + |insert into snk(l0, r0, r2) + | select a0, b0, b2 from ( + | select a0, a2, a1 from src1WithPartition where pt = 1 + | ) join src2 + | on a1 = b1 + | and a2 = b2 + |""".stripMargin) + } + + @Test + def testReadingMetadata(): Unit = { + addTable( + "src1WithMetadata", + Schema + .newBuilder() + .columnByMetadata("a0", DataTypes.INT.notNull) + .column("a1", DataTypes.DOUBLE.notNull) + .column("a2", DataTypes.STRING) + .column("a3", DataTypes.INT) + .index("a1", "a2") + .build(), + Maps.newHashMap("readable-metadata", "a0:int") + ) + + util.verifyRelPlanInsert(""" + |insert into snk(l0, r0, r2) + | select a0, b0, b2 from ( + | select a0, a2, a1 from src1WithMetadata + | ) join src2 + | on a1 = b1 + | and a2 = b2 + |""".stripMargin) + } + + @Test + def testFilterOnNonUpsertKeysBeforeJoinWithCdcSourceWithoutDelete(): Unit = { + testInnerFilterOnNonUpsertKeysBeforeJoinWithCdcSourceWithoutDelete() + } + + @Test + def testFilterOnNonUpsertKeysBeforeJoinWithCdcSourceWithoutDeleteAndFilterPushDown(): Unit = { + replaceTable("no_delete_src1", "no_delete_src1", Maps.newHashMap("filterable-fields", "a3")) + testInnerFilterOnNonUpsertKeysBeforeJoinWithCdcSourceWithoutDelete() + } + + private def testInnerFilterOnNonUpsertKeysBeforeJoinWithCdcSourceWithoutDelete(): Unit = { + util.verifyRelPlanInsert( + "insert into snk_for_cdc_src select * from no_delete_src1 " + + "join no_delete_src2 " + + "on a1 = b1 " + + "and a2 = b2 " + + "where a3 > 1", + ExplainDetail.CHANGELOG_MODE + ) + } + @Test def testFilterOnNonUpsertKeysAfterJoinWithCdcSourceWithoutDelete(): Unit = { util.verifyRelPlanInsert( @@ -243,6 +469,28 @@ class DeltaJoinTest extends TableTestBase { ) } + @Test + def testFilterOnUpsertKeysBeforeJoinWithCdcSourceWithoutDelete(): Unit = { + testFilterOnUpsertKeysBeforeJoinWithCdcSourceWithoutDeleteInner() + } + + @Test + def testFilterOnUpsertKeysBeforeJoinWithCdcSourceWithoutDeleteAndFilterPushDown(): Unit = { + replaceTable("no_delete_src1", "no_delete_src1", Maps.newHashMap("filterable-fields", "a0")) + testFilterOnUpsertKeysBeforeJoinWithCdcSourceWithoutDeleteInner() + } + + private def testFilterOnUpsertKeysBeforeJoinWithCdcSourceWithoutDeleteInner(): Unit = { + util.verifyRelPlanInsert( + "insert into snk_for_cdc_src select * from no_delete_src1 " + + "join no_delete_src2 " + + "on a1 = b1 " + + "and a2 = b2 " + + "where a0 > 1", + ExplainDetail.CHANGELOG_MODE + ) + } + @Test def testFilterOnUpsertKeysAfterJoinWithCdcSourceWithoutDelete(): Unit = { util.verifyRelPlanInsert( @@ -410,9 +658,7 @@ class DeltaJoinTest extends TableTestBase { @Test def testWithoutLookupTable(): Unit = { - util.tableEnv.executeSql( - "create table non_lookup_src with ('disable-lookup' = 'true') " + - "like src2 (OVERWRITING OPTIONS)") + replaceTable("src2", "non_lookup_src", Maps.newHashMap("disable-lookup", "true")) util.verifyRelPlanInsert( "insert into snk select * from src1 join non_lookup_src " + @@ -444,9 +690,7 @@ class DeltaJoinTest extends TableTestBase { ExecutionConfigOptions.TABLE_EXEC_SINK_UPSERT_MATERIALIZE, UpsertMaterialize.NONE) - util.tableEnv.executeSql( - "create table cdc_src with ('changelog-mode' = 'I,UA,UB,D') " + - "like src2 (OVERWRITING OPTIONS)") + replaceTable("src2", "cdc_src", Maps.newHashMap("changelog-mode", "I,UA,UB,D")) util.verifyRelPlanInsert( "insert into snk select * from src1 join cdc_src " + @@ -482,11 +726,10 @@ class DeltaJoinTest extends TableTestBase { @Test def testPKContainJoinKeyAndOnlyOneSourceNoDelete(): Unit = { - util.tableEnv.executeSql(""" - |create table all_changelog_src with ( - | 'changelog-mode' = 'I,UA,UB,D' - |) like no_delete_src1 - |""".stripMargin) + replaceTable( + "no_delete_src1", + "all_changelog_src", + Maps.newHashMap("changelog-mode", "I,UA,UB,D")) util.verifyRelPlanInsert( "insert into snk_for_cdc_src " + @@ -501,17 +744,15 @@ class DeltaJoinTest extends TableTestBase { def testPKContainsJoinKeyAndSourceNoUBAndD(): Unit = { // FLINK-38489 Currently, ChangelogNormalize will always generate changelog mode with D, // and Join with D cannot be optimized into Delta Join - util.tableEnv.executeSql(""" - |create table no_delete_and_update_before_src1 with ( - | 'changelog-mode' = 'I,UA' - |) like no_delete_src1 - |""".stripMargin) + replaceTable( + "no_delete_src1", + "no_delete_and_update_before_src1", + Maps.newHashMap("changelog-mode", "I,UA")) - util.tableEnv.executeSql(""" - |create table no_delete_and_update_before_src2 with ( - | 'changelog-mode' = 'I,UA' - |) like no_delete_src2 - |""".stripMargin) + replaceTable( + "no_delete_src2", + "no_delete_and_update_before_src2", + Maps.newHashMap("changelog-mode", "I,UA")) util.verifyRelPlanInsert( "insert into snk_for_cdc_src " + @@ -547,19 +788,6 @@ class DeltaJoinTest extends TableTestBase { ExplainDetail.CHANGELOG_MODE) } - @Test - def testSourceWithSourceAbilities(): Unit = { - util.tableEnv.executeSql( - "create table filterable_src with ('filterable-fields' = 'a3') " + - "like src1 (OVERWRITING OPTIONS)") - - util.verifyRelPlanInsert( - "insert into snk select * from filterable_src join src2 " + - "on filterable_src.a1 = src2.b1 " + - "and filterable_src.a2 = src2.b2 " + - "and filterable_src.a3 = 1") - } - @Test def testWithAggregatingSourceTableBeforeJoin(): Unit = { util.tableConfig.set( @@ -590,7 +818,7 @@ class DeltaJoinTest extends TableTestBase { @Test def testWithCascadeJoin(): Unit = { - util.tableEnv.executeSql("create table src3 like src2") + replaceTable("src2", "src3", Collections.emptyMap(), dropOldTable = false) addTable( "tmp_snk", @@ -817,7 +1045,8 @@ class DeltaJoinTest extends TableTestBase { private def addTable( tableName: String, schema: Schema, - extraOptions: JMap[String, String] = Collections.emptyMap()): Unit = { + extraOptions: JMap[String, String] = Collections.emptyMap(), + partitionKeys: JList[String] = Collections.emptyList()): Unit = { val currentCatalog = tEnv.getCurrentCatalog val currentDatabase = tEnv.getCurrentDatabase val tablePath = new ObjectPath(currentDatabase, tableName) @@ -831,7 +1060,7 @@ class DeltaJoinTest extends TableTestBase { .newBuilder() .schema(schema) .comment(testComment) - .partitionKeys(Collections.emptyList()) + .partitionKeys(partitionKeys) .options(options) .build() val resolvedTable = new ResolvedCatalogTable(original, schemaResolver.resolve(schema)) @@ -839,4 +1068,41 @@ class DeltaJoinTest extends TableTestBase { catalog.createTable(tablePath, resolvedTable, false) } + /** TODO remove this after fix FLINK-38571. */ + private def replaceTable( + oldTableName: String, + newTableName: String, + overridesOptions: JMap[String, String], + dropOldTable: Boolean = true): Unit = { + val currentCatalog = tEnv.getCurrentCatalog + val currentDatabase = tEnv.getCurrentDatabase + val oldTablePath = new ObjectPath(currentDatabase, oldTableName) + val newTablePath = new ObjectPath(currentDatabase, newTableName) + val catalog = tEnv.getCatalog(currentCatalog).get() + val schemaResolver = tEnv.getCatalogManager.getSchemaResolver + + val originalTable = catalog.getTable(oldTablePath).asInstanceOf[CatalogTable] + if (dropOldTable) { + catalog.dropTable(oldTablePath, false) + } + + val originalOptions = originalTable.getOptions + val newOptions = new JHashMap[String, String]() + newOptions.putAll(originalOptions) + newOptions.putAll(overridesOptions) + + val newTable = CatalogTable + .newBuilder() + .schema(originalTable.getUnresolvedSchema) + .comment(originalTable.getComment) + .partitionKeys(originalTable.getPartitionKeys) + .options(newOptions) + .build() + + val newResolvedTable = + new ResolvedCatalogTable(newTable, schemaResolver.resolve(originalTable.getUnresolvedSchema)) + + catalog.createTable(newTablePath, newResolvedTable, false) + } + } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala index 0c5d881d6608e..af8a866129d1e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala @@ -23,6 +23,7 @@ import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentIm import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy import org.apache.flink.table.catalog.{CatalogTable, ObjectPath, ResolvedCatalogTable} +import org.apache.flink.table.planner.{JHashMap, JMap} import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AsyncTestValueLookupFunction import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow @@ -31,16 +32,19 @@ import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedT import org.apache.flink.types.Row import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.util.Maps import org.junit.jupiter.api.{BeforeEach, TestTemplate} import org.junit.jupiter.api.extension.ExtendWith import javax.annotation.Nullable import java.time.LocalDateTime +import java.util.Collections import java.util.Objects.requireNonNull import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.mapAsScalaMapConverter @ExtendWith(Array(classOf[ParameterizedTestExtension])) class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { @@ -238,6 +242,197 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { .build()) } + @TestTemplate + def testWithNonEquiCondition2(): Unit = { + val data1 = List( + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Double.box(2.0), Int.box(2), LocalDateTime.of(2023, 3, 3, 3, 3, 3)), + // mismatch + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 11)), + changelogRow("+I", Int.box(2), Double.box(2.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected = List("+I[2.0, 2, 2023-03-03T03:03:03, 2, 2.0, 2022-02-02T02:02:22]") + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftData(data1) + .withRightData(data2) + // the filter "a2 > TO_TIMESTAMP('2021-01-01 01:01:11')" will be pushed down to + // the right side + .withJoinCondition("a0 = b0 and a1 = b1 and a2 > TO_TIMESTAMP('2021-01-01 01:01:11')") + .withSinkPk(List("l0", "r0")) + .withExpectedData(expected) + .withExpectedLookupFunctionInvokeCount(5) + .build()) + + after() + before() + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftData(data1) + .withRightData(data2) + // the filter "b1 > 1.0" will be pushed down to the right side + .withJoinCondition("a0 = b0 and b1 > 1.0") + .withSinkPk(List("l0", "r0")) + .withExpectedData(expected) + .withExpectedLookupFunctionInvokeCount(5) + .build()) + } + + @TestTemplate + def testFilterProjectBeforeJoin(): Unit = { + testFilterProjectBeforeJoinInner(false) + } + + @TestTemplate + def testFilterProjectBeforeJoinWithFilterPushDownIntoSource(): Unit = { + testFilterProjectBeforeJoinInner(true) + } + + private def testFilterProjectBeforeJoinInner(filterPushDown: Boolean): Unit = { + val data1 = List( + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Double.box(2.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)), + // mismatch + changelogRow("+I", Double.box(4.0), Int.box(4), LocalDateTime.of(2044, 4, 4, 4, 4, 4)) + ) + + val data2 = List( + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 11)), + changelogRow("+I", Int.box(2), Double.box(2.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + changelogRow("+I", Int.box(3), Double.box(3.0), LocalDateTime.of(2023, 3, 3, 3, 3, 33)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected1 = List("+I[null, 2, 2022-02-02T02:02:02, 2, null, 2022-02-02T02:02:22]") + + val (leftExtraOptions1, rightExtraOptions1): (JMap[String, String], JMap[String, String]) = + if (filterPushDown) { + (Maps.newHashMap("filterable-fields", "a2"), Maps.newHashMap("filterable-fields", "b0")) + } else { + (Collections.emptyMap(), Collections.emptyMap()) + } + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftData(data1) + .withRightData(data2) + .withLeftExtraOptions(leftExtraOptions1) + .withRightExtraOptions(rightExtraOptions1) + .withSinkPk(List("l0", "r0")) + .withFilterProjectOnLeft( + "select a0, a2 from testLeft where a2 > TO_TIMESTAMP('2021-01-01 01:01:11')") + .withFilterProjectOnRight("" + + "select b0, b2 from testRight where b0 < 3") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l0", "l2", "r0", "r2")) + .withExpectedData(expected1) + .withExpectedLookupFunctionInvokeCount(5) + .build()) + + after() + before() + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected2 = List("+I[null, 3, 2033-03-03T03:03:03, 3, null, 2023-03-03T03:03:33]") + + val (leftExtraOptions2, rightExtraOptions2): (JMap[String, String], JMap[String, String]) = + if (filterPushDown) { + (Maps.newHashMap("filterable-fields", "a1"), Maps.newHashMap("filterable-fields", "b0")) + } else { + (Collections.emptyMap(), Collections.emptyMap()) + } + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftData(data1) + .withRightData(data2) + .withLeftExtraOptions(leftExtraOptions2) + .withRightExtraOptions(rightExtraOptions2) + .withSinkPk(List("l0", "r0")) + .withFilterProjectOnLeft("select a0, a2 from testLeft where a1 > cast(2.0 as double)") + .withFilterProjectOnRight("" + + "select b0, b2 from testRight where b0 < 4") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l0", "l2", "r0", "r2")) + .withExpectedData(expected2) + .withExpectedLookupFunctionInvokeCount(5) + .build()) + } + + @TestTemplate + def testPartitionPushDownIntoSource(): Unit = { + val data1 = List( + changelogRow("+I", Double.box(50.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+I", Double.box(50.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("+I", Double.box(100.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)), + // mismatch + changelogRow("+I", Double.box(200.0), Int.box(4), LocalDateTime.of(2044, 4, 4, 4, 4, 4)) + ) + + val data2 = List( + changelogRow("+I", Int.box(1), Double.box(100.0), LocalDateTime.of(2021, 1, 1, 1, 1, 11)), + changelogRow("+I", Int.box(2), Double.box(100.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + changelogRow("+I", Int.box(3), Double.box(200.0), LocalDateTime.of(2023, 3, 3, 3, 3, 33)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(300.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected = List("+I[100.0, 3, null, 3, 200.0, null]") + + val (leftExtraOptions1, rightExtraOptions1): (JMap[String, String], JMap[String, String]) = + ( + java.util.Map.of("partition-list", "a1:50.0;a1:100.0;a1:200.0"), + java.util.Map.of("partition-list", "b1:100.0;b1:200.0;b1:300.0") + ) + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftPartitionKeys(List("a1")) + .withRightPartitionKeys(List("b1")) + .withLeftExtraOptions(leftExtraOptions1) + .withRightExtraOptions(rightExtraOptions1) + .withSinkPk(List("l0", "r0")) + .withLeftData(data1) + .withRightData(data2) + .withFilterProjectOnLeft("select a0, a1 from testLeft " + + "where a1 = cast(100.0 as double) or a1 = cast(200.0 as double)") + .withFilterProjectOnRight("select b1, b0 from testRight " + + "where b1 = cast(200.0 as double) or b1 = cast(300.0 as double)") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l0", "l1", "r1", "r0")) + .withExpectedData(expected) + .withExpectedLookupFunctionInvokeCount(4) + .build()) + } + @TestTemplate def testCdcSourceWithoutDelete(): Unit = { val data1 = List( @@ -280,8 +475,8 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { .withLeftPk(List("a0")) .withRightPk(List("b0")) .withSinkPk(List("l0", "r0")) - .withLeftChangelogMode("I,UA,UB") - .withRightChangelogMode("I,UA,UB") + .withLeftExtraOptions(Maps.newHashMap("changelog-mode", "I,UA,UB")) + .withRightExtraOptions(Maps.newHashMap("changelog-mode", "I,UA,UB")) .withLeftData(data1) .withRightData(data2) .withJoinCondition("a0 = b0") @@ -358,10 +553,10 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { .withLeftPk(List("a0", "a1")) .withRightPk(List("b0", "b1")) .withSinkPk(List("l0", "r0", "l1", "r1")) + .withLeftExtraOptions(Maps.newHashMap("changelog-mode", "I,UA,UB")) + .withRightExtraOptions(Maps.newHashMap("changelog-mode", "I,UA,UB")) .withLeftData(data1) .withRightData(data2) - .withLeftChangelogMode("I,UA,UB") - .withRightChangelogMode("I,UA,UB") .withJoinCondition("a0 = b0") .withFilterAfterJoin("a1 < b0") .withExpectedData(expected) @@ -369,6 +564,170 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { .build()) } + @TestTemplate + def testFilterProjectBeforeJoinWithCdcSourceWithoutDelete(): Unit = { + testFilterProjectBeforeJoinWithCdcSourceWithoutDeleteInner(false) + } + + @TestTemplate + def testFilterProjectBeforeJoinWithCdcSourceWithoutDeleteAndFilterPushDownIntoSource(): Unit = { + testFilterProjectBeforeJoinWithCdcSourceWithoutDeleteInner(true) + } + + private def testFilterProjectBeforeJoinWithCdcSourceWithoutDeleteInner( + filterPushDown: Boolean): Unit = { + val data1 = List( + // pk1 + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("-U", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+U", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 2)), + // pk2 + changelogRow("+I", Double.box(2.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("-U", Double.box(2.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("+U", Double.box(2.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 3)), + // mismatch + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + // pk1 + changelogRow("+I", Int.box(1), Double.box(1.0), LocalDateTime.of(2021, 1, 1, 1, 1, 11)), + // pk2 + changelogRow("+I", Int.box(2), Double.box(2.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected1 = List("+I[2.0, 2, 2022-02-02T02:02:03, 2, 2.0, null]") + + val (leftExtraOptions1, rightExtraOptions1): (JMap[String, String], JMap[String, String]) = + if (filterPushDown) { + ( + java.util.Map.of("filterable-fields", "a2", "changelog-mode", "I,UA,UB"), + java.util.Map.of("filterable-fields", "b0", "changelog-mode", "I,UA,UB")) + } else { + ( + java.util.Map.of("changelog-mode", "I,UA,UB"), + java.util.Map.of("changelog-mode", "I,UA,UB")) + } + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftPk(List("a0", "a1")) + .withRightPk(List("b0", "b1")) + .withLeftExtraOptions(leftExtraOptions1) + .withRightExtraOptions(rightExtraOptions1) + .withSinkPk(List("l0", "r0", "l1", "r1")) + .withLeftData(data1) + .withRightData(data2) + .withFilterProjectOnLeft("select a1, a2, a0 from testLeft where a1 <> cast(1.0 as double)") + .withFilterProjectOnRight("select b1, b0 from testRight") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l1", "l2", "l0", "r1", "r0")) + .withExpectedData(expected1) + .withExpectedLookupFunctionInvokeCount(if (enableCache) 5 else 6) + .build()) + + after() + before() + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected2 = List("+I[1.0, 1, 2021-01-01T01:01:02, 1, 1.0, null]") + + val (leftExtraOptions2, rightExtraOptions2): (JMap[String, String], JMap[String, String]) = + if (filterPushDown) { + ( + java.util.Map.of("filterable-fields", "a1", "changelog-mode", "I,UA,UB"), + java.util.Map.of("filterable-fields", "b0", "changelog-mode", "I,UA,UB")) + } else { + ( + java.util.Map.of("changelog-mode", "I,UA,UB"), + java.util.Map.of("changelog-mode", "I,UA,UB")) + } + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftPk(List("a0", "a1")) + .withRightPk(List("b0", "b1")) + .withLeftExtraOptions(leftExtraOptions2) + .withRightExtraOptions(rightExtraOptions2) + .withSinkPk(List("l0", "r0", "l1", "r1")) + .withLeftData(data1) + .withRightData(data2) + .withFilterProjectOnLeft("select a0, a2, a1 from testLeft") + .withFilterProjectOnRight("select b1, b0 from testRight where b0 <> 2") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l0", "l2", "l1", "r1", "r0")) + .withExpectedData(expected2) + .withExpectedLookupFunctionInvokeCount(if (enableCache) 5 else 7) + .build()) + } + + @TestTemplate + def testPartitionPushDownIntoCdcSourceWithoutDelete(): Unit = { + val data1 = List( + // pk1 + changelogRow("+I", Double.box(50.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("-U", Double.box(50.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + changelogRow("+U", Double.box(50.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 2)), + // pk2 + changelogRow("+I", Double.box(50.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("-U", Double.box(50.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 2)), + changelogRow("+U", Double.box(50.0), Int.box(2), LocalDateTime.of(2022, 2, 2, 2, 2, 3)), + // mismatch + changelogRow("+I", Double.box(100.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + // pk1 + changelogRow("+I", Int.box(1), Double.box(50.0), LocalDateTime.of(2021, 1, 1, 1, 1, 11)), + // pk2 + changelogRow("+I", Int.box(2), Double.box(500.0), LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // mismatch + changelogRow("+I", Int.box(99), Double.box(99.0), LocalDateTime.of(2099, 2, 2, 2, 2, 2)) + ) + + // TestValuesRuntimeFunctions#KeyedUpsertingSinkFunction will change the RowKind from + // "+U" to "+I" + val expected = List("+I[50.0, 2, 2022-02-02T02:02:03, 2, 500.0, null]") + + val (leftExtraOptions1, rightExtraOptions1): (JMap[String, String], JMap[String, String]) = + ( + java.util.Map.of("partition-list", "a1:50.0;a1:100.0", "changelog-mode", "I,UA,UB"), + java.util.Map.of("partition-list", "b1:50.0;b1:500.0;b1:99.0", "changelog-mode", "I,UA,UB") + ) + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0")) + .withRightIndex(List("b0")) + .withLeftPk(List("a0", "a1")) + .withRightPk(List("b0", "b1")) + .withLeftPartitionKeys(List("a1")) + .withRightPartitionKeys(List("b1")) + .withLeftExtraOptions(leftExtraOptions1) + .withRightExtraOptions(rightExtraOptions1) + .withSinkPk(List("l0", "r0", "l1", "r1")) + .withLeftData(data1) + .withRightData(data2) + .withFilterProjectOnLeft("select a1, a2, a0 from testLeft " + + "where a1 = cast(50.0 as double) or a1 = cast(100.0 as double)") + .withFilterProjectOnRight("select b1, b0 from testRight " + + "where b1 = cast(500.0 as double) or b1 = cast(99.0 as double)") + .withJoinCondition("a0 = b0") + .withPartialInsertCols(List("l1", "l2", "l0", "r1", "r0")) + .withExpectedData(expected) + .withExpectedLookupFunctionInvokeCount(if (enableCache) 5 else 7) + .build()) + } + @TestTemplate def testProjectFieldsAfterJoin(): Unit = { val data1 = List( @@ -492,18 +851,45 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { testSpec.leftPk.orNull, testSpec.rightPk.orNull, testSpec.sinkPk, + testSpec.leftPartitionKeys, + testSpec.rightPartitionKeys, testSpec.leftData, testSpec.rightData, testSpec.testFailingSource, - testSpec.leftChangelogMode, - testSpec.rightChangelogMode + testSpec.leftExtraOptions, + testSpec.rightExtraOptions ) + val partialInsertStr = if (testSpec.partialInsertCols.isEmpty) { + "" + } else { + s"(${testSpec.partialInsertCols.get.mkString(",")})" + } + + val queryOnLeft = if (testSpec.filterProjectOnLeft.isEmpty) { + "testLeft" + } else { + s"(${testSpec.filterProjectOnLeft.get})" + } + + val queryOnRight = if (testSpec.filterProjectOnRight.isEmpty) { + "testRight" + } else { + s"(${testSpec.filterProjectOnRight.get})" + } + + val filterAfterJoin = if (testSpec.filterAfterJoin.isEmpty) { + "" + } else { + s"where ${testSpec.filterAfterJoin.get}" + } + val sql = s""" - | insert into testSnk - | select * from testLeft join testRight on ${testSpec.joinCondition} - | ${if (testSpec.filterAfterJoin.isEmpty) "" else s"where ${testSpec.filterAfterJoin.get}"} + | insert into testSnk $partialInsertStr + | select * from $queryOnLeft join $queryOnRight + | on ${testSpec.joinCondition} + | $filterAfterJoin |""".stripMargin tEnv .executeSql(sql) @@ -528,11 +914,14 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { null, null, List("l0", "r0"), + List(), + List(), leftData, rightData, testFailingSource = false, - "I", - "I") + Collections.emptyMap(), + Collections.emptyMap() + ) } private def prepareTable( @@ -541,12 +930,30 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { @Nullable leftPk: List[String], @Nullable rightPk: List[String], sinkPk: List[String], + leftPartitionKeys: List[String], + rightPartitionKeys: List[String], leftData: List[Row], rightData: List[Row], testFailingSource: Boolean, - leftChangelogMode: String, - rightChangelogMode: String): Unit = { + leftExtraOptions: JMap[String, String], + rightExtraOptions: JMap[String, String]): Unit = { tEnv.executeSql("drop table if exists testLeft") + val leftExtraOptionsStr = + if (leftExtraOptions.isEmpty) { + "" + } else { + "," + leftExtraOptions.asScala + .map { case (key, value) => s"'$key' = '$value'" } + .mkString(", ") + } + + val leftPartitionStr = + if (leftPartitionKeys.isEmpty) { + "" + } else { + s"PARTITIONED BY (${leftPartitionKeys.mkString(",")})" + } + tEnv.executeSql( s""" |create table testLeft( @@ -554,18 +961,33 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { | a0 int, | a2 timestamp(3) | ${if (leftPk == null) "" else s", primary key (${leftPk.mkString(",")}) not enforced"} - |) with ( + |) $leftPartitionStr + |with ( | 'connector' = 'values', | 'bounded' = 'false', - | 'changelog-mode' = '$leftChangelogMode', | 'data-id' = '${TestValuesTableFactory.registerData(leftData)}', | 'async' = 'true', | 'failing-source' = '$testFailingSource' + | $leftExtraOptionsStr |) |""".stripMargin) addIndex("testLeft", leftIndex) tEnv.executeSql("drop table if exists testRight") + val rightExtraOptionsStr = + if (rightExtraOptions.isEmpty) { + "" + } else { + "," + rightExtraOptions.asScala + .map { case (key, value) => s"'$key' = '$value'" } + .mkString(", ") + } + val rightPartitionStr = + if (rightPartitionKeys.isEmpty) { + "" + } else { + s"PARTITIONED BY (${rightPartitionKeys.mkString(",")})" + } tEnv.executeSql( s""" |create table testRight( @@ -573,13 +995,14 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { | b1 double, | b2 timestamp(3) | ${if (rightPk == null) "" else s", primary key (${rightPk.mkString(",")}) not enforced"} - |) with ( + |) $rightPartitionStr + |with ( | 'connector' = 'values', | 'bounded' = 'false', - | 'changelog-mode' = '$rightChangelogMode', | 'data-id' = '${TestValuesTableFactory.registerData(rightData)}', | 'async' = 'true', | 'failing-source' = '$testFailingSource' + | $rightExtraOptionsStr |) |""".stripMargin) addIndex("testRight", rightIndex) @@ -611,16 +1034,21 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { rightIndex: List[String], leftPk: Option[List[String]], rightPk: Option[List[String]], + partialInsertCols: Option[List[String]], sinkPk: List[String], + leftPartitionKeys: List[String], + rightPartitionKeys: List[String], leftData: List[Row], rightData: List[Row], + filterProjectOnLeft: Option[String] = None, + filterProjectOnRight: Option[String] = None, joinCondition: String, filterAfterJoin: Option[String], expected: List[String], expectedLookupFunctionInvokeCount: Option[Int], testFailingSource: Boolean, - leftChangelogMode: String, - rightChangelogMode: String + leftExtraOptions: JMap[String, String], + rightExtraOptions: JMap[String, String] ) private class TestSpecBuilder { @@ -628,7 +1056,12 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { private var rightIndex: Option[List[String]] = None private var leftPk: Option[List[String]] = None private var rightPk: Option[List[String]] = None + private var partialInsertCols: Option[List[String]] = None private var sinkPk: Option[List[String]] = None + private var leftPartitionKeys: Option[List[String]] = None + private var rightPartitionKeys: Option[List[String]] = None + private var filterProjectOnLeft: Option[String] = None + private var filterProjectOnRight: Option[String] = None private var joinCondition: Option[String] = None private var filterAfterJoin: Option[String] = None private var leftData: Option[List[Row]] = None @@ -636,8 +1069,8 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { private var expectedData: Option[List[String]] = None private var expectedLookupFunctionInvokeCount: Option[Int] = None private var testFailingSource: Boolean = false - private var leftChangelogMode: String = "I" - private var rightChangelogMode: String = "I" + private val leftExtraOptions: JMap[String, String] = new JHashMap[String, String] + private val rightExtraOptions: JMap[String, String] = new JHashMap[String, String] def withLeftIndex(index: List[String]): TestSpecBuilder = { leftIndex = Some(requireNonNull(index)) @@ -659,11 +1092,26 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { this } + def withPartialInsertCols(cols: List[String]): TestSpecBuilder = { + partialInsertCols = Some(requireNonNull(cols)) + this + } + def withSinkPk(pk: List[String]): TestSpecBuilder = { sinkPk = Some(requireNonNull(pk)) this } + def withLeftPartitionKeys(partitionKeys: List[String]): TestSpecBuilder = { + leftPartitionKeys = Some(requireNonNull(partitionKeys)) + this + } + + def withRightPartitionKeys(partitionKeys: List[String]): TestSpecBuilder = { + rightPartitionKeys = Some(requireNonNull(partitionKeys)) + this + } + def withLeftData(data: List[Row]): TestSpecBuilder = { leftData = Some(requireNonNull(data)) this @@ -674,6 +1122,16 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { this } + def withFilterProjectOnLeft(query: String): TestSpecBuilder = { + filterProjectOnLeft = Some(requireNonNull(query)) + this + } + + def withFilterProjectOnRight(query: String): TestSpecBuilder = { + filterProjectOnRight = Some(requireNonNull(query)) + this + } + def withJoinCondition(condition: String): TestSpecBuilder = { joinCondition = Some(requireNonNull(condition)) this @@ -699,13 +1157,13 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { this } - def withLeftChangelogMode(mode: String): TestSpecBuilder = { - leftChangelogMode = requireNonNull(mode) + def withLeftExtraOptions(options: JMap[String, String]): TestSpecBuilder = { + leftExtraOptions.putAll(options) this } - def withRightChangelogMode(mode: String): TestSpecBuilder = { - rightChangelogMode = requireNonNull(mode) + def withRightExtraOptions(options: JMap[String, String]): TestSpecBuilder = { + rightExtraOptions.putAll(options) this } @@ -715,16 +1173,21 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { requireNonNull(rightIndex.orNull), leftPk, rightPk, + partialInsertCols, requireNonNull(sinkPk.orNull), + requireNonNull(leftPartitionKeys.getOrElse(List())), + requireNonNull(rightPartitionKeys.getOrElse(List())), requireNonNull(leftData.orNull), requireNonNull(rightData.orNull), + filterProjectOnLeft, + filterProjectOnRight, requireNonNull(joinCondition.orNull), filterAfterJoin, requireNonNull(expectedData.orNull), expectedLookupFunctionInvokeCount, testFailingSource, - leftChangelogMode, - rightChangelogMode + leftExtraOptions, + rightExtraOptions ) } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java index d73bfd315db35..41bcf8746a9fb 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime.operators.join.deltajoin; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.OpenContext; import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.streaming.api.functions.async.AsyncFunction; @@ -32,6 +33,7 @@ import org.apache.flink.table.runtime.generated.GeneratedFunction; import org.apache.flink.table.runtime.generated.GeneratedResultFuture; import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.join.lookup.CalcCollectionCollector; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; import org.apache.flink.util.Preconditions; @@ -61,6 +63,8 @@ public class AsyncDeltaJoinRunner extends RichAsyncFunction { "deltaJoinRightCallAsyncFetchCostTime"; private final GeneratedFunction> generatedFetcher; private final DataStructureConverter fetcherConverter; + private final @Nullable GeneratedFunction> + lookupSideGeneratedCalc; private final GeneratedResultFuture> generatedResultFuture; private final int asyncBufferCapacity; @@ -105,6 +109,7 @@ public class AsyncDeltaJoinRunner extends RichAsyncFunction { public AsyncDeltaJoinRunner( GeneratedFunction> generatedFetcher, DataStructureConverter fetcherConverter, + @Nullable GeneratedFunction> lookupSideGeneratedCalc, GeneratedResultFuture> generatedResultFuture, RowDataSerializer lookupSideRowSerializer, RowDataKeySelector leftJoinKeySelector, @@ -116,6 +121,7 @@ public AsyncDeltaJoinRunner( boolean enableCache) { this.generatedFetcher = generatedFetcher; this.fetcherConverter = fetcherConverter; + this.lookupSideGeneratedCalc = lookupSideGeneratedCalc; this.generatedResultFuture = generatedResultFuture; this.lookupSideRowSerializer = lookupSideRowSerializer; this.leftJoinKeySelector = leftJoinKeySelector; @@ -139,7 +145,10 @@ public void open(OpenContext openContext) throws Exception { FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext()); FunctionUtils.openFunction(fetcher, openContext); - // try to compile the generated ResultFuture, fail fast if the code is corrupt. + // try to compile the generated Calc and ResultFuture, fail fast if the code is corrupt. + if (lookupSideGeneratedCalc != null) { + lookupSideGeneratedCalc.compile(getRuntimeContext().getUserCodeClassLoader()); + } generatedResultFuture.compile(getRuntimeContext().getUserCodeClassLoader()); fetcherConverter.open(getRuntimeContext().getUserCodeClassLoader()); @@ -155,11 +164,13 @@ public void open(OpenContext openContext) throws Exception { JoinedRowResultFuture rf = new JoinedRowResultFuture( resultFutureBuffer, + createCalcFunction(openContext), createFetcherResultFuture(openContext), fetcherConverter, treatRightAsLookupTable, leftUpsertKeySelector, rightUpsertKeySelector, + lookupSideRowSerializer, enableCache, cache); // add will throw exception immediately if the queue is full which should never happen @@ -198,7 +209,7 @@ public void asyncInvoke(RowData input, ResultFuture resultFuture) throw if (enableCache) { Optional> dataFromCache = tryGetDataFromCache(streamJoinKey); if (dataFromCache.isPresent()) { - outResultFuture.complete(dataFromCache.get()); + outResultFuture.complete(dataFromCache.get(), true); return; } } @@ -209,6 +220,20 @@ public void asyncInvoke(RowData input, ResultFuture resultFuture) throw callAsyncFetchCostTime = System.currentTimeMillis() - startTime; } + @Nullable + private FlatMapFunction createCalcFunction(OpenContext openContext) + throws Exception { + FlatMapFunction calc = null; + if (lookupSideGeneratedCalc != null) { + calc = + lookupSideGeneratedCalc.newInstance( + getRuntimeContext().getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(calc, getRuntimeContext()); + FunctionUtils.openFunction(calc, openContext); + } + return calc; + } + public TableFunctionResultFuture createFetcherResultFuture(OpenContext openContext) throws Exception { TableFunctionResultFuture resultFuture = @@ -271,6 +296,8 @@ private Optional> tryGetDataFromCache(RowData joinKey) { @VisibleForTesting public static final class JoinedRowResultFuture implements ResultFuture { private final BlockingQueue resultFutureBuffer; + private final @Nullable FlatMapFunction calcFunction; + private final CalcCollectionCollector calcCollector; private final TableFunctionResultFuture joinConditionResultFuture; private final DataStructureConverter resultConverter; @@ -289,14 +316,18 @@ public static final class JoinedRowResultFuture implements ResultFuture private JoinedRowResultFuture( BlockingQueue resultFutureBuffer, + @Nullable FlatMapFunction calcFunction, TableFunctionResultFuture joinConditionResultFuture, DataStructureConverter resultConverter, boolean treatRightAsLookupTable, RowDataKeySelector leftUpsertKeySelector, RowDataKeySelector rightUpsertKeySelector, + RowDataSerializer lookupSideRowSerializer, boolean enableCache, DeltaJoinCache cache) { this.resultFutureBuffer = resultFutureBuffer; + this.calcFunction = calcFunction; + this.calcCollector = new CalcCollectionCollector(lookupSideRowSerializer); this.joinConditionResultFuture = joinConditionResultFuture; this.resultConverter = resultConverter; this.enableCache = enableCache; @@ -314,43 +345,62 @@ public void reset( this.realOutput = realOutput; this.streamJoinKey = joinKey; this.streamRow = row; + joinConditionResultFuture.setInput(row); joinConditionResultFuture.setResultFuture(delegate); delegate.reset(); + calcCollector.reset(); } @Override public void complete(Collection result) { + complete(result, false); + } + + public void complete(Collection result, boolean fromCache) { if (result == null) { result = Collections.emptyList(); } + Collection rowDataCollection = convertToInternalData(result); + + Collection lookupRowsAfterCalc = rowDataCollection; + if (!fromCache && calcFunction != null && rowDataCollection != null) { + for (RowData row : rowDataCollection) { + try { + calcFunction.flatMap(row, calcCollector); + } catch (Exception e) { + completeExceptionally(e); + } + } + lookupRowsAfterCalc = calcCollector.getCollection(); + } + // now we have received the rows from the lookup table, try to set them to the cache try { - updateCacheIfNecessary(result); + updateCacheIfNecessary(lookupRowsAfterCalc); } catch (Throwable t) { LOG.info("Failed to update the cache", t); completeExceptionally(t); return; } - Collection rowDataCollection = convertToInternalData(result); - // call condition collector first, + // call join condition collector, // the filtered result will be routed to the delegateCollector try { - joinConditionResultFuture.complete(rowDataCollection); + joinConditionResultFuture.complete(lookupRowsAfterCalc); } catch (Throwable t) { // we should catch the exception here to let the framework know completeExceptionally(t); return; } - Collection lookupRows = delegate.collection; - if (lookupRows == null || lookupRows.isEmpty()) { + Collection lookupRowsAfterJoin = delegate.collection; + if (lookupRowsAfterJoin == null || lookupRowsAfterJoin.isEmpty()) { realOutput.complete(Collections.emptyList()); } else { List outRows = new ArrayList<>(); - for (RowData lookupRow : lookupRows) { + for (RowData lookupRow : lookupRowsAfterJoin) { RowData outRow; if (treatRightAsLookupTable) { outRow = new JoinedRowData(streamRow.getRowKind(), streamRow, lookupRow); @@ -388,7 +438,7 @@ public void close() throws Exception { joinConditionResultFuture.close(); } - private void updateCacheIfNecessary(Collection lookupRows) throws Exception { + private void updateCacheIfNecessary(Collection lookupRows) throws Exception { if (!enableCache) { return; } @@ -419,7 +469,7 @@ private void updateCacheIfNecessary(Collection lookupRows) throws Except } private LinkedHashMap buildMapWithUkAsKeys( - Collection lookupRows, boolean treatRightAsLookupTable) throws Exception { + Collection lookupRows, boolean treatRightAsLookupTable) throws Exception { LinkedHashMap map = new LinkedHashMap<>(); for (Object lookupRow : lookupRows) { RowData rowData = convertToInternalData(lookupRow); diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinWithCalcRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinWithCalcRunner.java index 4d6e035b95ec5..b14c34815621c 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinWithCalcRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinWithCalcRunner.java @@ -32,9 +32,7 @@ import org.apache.flink.table.runtime.generated.GeneratedFunction; import org.apache.flink.table.runtime.generated.GeneratedResultFuture; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; -import org.apache.flink.util.Collector; -import java.util.ArrayList; import java.util.Collection; /** The async join runner with an additional calculate function on the dimension table. */ @@ -94,7 +92,8 @@ private class TemporalTableCalcResultFuture extends TableFunctionResultFuture calc; private final TableFunctionResultFuture joinConditionResultFuture; - private final CalcCollectionCollector calcCollector = new CalcCollectionCollector(); + private final CalcCollectionCollector calcCollector = + new CalcCollectionCollector(rightRowSerializer); private TemporalTableCalcResultFuture( FlatMapFunction calc, @@ -116,7 +115,7 @@ public void setResultFuture(ResultFuture resultFuture) { @Override public void complete(Collection result) { - if (result == null || result.size() == 0) { + if (result == null || result.isEmpty()) { joinConditionResultFuture.complete(result); } else { for (RowData row : result) { @@ -126,7 +125,7 @@ public void complete(Collection result) { joinConditionResultFuture.completeExceptionally(e); } } - joinConditionResultFuture.complete(calcCollector.collection); + joinConditionResultFuture.complete(calcCollector.getCollection()); } } @@ -137,21 +136,4 @@ public void close() throws Exception { FunctionUtils.closeFunction(calc); } } - - private class CalcCollectionCollector implements Collector { - - Collection collection; - - public void reset() { - this.collection = new ArrayList<>(); - } - - @Override - public void collect(RowData record) { - this.collection.add(rightRowSerializer.copy(record)); - } - - @Override - public void close() {} - } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java new file mode 100644 index 0000000000000..59dd86d77c884 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.lookup; + +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.Collection; + +/** A {@link Collector} used to store data after calc function. */ +public class CalcCollectionCollector implements Collector { + + private final RowDataSerializer lookupResultRowSerializer; + + private Collection collection; + + public CalcCollectionCollector(RowDataSerializer lookupResultRowSerializer) { + this.lookupResultRowSerializer = lookupResultRowSerializer; + } + + public void reset() { + this.collection = new ArrayList<>(); + } + + public Collection getCollection() { + return collection; + } + + @Override + public void collect(RowData record) { + this.collection.add(lookupResultRowSerializer.copy(record)); + } + + @Override + public void close() {} +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java index de52c8f2be2bd..bae1d5e1096e3 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTest.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.operators.join.deltajoin; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.OpenContext; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.streaming.api.functions.async.ResultFuture; @@ -77,6 +78,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -131,7 +133,7 @@ public void afterEach() throws Exception { @TestTemplate void testJoinBothLogTables() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -239,9 +241,113 @@ void testJoinBothLogTables() throws Exception { } } + @TestTemplate + void testJoinBothLogTablesWhileFilterExistsOnBothTable() throws Exception { + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITH_FILTER_ON_TABLE; + initTestHarness(testSpec); + initAssertor(testSpec); + + StreamRecord leftRecord1 = insertRecord(100, true, "jklk1"); + testHarness.processElement1(leftRecord1); + + // will be filtered upstream + StreamRecord leftRecord2 = insertRecord(100, false, "jklk2"); + insertLeftTable(testSpec, leftRecord2); + + StreamRecord leftRecord3 = insertRecord(200, true, "jklk1"); + testHarness.processElement1(leftRecord3); + + // will be filtered upstream + StreamRecord leftRecord4 = insertRecord(200, false, "jklk2"); + insertLeftTable(testSpec, leftRecord4); + + StreamRecord rightRecord1 = insertRecord("jklk1", 300, true); + testHarness.processElement2(rightRecord1); + + // will be filtered upstream + StreamRecord rightRecord2 = insertRecord("jklk2", 300, false); + insertRightTable(testSpec, rightRecord2); + + // mismatch + StreamRecord leftRecord5 = insertRecord(200, true, "unknown1"); + testHarness.processElement1(leftRecord5); + + // mismatch and will be filtered upstream + StreamRecord rightRecord3 = insertRecord("unknown2", 300, false); + insertRightTable(testSpec, rightRecord3); + + StreamRecord leftRecord6 = insertRecord(800, true, "jklk1"); + testHarness.processElement1(leftRecord6); + + // will be filtered upstream + StreamRecord leftRecord7 = insertRecord(800, false, "jklk2"); + insertLeftTable(testSpec, leftRecord7); + + StreamRecord rightRecord4 = insertRecord("jklk1", 1000, true); + testHarness.processElement2(rightRecord4); + + // will be filtered upstream + StreamRecord rightRecord5 = insertRecord("jklk2", 1000, false); + insertRightTable(testSpec, rightRecord5); + + waitAllDataProcessed(); + + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(insertRecord(100, true, "jklk1", "jklk1", 300, true)); + expectedOutput.add(insertRecord(200, true, "jklk1", "jklk1", 300, true)); + expectedOutput.add(insertRecord(800, true, "jklk1", "jklk1", 300, true)); + expectedOutput.add(insertRecord(100, true, "jklk1", "jklk1", 1000, true)); + expectedOutput.add(insertRecord(200, true, "jklk1", "jklk1", 1000, true)); + expectedOutput.add(insertRecord(800, true, "jklk1", "jklk1", 1000, true)); + + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + TableAsyncExecutionController aec = unwrapAEC(testHarness); + assertThat(aec.getBlockingSize()).isEqualTo(0); + assertThat(aec.getInFlightSize()).isEqualTo(0); + assertThat(aec.getFinishSize()).isEqualTo(0); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + RowType leftRowType = testSpec.getLeftInputRowType(); + RowType rightRowType = testSpec.getRightInputRowType(); + Map> expectedLeftCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(leftRecord1.getValue(), leftRowType), + leftRecord1.getValue(), + toBinary(leftRecord3.getValue(), leftRowType), + leftRecord3.getValue(), + toBinary(leftRecord6.getValue(), leftRowType), + leftRecord6.getValue())); + + Map> expectedRightCacheData = + newHashMap( + binaryrow(true, "jklk1"), + newHashMap( + toBinary(rightRecord1.getValue(), rightRowType), + rightRecord1.getValue(), + toBinary(rightRecord4.getValue(), rightRowType), + rightRecord4.getValue()), + binaryrow(true, "unknown1"), + Collections.emptyMap()); + + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 2, 1, 4, 2); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(2); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(1); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(4); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(2); + } + } + @TestTemplate void testJoinBothPkTables() throws Exception { - PkPkTableJoinTestSpec testSpec = PkPkTableJoinTestSpec.INSTANCE; + PkPkTableJoinTestSpec testSpec = PkPkTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -329,9 +435,90 @@ void testJoinBothPkTables() throws Exception { } } + @TestTemplate + void testJoinBothPkTablesWhileFilterExistsOnBothTable() throws Exception { + PkPkTableJoinTestSpec testSpec = PkPkTableJoinTestSpec.WITH_FILTER_ON_TABLE; + initTestHarness(testSpec); + initAssertor(testSpec); + + StreamRecord leftRecordK1V1 = insertRecord(100, true, "Tom"); + testHarness.processElement1(leftRecordK1V1); + + // will be filtered upstream + StreamRecord leftRecordK2V1 = insertRecord(101, false, "Tom"); + insertLeftTable(testSpec, leftRecordK2V1); + + // mismatch and will be filtered upstream + StreamRecord leftRecordK3V1 = insertRecord(1999, false, "Jim"); + insertLeftTable(testSpec, leftRecordK3V1); + + waitAllDataProcessed(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + // will be filtered upstream + StreamRecord rightRecordK1V1 = insertRecord("Tom", 200, true); + insertRightTable(testSpec, rightRecordK1V1); + + StreamRecord rightRecordK2V1 = insertRecord("Tom", 201, false); + testHarness.processElement2(rightRecordK2V1); + + // mismatch + StreamRecord rightRecordK3V1 = insertRecord("Sam", 2999, true); + testHarness.processElement2(rightRecordK3V1); + + waitAllDataProcessed(); + expectedOutput.add(insertRecord(100, true, "Tom", "Tom", 201, false)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V2 = updateAfterRecord(1000, true, "Tom"); + testHarness.processElement1(leftRecordK1V2); + + waitAllDataProcessed(); + expectedOutput.add(updateAfterRecord(1000, true, "Tom", "Tom", 201, false)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + // will be filtered upstream + StreamRecord rightRecordK1V2 = updateAfterRecord("Tom", 2000, true); + insertRightTable(testSpec, rightRecordK1V2); + + StreamRecord rightRecordK2V2 = updateAfterRecord("Tom", 2001, false); + testHarness.processElement2(rightRecordK2V2); + + waitAllDataProcessed(); + expectedOutput.add(updateAfterRecord(1000, true, "Tom", "Tom", 2001, false)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + Map> expectedLeftCacheData = + newHashMap( + binaryrow("Tom"), + newHashMap(binaryrow(true, "Tom"), leftRecordK1V2.getValue()), + binaryrow("Sam"), + Collections.emptyMap()); + + Map> expectedRightCacheData = + newHashMap( + binaryrow("Tom"), + newHashMap(binaryrow("Tom", false), rightRecordK2V2.getValue())); + verifyCacheData(cache, expectedLeftCacheData, expectedRightCacheData, 3, 1, 2, 1); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(1); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(2); + } else { + verifyCacheData(cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.leftInvokeCount.get()).isEqualTo(2); + assertThat(MyAsyncFunction.rightInvokeCount.get()).isEqualTo(3); + } + } + @TestTemplate void testBlockingWithSameJoinKey() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -457,7 +644,7 @@ void testBlockingWithSameJoinKey() throws Exception { */ @TestTemplate void testLogTableDataVisibleBeforeJoin() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -569,7 +756,7 @@ void testLogTableDataVisibleBeforeJoin() throws Exception { */ @TestTemplate void testPkTableDataVisibleBeforeJoin() throws Exception { - PkPkTableJoinTestSpec testSpec = PkPkTableJoinTestSpec.INSTANCE; + PkPkTableJoinTestSpec testSpec = PkPkTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -686,7 +873,7 @@ void testPkTableDataVisibleBeforeJoin() throws Exception { @TestTemplate void testCheckpointAndRestore() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -804,7 +991,7 @@ void testCheckpointAndRestore() throws Exception { @TestTemplate void testClearLegacyStateWhenCheckpointing() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -859,7 +1046,7 @@ void testClearLegacyStateWhenCheckpointing() throws Exception { @TestTemplate void testMeetExceptionWhenLookup() throws Exception { - LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.INSTANCE; + LogLogTableJoinTestSpec testSpec = LogLogTableJoinTestSpec.WITHOUT_FILTER_ON_TABLE; initTestHarness(testSpec); initAssertor(testSpec); @@ -1026,6 +1213,8 @@ public MyAsyncFunction newInstance(ClassLoader classLoader) { } }, leftFetcherConverter, + new MockGeneratedFlatMapFunction( + testSpec.getFilterOnLeftTable().orElse(null)), new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), testSpec.getLeftTypeInfo().toRowSerializer(), testSpec.getLeftJoinKeySelector(), @@ -1050,6 +1239,8 @@ public MyAsyncFunction newInstance(ClassLoader classLoader) { } }, rightFetcherConverter, + new MockGeneratedFlatMapFunction( + testSpec.getFilterOnRightTable().orElse(null)), new GeneratedResultFutureWrapper<>(new TestingFetcherResultFuture()), testSpec.getRightTypeInfo().toRowSerializer(), testSpec.getLeftJoinKeySelector(), @@ -1348,6 +1539,34 @@ public void complete(Collection result) { } } + private static class MockGeneratedFlatMapFunction + extends GeneratedFunction> { + + private static final long serialVersionUID = 1L; + + private final @Nullable Function condition; + + public MockGeneratedFlatMapFunction(@Nullable Function condition) { + super("", "", new Object[0]); + this.condition = condition; + } + + @Override + public FlatMapFunction newInstance(ClassLoader classLoader) { + return (value, out) -> { + if (condition == null || condition.apply(value)) { + out.collect(value); + } + }; + } + + @Override + public Class> compile(ClassLoader classLoader) { + // just avoid exceptions + return null; + } + } + private abstract static class AbstractTestSpec { abstract RowType getLeftInputRowType(); @@ -1406,6 +1625,10 @@ final int[] getOutputFieldIndices() { return IntStream.range(0, getOutputRowType().getFieldCount()).toArray(); } + abstract Optional> getFilterOnLeftTable(); + + abstract Optional> getFilterOnRightTable(); + private RowDataKeySelector getUpsertKeySelector( RowType rowType, @Nullable int[] upsertKey) { if (upsertKey == null) { @@ -1437,15 +1660,37 @@ private RowDataKeySelector getUpsertKeySelector( * ) * * + *

If the flag {@link #filterOnTable} is false, the query is: + * *

      *     select * from leftSrc join rightSrc
      *      on leftSrc.left_jk1 = rightSrc.right_jk1_index
      *      and leftSrc.left_jk2_index = rightSrc.right_jk2
      * 
+ * + *

If the flag {@link #filterOnTable} is true, the query is: + * + *

+     *     select * from (
+     *      select * from leftSrc where left_jk1 = 'true'
+     *     ) join (
+     *      select * from rightSrc where right_jk2 = 'jklk1'
+     *     ) on left_jk1 = right_jk1_index
+     *      and left_jk2_index = right_jk2
+     * 
*/ private static class LogLogTableJoinTestSpec extends AbstractTestSpec { - private static final LogLogTableJoinTestSpec INSTANCE = new LogLogTableJoinTestSpec(); + private static final LogLogTableJoinTestSpec WITHOUT_FILTER_ON_TABLE = + new LogLogTableJoinTestSpec(false); + private static final LogLogTableJoinTestSpec WITH_FILTER_ON_TABLE = + new LogLogTableJoinTestSpec(true); + + private final boolean filterOnTable; + + public LogLogTableJoinTestSpec(boolean filterOnTable) { + this.filterOnTable = filterOnTable; + } @Override RowType getLeftInputRowType() { @@ -1480,6 +1725,22 @@ int[] getLeftJoinKeyIndices() { int[] getRightJoinKeyIndices() { return new int[] {2, 0}; } + + @Override + Optional> getFilterOnLeftTable() { + if (filterOnTable) { + return Optional.of((rowData -> rowData.getBoolean(1))); + } + return Optional.empty(); + } + + @Override + Optional> getFilterOnRightTable() { + if (filterOnTable) { + return Optional.of((rowData -> "jklk1".equals(rowData.getString(0).toString()))); + } + return Optional.empty(); + } } /** @@ -1505,14 +1766,35 @@ int[] getRightJoinKeyIndices() { * ) * * + *

If the flag {@link #filterOnTable} is false, the query is: + * *

      *     select * from leftSrc join rightSrc
      *      on leftSrc.left_pk2_jk_index = rightSrc.right_pk2_jk_index
      * 
+ * + *

If the flag {@link #filterOnTable} is true, the query is: + * + *

+     *     select * from (
+     *       select * from leftSrc where left_pk1 = 'true'
+     *     ) join (
+     *       select * form rightSrc where right_pk1 = 'false'
+     *     ) on left_pk2_jk_index = right_pk2_jk_index
+     * 
*/ private static class PkPkTableJoinTestSpec extends AbstractTestSpec { - private static final PkPkTableJoinTestSpec INSTANCE = new PkPkTableJoinTestSpec(); + private static final PkPkTableJoinTestSpec WITHOUT_FILTER_ON_TABLE = + new PkPkTableJoinTestSpec(false); + private static final PkPkTableJoinTestSpec WITH_FILTER_ON_TABLE = + new PkPkTableJoinTestSpec(true); + + private final boolean filterOnTable; + + public PkPkTableJoinTestSpec(boolean filterOnTable) { + this.filterOnTable = filterOnTable; + } @Override RowType getLeftInputRowType() { @@ -1547,5 +1829,21 @@ int[] getLeftJoinKeyIndices() { int[] getRightJoinKeyIndices() { return new int[] {0}; } + + @Override + Optional> getFilterOnLeftTable() { + if (filterOnTable) { + return Optional.of((rowData -> rowData.getBoolean(1))); + } + return Optional.empty(); + } + + @Override + Optional> getFilterOnRightTable() { + if (filterOnTable) { + return Optional.of((rowData -> !rowData.getBoolean(2))); + } + return Optional.empty(); + } } }