Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
*/
package org.apache.calcite.sql.validate;

import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -177,10 +179,12 @@
*
* <p>Lines 2571 ~ 2588, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0.
*
* <p>Lines 3895 ~ 3899, 6574 ~ 6580 Flink improves Optimize the retrieval of sub-operands in
* <p>Line 2618 ~2631, set the correct scope for VECTOR_SEARCH.
*
* <p>Lines 3920 ~ 3925, 6599 ~ 6606 Flink improves Optimize the retrieval of sub-operands in
* SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}.
*
* <p>Lines 5315 ~ 5321, FLINK-24352 Add null check for temporal table check on SqlSnapshot.
* <p>Lines 5340 ~ 5347, FLINK-24352 Add null check for temporal table check on SqlSnapshot.
*/
public class SqlValidatorImpl implements SqlValidatorWithHints {
// ~ Static fields/initializers ---------------------------------------------
Expand Down Expand Up @@ -2570,6 +2574,10 @@ private SqlNode registerFrom(
case LATERAL:
// ----- FLINK MODIFICATION BEGIN -----
SqlBasicCall sbc = (SqlBasicCall) node;
// Put the usingScope which is a JoinScope,
// in order to make visible the left items
// of the JOIN tree.
scopes.put(node, usingScope);
registerFrom(
parentScope,
usingScope,
Expand All @@ -2580,10 +2588,6 @@ private SqlNode registerFrom(
extendList,
forceNullable,
true);
// Put the usingScope which is a JoinScope,
// in order to make visible the left items
// of the JOIN tree.
scopes.put(node, usingScope);
return sbc;
// ----- FLINK MODIFICATION END -----

Expand Down Expand Up @@ -2614,6 +2618,27 @@ private SqlNode registerFrom(
scopes.put(node, getSelectScope(call1.operand(0)));
return newNode;
}

// Related to CALCITE-4077
// ----- FLINK MODIFICATION BEGIN -----
FlinkSqlCallBinding binding =
new FlinkSqlCallBinding(this, getEmptyScope(), call1);
if (op instanceof SqlVectorSearchTableFunction
&& binding.operand(0)
.isA(
new HashSet<>(
Collections.singletonList(SqlKind.SELECT)))) {
boolean queryColumnIsNotLiteral =
binding.operand(2).getKind() != SqlKind.LITERAL;
if (!queryColumnIsNotLiteral && !lateral) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't lateral always needed? What's the syntax for literal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LATERAL is not always needed if the SqlCall doesn't contain correlation. For exmaple, users can use the following statement to search.

SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))

Here, the query input is ARRAY[1.5, 2.0].

CC

"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it's the syntax but your test expects an exception, so it's not a valid sql?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a valid sql. But we don't add literal related rule in physical phase, so planner can not translate the sql correctly. But the exception indicates the planner can parse the statement correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I guess it's more clear to have something like https://github.com/apache/flink/pull/26553/files#diff-19970e8600e459e820e1310beed925a10450f695698257d85648e8114b5e5aaeR92 to indicate it's not invalid case.

throw new ValidationException(
"The query column is not literal, please use LATERAL TABLE to run VECTOR_SEARCH.");
}
SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks there's no need to cast based on line 2618?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It requires. Because line 2618 uses SqlCall to get operand(its return type is <S extends SqlNode> S ), but here we uses SqlCallBinding to get operand(its return type is SqlNode). Therefore, we still need cast here. 0.0

We can not use SqlCall to extract operand because VECTOR_SEARCH allows user to use named argument, which means the operands is out of order and we need to use the name to reorder the operands.

scopes.put(enclosingNode, scope);
return newNode;
}
// ----- FLINK MODIFICATION END -----
}
// Put the usingScope which can be a JoinScope
// or a SelectScope, in order to see the left items
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
import org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
import org.apache.flink.table.planner.plan.type.FlinkReturnTypes;
import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker;

Expand Down Expand Up @@ -1328,6 +1329,9 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
// MODEL TABLE FUNCTIONS
public static final SqlFunction ML_EVALUATE = new SqlMLEvaluateTableFunction();

// SEARCH FUNCTIONS
public static final SqlFunction VECTOR_SEARCH = new SqlVectorSearchTableFunction();

// Catalog Functions
public static final SqlFunction CURRENT_DATABASE =
BuiltInSqlFunction.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.sql.ml;

import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlTableFunction;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;

/**
* {@link SqlVectorSearchTableFunction} implements an operator for search.
*
* <p>It allows four parameters:
*
* <ol>
* <li>a table
* <li>a descriptor to provide a column name from the input table
* <li>a query column from the left table
* <li>a literal value for top k
* </ol>
*/
public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction {

private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE";
private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH";
private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY";
private static final String PARAM_TOP_K = "TOP_K";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the optional config param from FLIP?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But I plan to add this in the https://issues.apache.org/jira/browse/FLINK-38430


private static final String OUTPUT_SCORE = "score";

public SqlVectorSearchTableFunction() {
super(
"VECTOR_SEARCH",
SqlKind.OTHER_FUNCTION,
ReturnTypes.CURSOR,
null,
new OperandMetadataImpl(),
SqlFunctionCategory.SYSTEM);
}

@Override
public SqlReturnTypeInference getRowTypeInference() {
return new SqlReturnTypeInference() {
@Override
public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final RelDataType inputRowType = opBinding.getOperandType(0);

return typeFactory
.builder()
.kind(inputRowType.getStructKind())
.addAll(inputRowType.getFieldList())
.addAll(
SqlValidatorUtils.makeOutputUnique(
inputRowType.getFieldList(),
Collections.singletonList(
new RelDataTypeFieldImpl(
OUTPUT_SCORE,
0,
typeFactory.createSqlType(
SqlTypeName.DOUBLE)))))
.build();
}
};
}

@Override
public boolean argumentMustBeScalar(int ordinal) {
return ordinal != 0;
}

private static class OperandMetadataImpl implements SqlOperandMetadata {

private static final List<String> PARAMETERS =
Collections.unmodifiableList(
Arrays.asList(
PARAM_SEARCH_TABLE,
PARAM_COLUMN_TO_SEARCH,
PARAM_COLUMN_TO_QUERY,
PARAM_TOP_K));
Copy link
Contributor

@davidradl davidradl Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I an see what TOP_K is from googling. It would be useful to add in the documentation describing the parameters with this change - including the default. I wonder if we should add paging, to be able to handle a large number of results or is this not done with vector databases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, the document will be added when all tasks are finished. The current API doesn't support paging, I think we can leave this as the future work.


@Override
public List<RelDataType> paramTypes(RelDataTypeFactory relDataTypeFactory) {
return Collections.nCopies(
PARAMETERS.size(), relDataTypeFactory.createSqlType(SqlTypeName.ANY));
}

@Override
public List<String> paramNames() {
return PARAMETERS;
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
// check vector table contains descriptor columns
if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) {
return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(
callBinding, throwOnFailure);
}

List<SqlNode> operands = callBinding.operands();
// check descriptor has one column
SqlCall descriptor = (SqlCall) operands.get(1);
List<SqlNode> descriptorCols = descriptor.getOperandList();
if (descriptorCols.size() != 1) {
return SqlValidatorUtils.throwExceptionOrReturnFalse(
Optional.of(
new ValidationException(
String.format(
"Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand %s.",
descriptor))),
throwOnFailure);
}

// check descriptor type is ARRAY<FLOAT> or ARRAY<DOUBLE>
RelDataType searchTableType = callBinding.getOperandType(0);
SqlNameMatcher matcher = callBinding.getValidator().getCatalogReader().nameMatcher();
SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0);
String descriptorColName =
columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names);
int index = matcher.indexOf(searchTableType.getFieldNames(), descriptorColName);
RelDataType targetType = searchTableType.getFieldList().get(index).getType();
LogicalType targetLogicalType = toLogicalType(targetType);

if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY)
&& ((ArrayType) (targetLogicalType))
.getElementType()
.isAnyOf(LogicalTypeRoot.FLOAT, LogicalTypeRoot.DOUBLE))) {
return SqlValidatorUtils.throwExceptionOrReturnFalse(
Optional.of(
new ValidationException(
String.format(
"Expect search column `%s` type is ARRAY<FLOAT> or ARRAY<DOUBLE>, but its type is %s.",
columnName, targetType))),
throwOnFailure);
}

// check query type is ARRAY<FLOAT> or ARRAY<DOUBLE>
LogicalType sourceLogicalType = toLogicalType(callBinding.getOperandType(2));
if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) {
return SqlValidatorUtils.throwExceptionOrReturnFalse(
Optional.of(
new ValidationException(
String.format(
"Can not cast the query column type %s to target type %s. Please keep the query column type is same to the search column type.",
sourceLogicalType, targetType))),
throwOnFailure);
}

// check topK is literal
LogicalType topKType = toLogicalType(callBinding.getOperandType(3));
if (!operands.get(3).getKind().equals(SqlKind.LITERAL)
|| !topKType.is(LogicalTypeRoot.INTEGER)) {
return SqlValidatorUtils.throwExceptionOrReturnFalse(
Optional.of(
new ValidationException(
String.format(
"Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.",
operands.get(3), topKType))),
throwOnFailure);
}

return true;
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return SqlOperandCountRanges.between(4, 4);
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)";
}

@Override
public Consistency getConsistency() {
return Consistency.NONE;
}

@Override
public boolean isOptional(int i) {
return false;
}

@Override
public boolean isFixedParameters() {
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,33 @@ private static void adjustTypeForMultisetConstructor(
/**
* Make output field names unique from input field names by appending index. For example, Input
* has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling
* this function, new output field names will be {@code b0, c0, d}. Duplicate names are not
* checked inside input and output itself.
* this function, new output field names will be {@code b0, c0, d}.
*
* <p>We assume that input fields in the input parameter are uniquely named, just as the output
* fields in the output parameter are.
*
* @param input Input fields
* @param output Output fields
* @return
* @return output fields with unique names.
*/
public static List<RelDataTypeField> makeOutputUnique(
List<RelDataTypeField> input, List<RelDataTypeField> output) {
final Set<String> inputFieldNames = new HashSet<>();
final Set<String> uniqueNames = new HashSet<>();
for (RelDataTypeField field : input) {
inputFieldNames.add(field.getName());
uniqueNames.add(field.getName());
}

List<RelDataTypeField> result = new ArrayList<>();
for (RelDataTypeField field : output) {
String fieldName = field.getName();
if (inputFieldNames.contains(fieldName)) {
fieldName += "0"; // Append index to make it unique
int count = 0;
String candidate = fieldName;
while (uniqueNames.contains(candidate)) {
candidate = fieldName + count;
count++;
}
result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType()));
uniqueNames.add(candidate);
result.add(new RelDataTypeFieldImpl(candidate, field.getIndex(), field.getType()));
}
return result;
}
Expand Down
Loading