Skip to content

Commit 56bf7c8

Browse files
authored
[FLINK-38424][planner] Support to parse VECTOR_SEARCH function (#27039)
1 parent cc01169 commit 56bf7c8

File tree

6 files changed

+653
-14
lines changed

6 files changed

+653
-14
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717
package org.apache.calcite.sql.validate;
1818

19+
import org.apache.flink.table.api.ValidationException;
1920
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
21+
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
2022

2123
import com.google.common.base.Preconditions;
2224
import com.google.common.collect.ImmutableList;
@@ -177,10 +179,12 @@
177179
*
178180
* <p>Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0.
179181
*
180-
* <p>Lines 3895 ~ 3899, 6574 ~ 6580 Flink improves Optimize the retrieval of sub-operands in
182+
* <p>Line 2618 ~2631, set the correct scope for VECTOR_SEARCH.
183+
*
184+
* <p>Lines 3920 ~ 3925, 6599 ~ 6606 Flink improves Optimize the retrieval of sub-operands in
181185
* SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}.
182186
*
183-
* <p>Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check on SqlSnapshot.
187+
* <p>Lines 5340 ~ 5347, FLINK-24352 Add null check for temporal table check on SqlSnapshot.
184188
*/
185189
public class SqlValidatorImpl implements SqlValidatorWithHints {
186190
// ~ Static fields/initializers ---------------------------------------------
@@ -2570,6 +2574,10 @@ private SqlNode registerFrom(
25702574
case LATERAL:
25712575
// ----- FLINK MODIFICATION BEGIN -----
25722576
SqlBasicCall sbc = (SqlBasicCall) node;
2577+
// Put the usingScope which is a JoinScope,
2578+
// in order to make visible the left items
2579+
// of the JOIN tree.
2580+
scopes.put(node, usingScope);
25732581
registerFrom(
25742582
parentScope,
25752583
usingScope,
@@ -2580,10 +2588,6 @@ private SqlNode registerFrom(
25802588
extendList,
25812589
forceNullable,
25822590
true);
2583-
// Put the usingScope which is a JoinScope,
2584-
// in order to make visible the left items
2585-
// of the JOIN tree.
2586-
scopes.put(node, usingScope);
25872591
return sbc;
25882592
// ----- FLINK MODIFICATION END -----
25892593

@@ -2614,6 +2618,27 @@ private SqlNode registerFrom(
26142618
scopes.put(node, getSelectScope(call1.operand(0)));
26152619
return newNode;
26162620
}
2621+
2622+
// Related to CALCITE-4077
2623+
// ----- FLINK MODIFICATION BEGIN -----
2624+
FlinkSqlCallBinding binding =
2625+
new FlinkSqlCallBinding(this, getEmptyScope(), call1);
2626+
if (op instanceof SqlVectorSearchTableFunction
2627+
&& binding.operand(0)
2628+
.isA(
2629+
new HashSet<>(
2630+
Collections.singletonList(SqlKind.SELECT)))) {
2631+
boolean queryColumnIsNotLiteral =
2632+
binding.operand(2).getKind() != SqlKind.LITERAL;
2633+
if (!queryColumnIsNotLiteral && !lateral) {
2634+
throw new ValidationException(
2635+
"The query column is not literal, please use LATERAL TABLE to run VECTOR_SEARCH.");
2636+
}
2637+
SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0));
2638+
scopes.put(enclosingNode, scope);
2639+
return newNode;
2640+
}
2641+
// ----- FLINK MODIFICATION END -----
26172642
}
26182643
// Put the usingScope which can be a JoinScope
26192644
// or a SelectScope, in order to see the left items

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
2323
import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
2424
import org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
25+
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
2526
import org.apache.flink.table.planner.plan.type.FlinkReturnTypes;
2627
import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker;
2728

@@ -1328,6 +1329,9 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
13281329
// MODEL TABLE FUNCTIONS
13291330
public static final SqlFunction ML_EVALUATE = new SqlMLEvaluateTableFunction();
13301331

1332+
// SEARCH FUNCTIONS
1333+
public static final SqlFunction VECTOR_SEARCH = new SqlVectorSearchTableFunction();
1334+
13311335
// Catalog Functions
13321336
public static final SqlFunction CURRENT_DATABASE =
13331337
BuiltInSqlFunction.newBuilder()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.sql.ml;
20+
21+
import org.apache.flink.table.api.ValidationException;
22+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
23+
import org.apache.flink.table.types.logical.ArrayType;
24+
import org.apache.flink.table.types.logical.LogicalType;
25+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
26+
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
27+
28+
import org.apache.calcite.rel.type.RelDataType;
29+
import org.apache.calcite.rel.type.RelDataTypeFactory;
30+
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
31+
import org.apache.calcite.sql.SqlCall;
32+
import org.apache.calcite.sql.SqlCallBinding;
33+
import org.apache.calcite.sql.SqlFunction;
34+
import org.apache.calcite.sql.SqlFunctionCategory;
35+
import org.apache.calcite.sql.SqlIdentifier;
36+
import org.apache.calcite.sql.SqlKind;
37+
import org.apache.calcite.sql.SqlNode;
38+
import org.apache.calcite.sql.SqlOperandCountRange;
39+
import org.apache.calcite.sql.SqlOperator;
40+
import org.apache.calcite.sql.SqlOperatorBinding;
41+
import org.apache.calcite.sql.SqlTableFunction;
42+
import org.apache.calcite.sql.type.ReturnTypes;
43+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
44+
import org.apache.calcite.sql.type.SqlOperandMetadata;
45+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
46+
import org.apache.calcite.sql.type.SqlTypeName;
47+
import org.apache.calcite.sql.validate.SqlNameMatcher;
48+
import org.apache.calcite.util.Util;
49+
import org.checkerframework.checker.nullness.qual.Nullable;
50+
51+
import java.util.Arrays;
52+
import java.util.Collections;
53+
import java.util.List;
54+
import java.util.Optional;
55+
56+
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
57+
58+
/**
59+
* {@link SqlVectorSearchTableFunction} implements an operator for search.
60+
*
61+
* <p>It allows four parameters:
62+
*
63+
* <ol>
64+
* <li>a table
65+
* <li>a descriptor to provide a column name from the input table
66+
* <li>a query column from the left table
67+
* <li>a literal value for top k
68+
* </ol>
69+
*/
70+
public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction {
71+
72+
private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE";
73+
private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH";
74+
private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY";
75+
private static final String PARAM_TOP_K = "TOP_K";
76+
77+
private static final String OUTPUT_SCORE = "score";
78+
79+
public SqlVectorSearchTableFunction() {
80+
super(
81+
"VECTOR_SEARCH",
82+
SqlKind.OTHER_FUNCTION,
83+
ReturnTypes.CURSOR,
84+
null,
85+
new OperandMetadataImpl(),
86+
SqlFunctionCategory.SYSTEM);
87+
}
88+
89+
@Override
90+
public SqlReturnTypeInference getRowTypeInference() {
91+
return new SqlReturnTypeInference() {
92+
@Override
93+
public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) {
94+
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
95+
final RelDataType inputRowType = opBinding.getOperandType(0);
96+
97+
return typeFactory
98+
.builder()
99+
.kind(inputRowType.getStructKind())
100+
.addAll(inputRowType.getFieldList())
101+
.addAll(
102+
SqlValidatorUtils.makeOutputUnique(
103+
inputRowType.getFieldList(),
104+
Collections.singletonList(
105+
new RelDataTypeFieldImpl(
106+
OUTPUT_SCORE,
107+
0,
108+
typeFactory.createSqlType(
109+
SqlTypeName.DOUBLE)))))
110+
.build();
111+
}
112+
};
113+
}
114+
115+
@Override
116+
public boolean argumentMustBeScalar(int ordinal) {
117+
return ordinal != 0;
118+
}
119+
120+
private static class OperandMetadataImpl implements SqlOperandMetadata {
121+
122+
private static final List<String> PARAMETERS =
123+
Collections.unmodifiableList(
124+
Arrays.asList(
125+
PARAM_SEARCH_TABLE,
126+
PARAM_COLUMN_TO_SEARCH,
127+
PARAM_COLUMN_TO_QUERY,
128+
PARAM_TOP_K));
129+
130+
@Override
131+
public List<RelDataType> paramTypes(RelDataTypeFactory relDataTypeFactory) {
132+
return Collections.nCopies(
133+
PARAMETERS.size(), relDataTypeFactory.createSqlType(SqlTypeName.ANY));
134+
}
135+
136+
@Override
137+
public List<String> paramNames() {
138+
return PARAMETERS;
139+
}
140+
141+
@Override
142+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
143+
// check vector table contains descriptor columns
144+
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
145+
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
146+
callBinding, throwOnFailure);
147+
}
148+
149+
List<SqlNode> operands = callBinding.operands();
150+
// check descriptor has one column
151+
SqlCall descriptor = (SqlCall) operands.get(1);
152+
List<SqlNode> descriptorCols = descriptor.getOperandList();
153+
if (descriptorCols.size() != 1) {
154+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
155+
Optional.of(
156+
new ValidationException(
157+
String.format(
158+
"Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand %s.",
159+
descriptor))),
160+
throwOnFailure);
161+
}
162+
163+
// check descriptor type is ARRAY<FLOAT> or ARRAY<DOUBLE>
164+
RelDataType searchTableType = callBinding.getOperandType(0);
165+
SqlNameMatcher matcher = callBinding.getValidator().getCatalogReader().nameMatcher();
166+
SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0);
167+
String descriptorColName =
168+
columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names);
169+
int index = matcher.indexOf(searchTableType.getFieldNames(), descriptorColName);
170+
RelDataType targetType = searchTableType.getFieldList().get(index).getType();
171+
LogicalType targetLogicalType = toLogicalType(targetType);
172+
173+
if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY)
174+
&& ((ArrayType) (targetLogicalType))
175+
.getElementType()
176+
.isAnyOf(LogicalTypeRoot.FLOAT, LogicalTypeRoot.DOUBLE))) {
177+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
178+
Optional.of(
179+
new ValidationException(
180+
String.format(
181+
"Expect search column `%s` type is ARRAY<FLOAT> or ARRAY<DOUBLE>, but its type is %s.",
182+
columnName, targetType))),
183+
throwOnFailure);
184+
}
185+
186+
// check query type is ARRAY<FLOAT> or ARRAY<DOUBLE>
187+
LogicalType sourceLogicalType = toLogicalType(callBinding.getOperandType(2));
188+
if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) {
189+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
190+
Optional.of(
191+
new ValidationException(
192+
String.format(
193+
"Can not cast the query column type %s to target type %s. Please keep the query column type is same to the search column type.",
194+
sourceLogicalType, targetType))),
195+
throwOnFailure);
196+
}
197+
198+
// check topK is literal
199+
LogicalType topKType = toLogicalType(callBinding.getOperandType(3));
200+
if (!operands.get(3).getKind().equals(SqlKind.LITERAL)
201+
|| !topKType.is(LogicalTypeRoot.INTEGER)) {
202+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
203+
Optional.of(
204+
new ValidationException(
205+
String.format(
206+
"Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.",
207+
operands.get(3), topKType))),
208+
throwOnFailure);
209+
}
210+
211+
return true;
212+
}
213+
214+
@Override
215+
public SqlOperandCountRange getOperandCountRange() {
216+
return SqlOperandCountRanges.between(4, 4);
217+
}
218+
219+
@Override
220+
public String getAllowedSignatures(SqlOperator op, String opName) {
221+
return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)";
222+
}
223+
224+
@Override
225+
public Consistency getConsistency() {
226+
return Consistency.NONE;
227+
}
228+
229+
@Override
230+
public boolean isOptional(int i) {
231+
return false;
232+
}
233+
234+
@Override
235+
public boolean isFixedParameters() {
236+
return true;
237+
}
238+
}
239+
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,33 @@ private static void adjustTypeForMultisetConstructor(
160160
/**
161161
* Make output field names unique from input field names by appending index. For example, Input
162162
* has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling
163-
* this function, new output field names will be {@code b0, c0, d}. Duplicate names are not
164-
* checked inside input and output itself.
163+
* this function, new output field names will be {@code b0, c0, d}.
164+
*
165+
* <p>We assume that input fields in the input parameter are uniquely named, just as the output
166+
* fields in the output parameter are.
165167
*
166168
* @param input Input fields
167169
* @param output Output fields
168-
* @return
170+
* @return output fields with unique names.
169171
*/
170172
public static List<RelDataTypeField> makeOutputUnique(
171173
List<RelDataTypeField> input, List<RelDataTypeField> output) {
172-
final Set<String> inputFieldNames = new HashSet<>();
174+
final Set<String> uniqueNames = new HashSet<>();
173175
for (RelDataTypeField field : input) {
174-
inputFieldNames.add(field.getName());
176+
uniqueNames.add(field.getName());
175177
}
176178

177179
List<RelDataTypeField> result = new ArrayList<>();
178180
for (RelDataTypeField field : output) {
179181
String fieldName = field.getName();
180-
if (inputFieldNames.contains(fieldName)) {
181-
fieldName += "0"; // Append index to make it unique
182+
int count = 0;
183+
String candidate = fieldName;
184+
while (uniqueNames.contains(candidate)) {
185+
candidate = fieldName + count;
186+
count++;
182187
}
183-
result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType()));
188+
uniqueNames.add(candidate);
189+
result.add(new RelDataTypeFieldImpl(candidate, field.getIndex(), field.getType()));
184190
}
185191
return result;
186192
}

0 commit comments

Comments
 (0)