Skip to content

Commit 6a90b41

Browse files
committed
Move BatchTask to record
Signed-off-by: Rishabh Maurya <[email protected]>
1 parent 0fc1213 commit 6a90b41

File tree

3 files changed

+50
-161
lines changed

3 files changed

+50
-161
lines changed

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/BatchTask.java

Lines changed: 0 additions & 135 deletions
This file was deleted.

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ public void sendResponseBatch(
127127
compress,
128128
isHandshake,
129129
false,
130+
false,
131+
null,
130132
storedContext
131133
);
132134

@@ -145,23 +147,23 @@ public void sendResponseBatch(
145147
}
146148

147149
private void processBatchTask(BatchTask task) {
148-
task.storedContext.restore();
149-
if (!(task.channel instanceof FlightServerChannel flightChannel)) {
150-
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel.getClass().getName());
151-
messageListener.onResponseSent(task.requestId, task.action, error);
150+
task.storedContext().restore();
151+
if (!(task.channel() instanceof FlightServerChannel flightChannel)) {
152+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
153+
messageListener.onResponseSent(task.requestId(), task.action(), error);
152154
return;
153155
}
154156

155157
try {
156158
try (VectorStreamOutput out = new VectorStreamOutput(flightChannel.getAllocator(), flightChannel.getRoot())) {
157-
task.response.writeTo(out);
158-
flightChannel.sendBatch(getHeaderBuffer(task.requestId, task.nodeVersion, task.features), out);
159-
messageListener.onResponseSent(task.requestId, task.action, task.response);
159+
task.response().writeTo(out);
160+
flightChannel.sendBatch(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()), out);
161+
messageListener.onResponseSent(task.requestId(), task.action(), task.response());
160162
}
161163
} catch (FlightRuntimeException e) {
162-
messageListener.onResponseSent(task.requestId, task.action, FlightErrorMapper.fromFlightException(e));
164+
messageListener.onResponseSent(task.requestId(), task.action(), FlightErrorMapper.fromFlightException(e));
163165
} catch (Exception e) {
164-
messageListener.onResponseSent(task.requestId, task.action, e);
166+
messageListener.onResponseSent(task.requestId(), task.action(), e);
165167
}
166168
}
167169

@@ -185,6 +187,8 @@ public void completeStream(
185187
false,
186188
false,
187189
true,
190+
false,
191+
null,
188192
storedContext
189193
);
190194

@@ -203,18 +207,18 @@ public void completeStream(
203207
}
204208

205209
private void processCompleteTask(BatchTask task) {
206-
task.storedContext.restore();
207-
if (!(task.channel instanceof FlightServerChannel flightChannel)) {
208-
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel.getClass().getName());
209-
messageListener.onResponseSent(task.requestId, task.action, error);
210+
task.storedContext().restore();
211+
if (!(task.channel() instanceof FlightServerChannel flightChannel)) {
212+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
213+
messageListener.onResponseSent(task.requestId(), task.action(), error);
210214
return;
211215
}
212216

213217
try {
214218
flightChannel.completeStream();
215-
messageListener.onResponseSent(task.requestId, task.action, TransportResponse.Empty.INSTANCE);
219+
messageListener.onResponseSent(task.requestId(), task.action(), TransportResponse.Empty.INSTANCE);
216220
} catch (Exception e) {
217-
messageListener.onResponseSent(task.requestId, task.action, e);
221+
messageListener.onResponseSent(task.requestId(), task.action(), e);
218222
}
219223
}
220224

@@ -235,8 +239,11 @@ public void sendErrorResponse(
235239
transportChannel,
236240
requestId,
237241
action,
242+
null,
243+
false,
238244
false,
239245
false,
246+
true,
240247
error,
241248
storedContext
242249
);
@@ -256,22 +263,22 @@ public void sendErrorResponse(
256263
}
257264

258265
private void processErrorTask(BatchTask task) {
259-
task.storedContext.restore();
260-
if (!(task.channel instanceof FlightServerChannel flightServerChannel)) {
261-
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel.getClass().getName());
262-
messageListener.onResponseSent(task.requestId, task.action, error);
266+
task.storedContext().restore();
267+
if (!(task.channel() instanceof FlightServerChannel flightServerChannel)) {
268+
Exception error = new IllegalStateException("Expected FlightServerChannel, got " + task.channel().getClass().getName());
269+
messageListener.onResponseSent(task.requestId(), task.action(), error);
263270
return;
264271
}
265272

266273
try {
267-
Exception flightError = task.error;
268-
if (task.error instanceof StreamException) {
269-
flightError = FlightErrorMapper.toFlightException((StreamException) task.error);
274+
Exception flightError = task.error();
275+
if (task.error() instanceof StreamException) {
276+
flightError = FlightErrorMapper.toFlightException((StreamException) task.error());
270277
}
271-
flightServerChannel.sendError(getHeaderBuffer(task.requestId, task.nodeVersion, task.features), flightError);
272-
messageListener.onResponseSent(task.requestId, task.action, task.error);
278+
flightServerChannel.sendError(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()), flightError);
279+
messageListener.onResponseSent(task.requestId(), task.action(), task.error());
273280
} catch (Exception e) {
274-
messageListener.onResponseSent(task.requestId, task.action, e);
281+
messageListener.onResponseSent(task.requestId(), task.action(), e);
275282
}
276283
}
277284

@@ -301,4 +308,19 @@ private ByteBuffer getHeaderBuffer(long requestId, Version nodeVersion, Set<Stri
301308
return ByteBuffer.wrap(headerBytes.toBytesRef().bytes);
302309
}
303310
}
311+
312+
record BatchTask(Version nodeVersion, Set<String> features, TcpChannel channel, FlightTransportChannel transportChannel, long requestId,
313+
String action, TransportResponse response, boolean compress, boolean isHandshake, boolean isComplete, boolean isError,
314+
Exception error, ThreadContext.StoredContext storedContext) implements AutoCloseable {
315+
316+
@Override
317+
public void close() {
318+
if (storedContext != null) {
319+
storedContext.close();
320+
}
321+
if ((isComplete || isError) && transportChannel != null) {
322+
transportChannel.releaseChannel(isError);
323+
}
324+
}
325+
}
304326
}

plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/transport/FlightTransportChannelTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public void testCompleteStreamSuccess() {
156156
);
157157

158158
// Simulate async completion by manually creating and closing a BatchTask
159-
BatchTask completeTask = new BatchTask(
159+
FlightOutboundHandler.BatchTask completeTask = new FlightOutboundHandler.BatchTask(
160160
Version.CURRENT,
161161
Collections.emptySet(),
162162
mockTcpChannel,
@@ -167,6 +167,8 @@ public void testCompleteStreamSuccess() {
167167
false,
168168
false,
169169
true,
170+
false,
171+
null,
170172
null
171173
);
172174
completeTask.close();

0 commit comments

Comments
 (0)