@@ -20,13 +20,12 @@ use tracing::{debug, debug_span, warn, Instrument};
2020
2121use crate :: {
2222 api:: {
23- self ,
2423 blobs:: { Bitfield , WriteProgress } ,
25- Store ,
24+ ExportBaoResult , Store ,
2625 } ,
2726 hashseq:: HashSeq ,
2827 protocol:: { GetManyRequest , GetRequest , ObserveItem , ObserveRequest , PushRequest , Request } ,
29- provider:: events:: { ClientConnected , ClientError , ConnectionClosed , RequestTracker } ,
28+ provider:: events:: { ClientConnected , ConnectionClosed , RequestTracker } ,
3029 Hash ,
3130} ;
3231pub mod events;
@@ -94,7 +93,7 @@ impl StreamPair {
9493 }
9594
9695 /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
97- pub async fn into_writer (
96+ async fn into_writer (
9897 mut self ,
9998 tracker : RequestTracker ,
10099 ) -> Result < ProgressWriter , ReadToEndError > {
@@ -118,7 +117,7 @@ impl StreamPair {
118117 ) )
119118 }
120119
121- pub async fn into_reader (
120+ async fn into_reader (
122121 mut self ,
123122 tracker : RequestTracker ,
124123 ) -> Result < ProgressReader , ClosedStream > {
@@ -141,39 +140,71 @@ impl StreamPair {
141140 }
142141
143142 pub async fn get_request (
144- & self ,
143+ mut self ,
145144 f : impl FnOnce ( ) -> GetRequest ,
146- ) -> Result < RequestTracker , ClientError > {
147- self . events
145+ ) -> anyhow:: Result < ProgressWriter > {
146+ let res = self
147+ . events
148148 . request ( f, self . connection_id , self . request_id )
149- . await
149+ . await ;
150+ match res {
151+ Err ( e) => {
152+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
153+ Err ( e. into ( ) )
154+ }
155+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
156+ }
150157 }
151158
152159 pub async fn get_many_request (
153- & self ,
160+ mut self ,
154161 f : impl FnOnce ( ) -> GetManyRequest ,
155- ) -> Result < RequestTracker , ClientError > {
156- self . events
162+ ) -> anyhow:: Result < ProgressWriter > {
163+ let res = self
164+ . events
157165 . request ( f, self . connection_id , self . request_id )
158- . await
166+ . await ;
167+ match res {
168+ Err ( e) => {
169+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
170+ Err ( e. into ( ) )
171+ }
172+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
173+ }
159174 }
160175
161176 pub async fn push_request (
162- & self ,
177+ mut self ,
163178 f : impl FnOnce ( ) -> PushRequest ,
164- ) -> Result < RequestTracker , ClientError > {
165- self . events
179+ ) -> anyhow:: Result < ProgressReader > {
180+ let res = self
181+ . events
166182 . request ( f, self . connection_id , self . request_id )
167- . await
183+ . await ;
184+ match res {
185+ Err ( e) => {
186+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
187+ Err ( e. into ( ) )
188+ }
189+ Ok ( tracker) => Ok ( self . into_reader ( tracker) . await ?) ,
190+ }
168191 }
169192
170193 pub async fn observe_request (
171- & self ,
194+ mut self ,
172195 f : impl FnOnce ( ) -> ObserveRequest ,
173- ) -> Result < RequestTracker , ClientError > {
174- self . events
196+ ) -> anyhow:: Result < ProgressWriter > {
197+ let res = self
198+ . events
175199 . request ( f, self . connection_id , self . request_id )
176- . await
200+ . await ;
201+ match res {
202+ Err ( e) => {
203+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
204+ Err ( e. into ( ) )
205+ }
206+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
207+ }
177208 }
178209
179210 fn stats ( & self ) -> TransferStats {
@@ -299,7 +330,8 @@ pub async fn handle_connection(
299330 } )
300331 . await
301332 {
302- debug ! ( "client not authorized to connect: {cause}" ) ;
333+ connection. close ( cause. code ( ) , cause. reason ( ) ) ;
334+ debug ! ( "closing connection: {cause}" ) ;
303335 return ;
304336 }
305337 while let Ok ( context) = StreamPair :: accept ( & connection, & progress) . await {
@@ -323,35 +355,32 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<
323355
324356 match request {
325357 Request :: Get ( request) => {
326- let tracker = context. get_request ( || request. clone ( ) ) . await ?;
327- let mut writer = context . into_writer ( tracker ) . await ? ;
328- if handle_get ( store , request , & mut writer ) . await . is_ok ( ) {
358+ let mut writer = context. get_request ( || request. clone ( ) ) . await ?;
359+ let res = handle_get ( store , request , & mut writer ) . await ;
360+ if res . is_ok ( ) {
329361 writer. transfer_completed ( ) . await ;
330362 } else {
331363 writer. transfer_aborted ( ) . await ;
332364 }
333365 }
334366 Request :: GetMany ( request) => {
335- let tracker = context. get_many_request ( || request. clone ( ) ) . await ?;
336- let mut writer = context. into_writer ( tracker) . await ?;
367+ let mut writer = context. get_many_request ( || request. clone ( ) ) . await ?;
337368 if handle_get_many ( store, request, & mut writer) . await . is_ok ( ) {
338369 writer. transfer_completed ( ) . await ;
339370 } else {
340371 writer. transfer_aborted ( ) . await ;
341372 }
342373 }
343374 Request :: Observe ( request) => {
344- let tracker = context. observe_request ( || request. clone ( ) ) . await ?;
345- let mut writer = context. into_writer ( tracker) . await ?;
375+ let mut writer = context. observe_request ( || request. clone ( ) ) . await ?;
346376 if handle_observe ( store, request, & mut writer) . await . is_ok ( ) {
347377 writer. transfer_completed ( ) . await ;
348378 } else {
349379 writer. transfer_aborted ( ) . await ;
350380 }
351381 }
352382 Request :: Push ( request) => {
353- let tracker = context. push_request ( || request. clone ( ) ) . await ?;
354- let mut reader = context. into_reader ( tracker) . await ?;
383+ let mut reader = context. push_request ( || request. clone ( ) ) . await ?;
355384 if handle_push ( store, request, & mut reader) . await . is_ok ( ) {
356385 reader. transfer_completed ( ) . await ;
357386 } else {
@@ -464,11 +493,11 @@ pub(crate) async fn send_blob(
464493 hash : Hash ,
465494 ranges : ChunkRanges ,
466495 writer : & mut ProgressWriter ,
467- ) -> api :: Result < ( ) > {
468- Ok ( store
496+ ) -> ExportBaoResult < ( ) > {
497+ store
469498 . export_bao ( hash, ranges)
470499 . write_quinn_with_progress ( & mut writer. inner , & mut writer. context , & hash, index)
471- . await ? )
500+ . await
472501}
473502
474503/// Handle a single push request.
0 commit comments