Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt](arrow-flight-sql) Support arrow-flight-sql protocol getStreamCatalogs, getStreamSchemas, getStreamTables #46217

Merged
merged 4 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,19 @@ private void getStreamStatementResult(String handle, ServerStreamListener listen
String[] handleParts = handle.split(":");
String executedPeerIdentity = handleParts[0];
String queryId = handleParts[1];
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity);
try {
// The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull(
connectContext.getFlightSqlChannel().getResult(queryId));
final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot();
listener.start(vectorSchemaRoot);
listener.putNext();
} catch (Exception e) {
listener.error(e);
String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
handleStreamException(e, errMsg, listener);
} finally {
listener.completed();
// The result has been sent or sent failed, delete it.
Expand Down Expand Up @@ -280,7 +278,7 @@ private FlightInfo executeQueryStatement(String peerIdentity, ConnectContext con
String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
} finally {
connectContext.setCommand(MysqlCommand.COM_SLEEP);
Expand Down Expand Up @@ -361,7 +359,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r
String errMsg = "create prepared statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(
e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
listener.onError(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
return;
} catch (final Throwable t) {
Expand Down Expand Up @@ -407,7 +405,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
} catch (Exception e) {
String errMsg = "acceptPutPreparedStatementUpdate failed, " + e.getMessage() + ", "
+ Util.getRootCauseMessage(e);
LOG.warn(errMsg, e);
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
};
Expand Down Expand Up @@ -451,7 +449,21 @@ public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs request, final

@Override
public void getStreamCatalogs(final CallContext context, final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCatalogs unimplemented").toRuntimeException();
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
final Schema schema = Schemas.GET_CATALOGS_SCHEMA;

try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
flightSqlSchemaHelper.getCatalogs(vectorSchemaRoot);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand All @@ -463,7 +475,22 @@ public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas request, final
@Override
public void getStreamSchemas(final CommandGetDbSchemas command, final CallContext context,
final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamSchemas unimplemented").toRuntimeException();
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
flightSqlSchemaHelper.setParameterForGetDbSchemas(command);
final Schema schema = Schemas.GET_SCHEMAS_SCHEMA;

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
flightSqlSchemaHelper.getSchemas(vectorSchemaRoot);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand All @@ -479,7 +506,23 @@ public FlightInfo getFlightInfoTables(final CommandGetTables request, final Call
@Override
public void getStreamTables(final CommandGetTables command, final CallContext context,
final ServerStreamListener listener) {
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTables unimplemented").toRuntimeException();
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
flightSqlSchemaHelper.setParameterForGetTables(command);
final Schema schema = command.getIncludeSchema() ? Schemas.GET_TABLES_SCHEMA
: Schemas.GET_TABLES_SCHEMA_NO_SCHEMA;

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
listener.start(vectorSchemaRoot);
vectorSchemaRoot.allocateNew();
flightSqlSchemaHelper.getTables(vectorSchemaRoot);
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
handleStreamException(e, "", listener);
}
}

@Override
Expand All @@ -502,7 +545,6 @@ public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys request,
@Override
public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final CallContext context,
final ServerStreamListener listener) {

throw CallStatus.UNIMPLEMENTED.withDescription("getStreamPrimaryKeys unimplemented").toRuntimeException();
}

Expand Down Expand Up @@ -545,9 +587,14 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex
private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
final Schema schema) {
final Ticket ticket = new Ticket(Any.pack(request).toByteArray());
// TODO Support multiple endpoints.
final List<FlightEndpoint> endpoints = Collections.singletonList(new FlightEndpoint(ticket, location));

return new FlightInfo(schema, descriptor, endpoints, -1, -1);
}

private static void handleStreamException(Exception e, String errMsg, ServerStreamListener listener) {
LOG.error(errMsg, e);
listener.error(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
}
Loading
Loading