Skip to content

Commit

Permalink
Proof of Concept: request routing shard through SQL partition
Browse files Browse the repository at this point in the history
Signed-off-by: acarbonetto <[email protected]>
  • Loading branch information
acarbonetto committed Jul 5, 2023
1 parent 1c0b35d commit 1a44d89
Show file tree
Hide file tree
Showing 26 changed files with 233 additions and 75 deletions.
9 changes: 6 additions & 3 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) {
@Override
public LogicalPlan visitRelation(Relation node, AnalysisContext context) {
QualifiedName qualifiedName = node.getTableQualifiedName();
String partitionName = node.getTablePartitionKeys();
DataSourceSchemaIdentifierNameResolver dataSourceSchemaIdentifierNameResolver
= new DataSourceSchemaIdentifierNameResolver(dataSourceService, qualifiedName.getParts());
String tableName = dataSourceSchemaIdentifierNameResolver.getIdentifierName();
Expand All @@ -156,9 +157,11 @@ public LogicalPlan visitRelation(Relation node, AnalysisContext context) {
.getDataSource(dataSourceSchemaIdentifierNameResolver.getDataSourceName())
.getStorageEngine()
.getTable(new DataSourceSchemaName(
dataSourceSchemaIdentifierNameResolver.getDataSourceName(),
dataSourceSchemaIdentifierNameResolver.getSchemaName()),
dataSourceSchemaIdentifierNameResolver.getIdentifierName());
dataSourceSchemaIdentifierNameResolver.getDataSourceName(),
dataSourceSchemaIdentifierNameResolver.getSchemaName()
),
dataSourceSchemaIdentifierNameResolver.getIdentifierName(),
partitionName);
}
table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v));
table.getReservedFieldTypes().forEach(
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Relation.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,26 @@ public Relation(UnresolvedExpression tableName) {
}

public Relation(UnresolvedExpression tableName, String alias) {
this(tableName, alias, null);
}

public Relation(UnresolvedExpression tableName, String alias, List<String> partitionKeys) {
this.tableName = Arrays.asList(tableName);
this.alias = alias;
this.partitionKeys = partitionKeys;
}

/**
* Optional alias name for the relation.
*/
private String alias;


/**
* Optional partition key(s) for the relation.
*/
private List<String> partitionKeys;

/**
* Return table name.
*
Expand Down Expand Up @@ -88,6 +99,15 @@ public QualifiedName getTableQualifiedName() {
}
}

/**
* Retrieve the partition keys associated with the table/relation
*
* @return TablePartitionKeys.
*/
public String getTablePartitionKeys() {
return String.join(COMMA, partitionKeys);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import java.util.Collection;
import java.util.Collections;
import javax.annotation.Nullable;
import org.opensearch.sql.DataSourceSchemaName;
import org.opensearch.sql.expression.function.FunctionResolver;

Expand All @@ -19,7 +20,7 @@ public interface StorageEngine {
/**
* Get {@link Table} from storage engine.
*/
Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName);
Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName, @Nullable String partition);

/**
* Get list of datasource related functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void onFailure(Exception e) {
}

// act 1, asserts in firstResponder
var t = new OpenSearchIndex(client, defaultSettings(), "test");
var t = new OpenSearchIndex(client, defaultSettings(), "test", "routingId");
LogicalPlan p = new LogicalPaginate(1, List.of(
new LogicalProject(
new LogicalRelation("test", t), List.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

// Route request to new query engine if it's supported already
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
sqlRequest.getSql(), request.path(), request.params(), sqlRequest.cursor());
sqlRequest.getSql(), request.path(), request.params(), sqlRequest.cursor(), sqlRequest.routingIds());
return newSqlQueryHandler.prepareRequest(newSqlRequest,
(restChannel, exception) -> {
try{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.json.JSONException;
import org.json.JSONObject;
import org.opensearch.common.settings.Settings;
Expand All @@ -28,6 +29,7 @@ public class SqlRequest {
JSONObject jsonContent;
String cursor;
Integer fetchSize;
private List<String> routingIds;

public SqlRequest(final String sql, final JSONObject jsonContent) {
this.sql = sql;
Expand All @@ -38,10 +40,12 @@ public SqlRequest(final String cursor) {
this.cursor = cursor;
}

public SqlRequest(final String sql, final Integer fetchSize, final JSONObject jsonContent) {
public SqlRequest(final String sql, final Integer fetchSize, final JSONObject jsonContent,
final List<String> routingIds) {
this.sql = sql;
this.fetchSize = fetchSize;
this.jsonContent = jsonContent;
this.routingIds = routingIds;
}

private static boolean isValidJson(String json) {
Expand All @@ -65,6 +69,8 @@ public Integer fetchSize() {
return this.fetchSize;
}

public List<String> routingIds() { return this.routingIds; }

public JSONObject getJsonContent() {
return this.jsonContent;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class SqlRequestFactory {

public static final String SQL_CURSOR_FIELD_NAME = "cursor";
public static final String SQL_FETCH_FIELD_NAME = "fetch_size";
public static final String ROUTING_FIELD_NAME = "routing";

public static SqlRequest getSqlRequest(RestRequest request) {
switch (request.method()) {
Expand Down Expand Up @@ -63,7 +64,22 @@ private static SqlRequest parseSqlRequestFromPayload(RestRequest restRequest) {
List<PreparedStatementRequest.PreparedStatementParameter> parameters = parseParameters(paramArray);
return new PreparedStatementRequest(sql, validateAndGetFetchSize(jsonContent), jsonContent, parameters);
}
return new SqlRequest(sql, validateAndGetFetchSize(jsonContent), jsonContent);

List<String> routingIds = List.of();
if (jsonContent.has(ROUTING_FIELD_NAME)) {
try {
routingIds = List.of(jsonContent.getString(ROUTING_FIELD_NAME));
} catch (JSONException ignored) {
try {
JSONArray routingIdArray = jsonContent.getJSONArray(ROUTING_FIELD_NAME);
routingIds = parseRoutingIds(routingIdArray);
} catch (JSONException jsonException) {
throw new IllegalArgumentException(ROUTING_FIELD_NAME + " parameter must be defined as a string or array value", jsonException);
}
}
}

return new SqlRequest(sql, validateAndGetFetchSize(jsonContent), jsonContent, routingIds);
}


Expand All @@ -82,6 +98,14 @@ private static Integer validateAndGetFetchSize(JSONObject jsonContent) {
return fetchSize.orElse(0);
}

private static List<String> parseRoutingIds(JSONArray array) {
List<String> routingIds = List.of();
for (int i = 0; i < array.length(); i++) {
routingIds.add(array.getString(i));
}
return routingIds;
}

private static List<PreparedStatementRequest.PreparedStatementParameter> parseParameters(
JSONArray paramsJsonArray) {
List<PreparedStatementRequest.PreparedStatementParameter> parameters = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -98,8 +99,9 @@ private static SQLQueryRequest createSqlQueryRequest(String query, Optional<Stri
builder.endObject();
JSONObject jsonContent = new JSONObject(Strings.toString(builder));

// TODO pass through
return new SQLQueryRequest(jsonContent, query, QUERY_API_ENDPOINT,
Map.of("format", "jdbc"), cursorId.orElse(""));
Map.of("format", "jdbc"), cursorId.orElse(""), List.of());
}

boolean doesQueryFallback(SQLQueryRequest request) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,48 @@ public class OpenSearchQueryRequest implements OpenSearchRequest {
*/
private boolean searchDone = false;

/**
*
*/
private final IndexName routingId;

/**
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(String indexName, int size,
public OpenSearchQueryRequest(String indexName,
String routingId,
int size,
OpenSearchExprValueFactory factory) {
this(new IndexName(indexName), size, factory);
this(new IndexName(indexName), new IndexName(routingId), size, factory);
}

/**
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(IndexName indexName, int size,
OpenSearchExprValueFactory factory) {
public OpenSearchQueryRequest(IndexName indexName,
IndexName routingId,
int size,
OpenSearchExprValueFactory factory) {
this.indexName = indexName;
this.sourceBuilder = new SearchSourceBuilder();
sourceBuilder.from(0);
sourceBuilder.size(size);
sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT);
this.exprValueFactory = factory;
this.routingId = routingId;
}

/**
* Constructor of OpenSearchQueryRequest.
*/
public OpenSearchQueryRequest(IndexName indexName, SearchSourceBuilder sourceBuilder,
public OpenSearchQueryRequest(IndexName indexName,
IndexName routingId,
SearchSourceBuilder sourceBuilder,
OpenSearchExprValueFactory factory) {
this.indexName = indexName;
this.sourceBuilder = sourceBuilder;
this.exprValueFactory = factory;
this.routingId = routingId;
}

@Override
Expand All @@ -101,8 +114,9 @@ public OpenSearchResponse search(Function<SearchRequest, SearchResponse> searchA
searchDone = true;
return new OpenSearchResponse(
searchAction.apply(new SearchRequest()
.indices(indexName.getIndexNames())
.source(sourceBuilder)), exprValueFactory, includes);
.indices(indexName.getIndexNames())
.source(sourceBuilder)
.routing(getRoutingId().getIndexNames())), exprValueFactory, includes);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,25 @@ public OpenSearchRequestBuilder(int requestedTotalSize,
* @return query request or scroll request
*/
public OpenSearchRequest build(OpenSearchRequest.IndexName indexName,
OpenSearchRequest.IndexName routingId,
int maxResultWindow, TimeValue scrollTimeout) {
int size = requestedTotalSize;
if (pageSize == null) {
if (startFrom + size > maxResultWindow) {
sourceBuilder.size(maxResultWindow - startFrom);
return new OpenSearchScrollRequest(
indexName, scrollTimeout, sourceBuilder, exprValueFactory);
indexName, routingId, scrollTimeout, sourceBuilder, exprValueFactory);
} else {
sourceBuilder.from(startFrom);
sourceBuilder.size(requestedTotalSize);
return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory);
return new OpenSearchQueryRequest(indexName, routingId, sourceBuilder, exprValueFactory);
}
} else {
if (startFrom != 0) {
throw new UnsupportedOperationException("Non-zero offset is not supported with pagination");
}
sourceBuilder.size(pageSize);
return new OpenSearchScrollRequest(indexName, scrollTimeout,
return new OpenSearchScrollRequest(indexName, routingId, scrollTimeout,
sourceBuilder, exprValueFactory);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ public class OpenSearchScrollRequest implements OpenSearchRequest {
*/
private final IndexName indexName;

/**
* Routing Ids used for the request
* {@link OpenSearchRequest.IndexName}.
*/
private final IndexName routingId;

/** Index name. */
@EqualsAndHashCode.Exclude
@ToString.Exclude
Expand All @@ -75,14 +81,17 @@ public class OpenSearchScrollRequest implements OpenSearchRequest {

/** Constructor. */
public OpenSearchScrollRequest(IndexName indexName,
IndexName routingId,
TimeValue scrollTimeout,
SearchSourceBuilder sourceBuilder,
OpenSearchExprValueFactory exprValueFactory) {
this.indexName = indexName;
this.routingId = routingId;
this.scrollTimeout = scrollTimeout;
this.exprValueFactory = exprValueFactory;
this.initialSearchRequest = new SearchRequest()
.indices(indexName.getIndexNames())
.routing(routingId.getIndexNames())
.scroll(scrollTimeout)
.source(sourceBuilder);

Expand Down Expand Up @@ -168,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(scrollId);
out.writeStringCollection(includes);
indexName.writeTo(out);
routingId.writeTo(out);
}

/**
Expand All @@ -183,7 +193,11 @@ public OpenSearchScrollRequest(StreamInput in, OpenSearchStorageEngine engine)
scrollId = in.readString();
includes = in.readStringList();
indexName = new IndexName(in);
OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString());
routingId = new IndexName(in);
OpenSearchIndex index = (OpenSearchIndex) engine.getTable(
null,
indexName.toString(),
routingId.toString());
exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ public class OpenSearchIndex implements Table {
*/
private final OpenSearchRequest.IndexName indexName;

/**
* Stores the routing id for the request
* {@link OpenSearchRequest.IndexName}.
*/
private final OpenSearchRequest.IndexName routingId;

/**
* The cached mapping of field and type in index.
*/
Expand All @@ -84,10 +90,11 @@ public class OpenSearchIndex implements Table {
/**
* Constructor.
*/
public OpenSearchIndex(OpenSearchClient client, Settings settings, String indexName) {
public OpenSearchIndex(OpenSearchClient client, Settings settings, String indexName, String routingId) {
this.client = client;
this.settings = settings;
this.indexName = new OpenSearchRequest.IndexName(indexName);
this.routingId = new OpenSearchRequest.IndexName(routingId);
}

@Override
Expand Down Expand Up @@ -180,7 +187,7 @@ public TableScanBuilder createScanBuilder() {
createExprValueFactory());
Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator =
requestBuilder -> new OpenSearchIndexScan(client, requestBuilder.getMaxResponseSize(),
requestBuilder.build(indexName, getMaxResultWindow(), cursorKeepAlive));
requestBuilder.build(indexName, routingId, getMaxResultWindow(), cursorKeepAlive));
return new OpenSearchIndexScanBuilder(builder, createScanOperator);
}

Expand Down
Loading

0 comments on commit 1a44d89

Please sign in to comment.