@@ -127,6 +127,8 @@ public void sendResponseBatch(
127
127
compress ,
128
128
isHandshake ,
129
129
false ,
130
+ false ,
131
+ null ,
130
132
storedContext
131
133
);
132
134
@@ -145,23 +147,23 @@ public void sendResponseBatch(
145
147
}
146
148
147
149
private void processBatchTask (BatchTask task ) {
148
- task .storedContext .restore ();
149
- if (!(task .channel instanceof FlightServerChannel flightChannel )) {
150
- Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel .getClass ().getName ());
151
- messageListener .onResponseSent (task .requestId , task .action , error );
150
+ task .storedContext () .restore ();
151
+ if (!(task .channel () instanceof FlightServerChannel flightChannel )) {
152
+ Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel () .getClass ().getName ());
153
+ messageListener .onResponseSent (task .requestId () , task .action () , error );
152
154
return ;
153
155
}
154
156
155
157
try {
156
158
try (VectorStreamOutput out = new VectorStreamOutput (flightChannel .getAllocator (), flightChannel .getRoot ())) {
157
- task .response .writeTo (out );
158
- flightChannel .sendBatch (getHeaderBuffer (task .requestId , task .nodeVersion , task .features ), out );
159
- messageListener .onResponseSent (task .requestId , task .action , task .response );
159
+ task .response () .writeTo (out );
160
+ flightChannel .sendBatch (getHeaderBuffer (task .requestId () , task .nodeVersion () , task .features () ), out );
161
+ messageListener .onResponseSent (task .requestId () , task .action () , task .response () );
160
162
}
161
163
} catch (FlightRuntimeException e ) {
162
- messageListener .onResponseSent (task .requestId , task .action , FlightErrorMapper .fromFlightException (e ));
164
+ messageListener .onResponseSent (task .requestId () , task .action () , FlightErrorMapper .fromFlightException (e ));
163
165
} catch (Exception e ) {
164
- messageListener .onResponseSent (task .requestId , task .action , e );
166
+ messageListener .onResponseSent (task .requestId () , task .action () , e );
165
167
}
166
168
}
167
169
@@ -185,6 +187,8 @@ public void completeStream(
185
187
false ,
186
188
false ,
187
189
true ,
190
+ false ,
191
+ null ,
188
192
storedContext
189
193
);
190
194
@@ -203,18 +207,18 @@ public void completeStream(
203
207
}
204
208
205
209
private void processCompleteTask (BatchTask task ) {
206
- task .storedContext .restore ();
207
- if (!(task .channel instanceof FlightServerChannel flightChannel )) {
208
- Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel .getClass ().getName ());
209
- messageListener .onResponseSent (task .requestId , task .action , error );
210
+ task .storedContext () .restore ();
211
+ if (!(task .channel () instanceof FlightServerChannel flightChannel )) {
212
+ Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel () .getClass ().getName ());
213
+ messageListener .onResponseSent (task .requestId () , task .action () , error );
210
214
return ;
211
215
}
212
216
213
217
try {
214
218
flightChannel .completeStream ();
215
- messageListener .onResponseSent (task .requestId , task .action , TransportResponse .Empty .INSTANCE );
219
+ messageListener .onResponseSent (task .requestId () , task .action () , TransportResponse .Empty .INSTANCE );
216
220
} catch (Exception e ) {
217
- messageListener .onResponseSent (task .requestId , task .action , e );
221
+ messageListener .onResponseSent (task .requestId () , task .action () , e );
218
222
}
219
223
}
220
224
@@ -235,8 +239,11 @@ public void sendErrorResponse(
235
239
transportChannel ,
236
240
requestId ,
237
241
action ,
242
+ null ,
243
+ false ,
238
244
false ,
239
245
false ,
246
+ true ,
240
247
error ,
241
248
storedContext
242
249
);
@@ -256,22 +263,22 @@ public void sendErrorResponse(
256
263
}
257
264
258
265
private void processErrorTask (BatchTask task ) {
259
- task .storedContext .restore ();
260
- if (!(task .channel instanceof FlightServerChannel flightServerChannel )) {
261
- Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel .getClass ().getName ());
262
- messageListener .onResponseSent (task .requestId , task .action , error );
266
+ task .storedContext () .restore ();
267
+ if (!(task .channel () instanceof FlightServerChannel flightServerChannel )) {
268
+ Exception error = new IllegalStateException ("Expected FlightServerChannel, got " + task .channel () .getClass ().getName ());
269
+ messageListener .onResponseSent (task .requestId () , task .action () , error );
263
270
return ;
264
271
}
265
272
266
273
try {
267
- Exception flightError = task .error ;
268
- if (task .error instanceof StreamException ) {
269
- flightError = FlightErrorMapper .toFlightException ((StreamException ) task .error );
274
+ Exception flightError = task .error () ;
275
+ if (task .error () instanceof StreamException ) {
276
+ flightError = FlightErrorMapper .toFlightException ((StreamException ) task .error () );
270
277
}
271
- flightServerChannel .sendError (getHeaderBuffer (task .requestId , task .nodeVersion , task .features ), flightError );
272
- messageListener .onResponseSent (task .requestId , task .action , task .error );
278
+ flightServerChannel .sendError (getHeaderBuffer (task .requestId () , task .nodeVersion () , task .features () ), flightError );
279
+ messageListener .onResponseSent (task .requestId () , task .action () , task .error () );
273
280
} catch (Exception e ) {
274
- messageListener .onResponseSent (task .requestId , task .action , e );
281
+ messageListener .onResponseSent (task .requestId () , task .action () , e );
275
282
}
276
283
}
277
284
@@ -301,4 +308,19 @@ private ByteBuffer getHeaderBuffer(long requestId, Version nodeVersion, Set<Stri
301
308
return ByteBuffer .wrap (headerBytes .toBytesRef ().bytes );
302
309
}
303
310
}
311
+
312
+ record BatchTask (Version nodeVersion , Set <String > features , TcpChannel channel , FlightTransportChannel transportChannel , long requestId ,
313
+ String action , TransportResponse response , boolean compress , boolean isHandshake , boolean isComplete , boolean isError ,
314
+ Exception error , ThreadContext .StoredContext storedContext ) implements AutoCloseable {
315
+
316
+ @ Override
317
+ public void close () {
318
+ if (storedContext != null ) {
319
+ storedContext .close ();
320
+ }
321
+ if ((isComplete || isError ) && transportChannel != null ) {
322
+ transportChannel .releaseChannel (isError );
323
+ }
324
+ }
325
+ }
304
326
}
0 commit comments