Skip to content

Commit

Permalink
Add score functions to analyzer test
Browse files Browse the repository at this point in the history
Signed-off-by: acarbonetto <[email protected]>
  • Loading branch information
acarbonetto committed Aug 22, 2023
1 parent 675aa08 commit 32ce604
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 44 deletions.
44 changes: 0 additions & 44 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -277,50 +277,6 @@ public void filter_relation_with_multiple_tables() {
AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1))));
}

@Test
public void analyze_filter_visit_score_function() {
UnresolvedPlan unresolvedPlan =
AstDSL.filter(
AstDSL.relation("schema"),
new ScoreFunction(
AstDSL.function(
"match_phrase_prefix",
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("search query")),
AstDSL.unresolvedArg("boost", stringLiteral("3"))),
AstDSL.doubleLiteral(1.0)));
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
DSL.match_phrase_prefix(
DSL.namedArgument("field", "field_value1"),
DSL.namedArgument("query", "search query"),
DSL.namedArgument("boost", "3.0"))),
unresolvedPlan);

LogicalPlan logicalPlan = analyze(unresolvedPlan);
OpenSearchFunction relevanceQuery =
(OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition();
assertEquals(true, relevanceQuery.isScoreTracked());
}

@Test
public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() {
UnresolvedPlan unresolvedPlan =
AstDSL.filter(
AstDSL.relation("schema"),
new ScoreFunction(
AstDSL.function(
"match_phrase_prefix",
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("search query")),
AstDSL.unresolvedArg("boost", stringLiteral("3"))),
AstDSL.stringLiteral("3.0")));
SemanticCheckException exception =
assertThrows(SemanticCheckException.class, () -> analyze(unresolvedPlan));
assertEquals("Expected boost type 'DOUBLE' but got 'STRING'", exception.getMessage());
}

@Test
public void head_relation() {
assertAnalyzeEqual(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.analysis;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;

import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunction;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalPlan;

@ExtendWith(MockitoExtension.class)
class OpenSearchAnalyzerTest extends AnalyzerTestBase {

@Mock
private BuiltinFunctionRepository builtinFunctionRepository;

@Override
protected ExpressionAnalyzer expressionAnalyzer() {
return new ExpressionAnalyzer(builtinFunctionRepository);
}

@BeforeEach
private void setup() {
this.expressionAnalyzer = expressionAnalyzer();
this.analyzer = analyzer(this.expressionAnalyzer, dataSourceService);
}

@Test
public void analyze_filter_visit_score_function() {

// setup
OpenSearchFunction scoreFunction = new OpenSearchFunction(
new FunctionName("match_phrase_prefix"), List.of());
when(builtinFunctionRepository.compile(any(), any(), any())).thenReturn(scoreFunction);

UnresolvedPlan unresolvedPlan =
AstDSL.filter(
AstDSL.relation("schema"),
new ScoreFunction(
AstDSL.function(
"match_phrase_prefix",
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("search query")),
AstDSL.unresolvedArg("boost", stringLiteral("3"))),
AstDSL.doubleLiteral(1.0)));

// test
LogicalPlan logicalPlan = analyze(unresolvedPlan);
OpenSearchFunction relevanceQuery =
(OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition();

// verify
assertEquals(true, relevanceQuery.isScoreTracked());
}

@Test
public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() {

// setup
UnresolvedPlan unresolvedPlan =
AstDSL.filter(
AstDSL.relation("schema"),
new ScoreFunction(
AstDSL.function(
"match_phrase_prefix",
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("search query")),
AstDSL.unresolvedArg("boost", stringLiteral("3"))),
AstDSL.stringLiteral("3.0")));

// Test
SemanticCheckException exception =
assertThrows(SemanticCheckException.class, () -> analyze(unresolvedPlan));

// Verify
assertEquals("Expected boost type 'DOUBLE' but got 'STRING'", exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ public void testGetDataSourceMetadataWithBasicAuth() {
@SneakyThrows
@Test
public void testGetDataSourceMetadataList() {
Mockito.when(clusterService.getClusterApplierService().isInitialClusterStateSet()).thenReturn(true);
Mockito.when(clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME))
.thenReturn(true);
Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture);
Expand All @@ -233,6 +234,7 @@ public void testGetDataSourceMetadataList() {
@SneakyThrows
@Test
public void testGetDataSourceMetadataListWithNoIndex() {
Mockito.when(clusterService.getClusterApplierService().isInitialClusterStateSet()).thenReturn(true);
Mockito.when(clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME))
.thenReturn(Boolean.FALSE);
Mockito.when(client.admin().indices().create(ArgumentMatchers.any()))
Expand Down

0 comments on commit 32ce604

Please sign in to comment.