-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[FLINK-38424][planner] Support to parse VECTOR_SEARCH function #27039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,9 @@ | |
| */ | ||
| package org.apache.calcite.sql.validate; | ||
|
|
||
| import org.apache.flink.table.api.ValidationException; | ||
| import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding; | ||
| import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; | ||
|
|
||
| import com.google.common.base.Preconditions; | ||
| import com.google.common.collect.ImmutableList; | ||
|
|
@@ -177,10 +179,12 @@ | |
| * | ||
| * <p>Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0. | ||
| * | ||
| * <p>Lines 3895 ~ 3899, 6574 ~ 6580 Flink improves Optimize the retrieval of sub-operands in | ||
| * <p>Line 2618 ~2631, set the correct scope for VECTOR_SEARCH. | ||
| * | ||
| * <p>Lines 3920 ~ 3925, 6599 ~ 6606 Flink improves Optimize the retrieval of sub-operands in | ||
| * SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}. | ||
| * | ||
| * <p>Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check on SqlSnapshot. | ||
| * <p>Lines 5340 ~ 5347, FLINK-24352 Add null check for temporal table check on SqlSnapshot. | ||
| */ | ||
| public class SqlValidatorImpl implements SqlValidatorWithHints { | ||
| // ~ Static fields/initializers --------------------------------------------- | ||
|
|
@@ -2570,6 +2574,10 @@ private SqlNode registerFrom( | |
| case LATERAL: | ||
| // ----- FLINK MODIFICATION BEGIN ----- | ||
| SqlBasicCall sbc = (SqlBasicCall) node; | ||
| // Put the usingScope which is a JoinScope, | ||
| // in order to make visible the left items | ||
| // of the JOIN tree. | ||
| scopes.put(node, usingScope); | ||
| registerFrom( | ||
| parentScope, | ||
| usingScope, | ||
|
|
@@ -2580,10 +2588,6 @@ private SqlNode registerFrom( | |
| extendList, | ||
| forceNullable, | ||
| true); | ||
| // Put the usingScope which is a JoinScope, | ||
| // in order to make visible the left items | ||
| // of the JOIN tree. | ||
| scopes.put(node, usingScope); | ||
| return sbc; | ||
| // ----- FLINK MODIFICATION END ----- | ||
|
|
||
|
|
@@ -2614,6 +2618,27 @@ private SqlNode registerFrom( | |
| scopes.put(node, getSelectScope(call1.operand(0))); | ||
| return newNode; | ||
| } | ||
|
|
||
| // Related to CALCITE-4077 | ||
| // ----- FLINK MODIFICATION BEGIN ----- | ||
| FlinkSqlCallBinding binding = | ||
| new FlinkSqlCallBinding(this, getEmptyScope(), call1); | ||
| if (op instanceof SqlVectorSearchTableFunction | ||
| && binding.operand(0) | ||
| .isA( | ||
| new HashSet<>( | ||
| Collections.singletonList(SqlKind.SELECT)))) { | ||
| boolean queryColumnIsNotLiteral = | ||
| binding.operand(2).getKind() != SqlKind.LITERAL; | ||
| if (!queryColumnIsNotLiteral && !lateral) { | ||
| throw new ValidationException( | ||
| "The query column is not literal, please use LATERAL TABLE to run VECTOR_SEARCH."); | ||
| } | ||
| SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: looks there's no need to cast based on line 2618?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It requires. Because line 2618 uses SqlCall to get operand(its return type is We can not use SqlCall to extract operand because |
||
| scopes.put(enclosingNode, scope); | ||
| return newNode; | ||
| } | ||
| // ----- FLINK MODIFICATION END ----- | ||
| } | ||
| // Put the usingScope which can be a JoinScope | ||
| // or a SelectScope, in order to see the left items | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.flink.table.planner.functions.sql.ml; | ||
|
|
||
| import org.apache.flink.table.api.ValidationException; | ||
| import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; | ||
| import org.apache.flink.table.types.logical.ArrayType; | ||
| import org.apache.flink.table.types.logical.LogicalType; | ||
| import org.apache.flink.table.types.logical.LogicalTypeRoot; | ||
| import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; | ||
|
|
||
| import org.apache.calcite.rel.type.RelDataType; | ||
| import org.apache.calcite.rel.type.RelDataTypeFactory; | ||
| import org.apache.calcite.rel.type.RelDataTypeFieldImpl; | ||
| import org.apache.calcite.sql.SqlCall; | ||
| import org.apache.calcite.sql.SqlCallBinding; | ||
| import org.apache.calcite.sql.SqlFunction; | ||
| import org.apache.calcite.sql.SqlFunctionCategory; | ||
| import org.apache.calcite.sql.SqlIdentifier; | ||
| import org.apache.calcite.sql.SqlKind; | ||
| import org.apache.calcite.sql.SqlNode; | ||
| import org.apache.calcite.sql.SqlOperandCountRange; | ||
| import org.apache.calcite.sql.SqlOperator; | ||
| import org.apache.calcite.sql.SqlOperatorBinding; | ||
| import org.apache.calcite.sql.SqlTableFunction; | ||
| import org.apache.calcite.sql.type.ReturnTypes; | ||
| import org.apache.calcite.sql.type.SqlOperandCountRanges; | ||
| import org.apache.calcite.sql.type.SqlOperandMetadata; | ||
| import org.apache.calcite.sql.type.SqlReturnTypeInference; | ||
| import org.apache.calcite.sql.type.SqlTypeName; | ||
| import org.apache.calcite.sql.validate.SqlNameMatcher; | ||
| import org.apache.calcite.util.Util; | ||
| import org.checkerframework.checker.nullness.qual.Nullable; | ||
|
|
||
| import java.util.Arrays; | ||
| import java.util.Collections; | ||
| import java.util.List; | ||
| import java.util.Optional; | ||
|
|
||
| import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; | ||
|
|
||
| /** | ||
| * {@link SqlVectorSearchTableFunction} implements an operator for search. | ||
| * | ||
| * <p>It allows four parameters: | ||
| * | ||
| * <ol> | ||
| * <li>a table | ||
| * <li>a descriptor to provide a column name from the input table | ||
| * <li>a query column from the left table | ||
| * <li>a literal value for top k | ||
| * </ol> | ||
| */ | ||
| public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction { | ||
|
|
||
| private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE"; | ||
| private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH"; | ||
| private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY"; | ||
| private static final String PARAM_TOP_K = "TOP_K"; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing the optional config param from FLIP?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. But I plan to add this in the https://issues.apache.org/jira/browse/FLINK-38430 |
||
|
|
||
| private static final String OUTPUT_SCORE = "score"; | ||
|
|
||
| public SqlVectorSearchTableFunction() { | ||
| super( | ||
| "VECTOR_SEARCH", | ||
| SqlKind.OTHER_FUNCTION, | ||
| ReturnTypes.CURSOR, | ||
| null, | ||
| new OperandMetadataImpl(), | ||
| SqlFunctionCategory.SYSTEM); | ||
| } | ||
|
|
||
| @Override | ||
| public SqlReturnTypeInference getRowTypeInference() { | ||
| return new SqlReturnTypeInference() { | ||
| @Override | ||
| public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) { | ||
| final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); | ||
| final RelDataType inputRowType = opBinding.getOperandType(0); | ||
|
|
||
| return typeFactory | ||
| .builder() | ||
| .kind(inputRowType.getStructKind()) | ||
| .addAll(inputRowType.getFieldList()) | ||
| .addAll( | ||
| SqlValidatorUtils.makeOutputUnique( | ||
| inputRowType.getFieldList(), | ||
| Collections.singletonList( | ||
| new RelDataTypeFieldImpl( | ||
| OUTPUT_SCORE, | ||
| 0, | ||
| typeFactory.createSqlType( | ||
| SqlTypeName.DOUBLE))))) | ||
| .build(); | ||
| } | ||
| }; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean argumentMustBeScalar(int ordinal) { | ||
| return ordinal != 0; | ||
| } | ||
|
|
||
| private static class OperandMetadataImpl implements SqlOperandMetadata { | ||
|
|
||
| private static final List<String> PARAMETERS = | ||
| Collections.unmodifiableList( | ||
| Arrays.asList( | ||
| PARAM_SEARCH_TABLE, | ||
| PARAM_COLUMN_TO_SEARCH, | ||
| PARAM_COLUMN_TO_QUERY, | ||
| PARAM_TOP_K)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I an see what TOP_K is from googling. It would be useful to add in the documentation describing the parameters with this change - including the default. I wonder if we should add paging, to be able to handle a large number of results or is this not done with vector databases.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course, the document will be added when all tasks are finished. The current API doesn't support paging, I think we can leave this as the future work. |
||
|
|
||
| @Override | ||
| public List<RelDataType> paramTypes(RelDataTypeFactory relDataTypeFactory) { | ||
| return Collections.nCopies( | ||
| PARAMETERS.size(), relDataTypeFactory.createSqlType(SqlTypeName.ANY)); | ||
| } | ||
|
|
||
| @Override | ||
| public List<String> paramNames() { | ||
| return PARAMETERS; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { | ||
| // check vector table contains descriptor columns | ||
| if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { | ||
| return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( | ||
| callBinding, throwOnFailure); | ||
| } | ||
|
|
||
| List<SqlNode> operands = callBinding.operands(); | ||
| // check descriptor has one column | ||
| SqlCall descriptor = (SqlCall) operands.get(1); | ||
| List<SqlNode> descriptorCols = descriptor.getOperandList(); | ||
| if (descriptorCols.size() != 1) { | ||
| return SqlValidatorUtils.throwExceptionOrReturnFalse( | ||
| Optional.of( | ||
| new ValidationException( | ||
| String.format( | ||
| "Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand %s.", | ||
| descriptor))), | ||
| throwOnFailure); | ||
| } | ||
|
|
||
| // check descriptor type is ARRAY<FLOAT> or ARRAY<DOUBLE> | ||
| RelDataType searchTableType = callBinding.getOperandType(0); | ||
| SqlNameMatcher matcher = callBinding.getValidator().getCatalogReader().nameMatcher(); | ||
| SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0); | ||
| String descriptorColName = | ||
| columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names); | ||
| int index = matcher.indexOf(searchTableType.getFieldNames(), descriptorColName); | ||
| RelDataType targetType = searchTableType.getFieldList().get(index).getType(); | ||
| LogicalType targetLogicalType = toLogicalType(targetType); | ||
|
|
||
| if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY) | ||
| && ((ArrayType) (targetLogicalType)) | ||
| .getElementType() | ||
| .isAnyOf(LogicalTypeRoot.FLOAT, LogicalTypeRoot.DOUBLE))) { | ||
| return SqlValidatorUtils.throwExceptionOrReturnFalse( | ||
| Optional.of( | ||
| new ValidationException( | ||
| String.format( | ||
| "Expect search column `%s` type is ARRAY<FLOAT> or ARRAY<DOUBLE>, but its type is %s.", | ||
| columnName, targetType))), | ||
| throwOnFailure); | ||
| } | ||
|
|
||
| // check query type is ARRAY<FLOAT> or ARRAY<DOUBLE> | ||
| LogicalType sourceLogicalType = toLogicalType(callBinding.getOperandType(2)); | ||
| if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) { | ||
| return SqlValidatorUtils.throwExceptionOrReturnFalse( | ||
| Optional.of( | ||
| new ValidationException( | ||
| String.format( | ||
| "Can not cast the query column type %s to target type %s. Please keep the query column type is same to the search column type.", | ||
| sourceLogicalType, targetType))), | ||
| throwOnFailure); | ||
| } | ||
|
|
||
| // check topK is literal | ||
| LogicalType topKType = toLogicalType(callBinding.getOperandType(3)); | ||
| if (!operands.get(3).getKind().equals(SqlKind.LITERAL) | ||
| || !topKType.is(LogicalTypeRoot.INTEGER)) { | ||
| return SqlValidatorUtils.throwExceptionOrReturnFalse( | ||
| Optional.of( | ||
| new ValidationException( | ||
| String.format( | ||
| "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.", | ||
| operands.get(3), topKType))), | ||
| throwOnFailure); | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| @Override | ||
| public SqlOperandCountRange getOperandCountRange() { | ||
| return SqlOperandCountRanges.between(4, 4); | ||
| } | ||
|
|
||
| @Override | ||
| public String getAllowedSignatures(SqlOperator op, String opName) { | ||
| return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)"; | ||
| } | ||
|
|
||
| @Override | ||
| public Consistency getConsistency() { | ||
| return Consistency.NONE; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean isOptional(int i) { | ||
| return false; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean isFixedParameters() { | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't
lateralalways needed? What's the syntax for literal?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LATERALis not always needed if the SqlCall doesn't contain correlation. For exmaple, users can use the following statement to search.Here, the query input is
ARRAY[1.5, 2.0].CC
flink/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java
Line 95 in 06c61b8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it's the syntax but your test expects an exception, so it's not a valid sql?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a valid sql. But we don't add literal related rule in physical phase, so planner can not translate the sql correctly. But the exception indicates the planner can parse the statement correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I guess it's more clear to have something like https://github.com/apache/flink/pull/26553/files#diff-19970e8600e459e820e1310beed925a10450f695698257d85648e8114b5e5aaeR92 to indicate it's not invalid case.