Skip to content

Commit a67d787

Browse files
committed
Use irpc::channel::SendError as default sink error.
1 parent 2dac46c commit a67d787

File tree

7 files changed

+83
-57
lines changed

7 files changed

+83
-57
lines changed

src/api.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub mod downloader;
3030
pub mod proto;
3131
pub mod remote;
3232
pub mod tags;
33-
use crate::api::proto::WaitIdleRequest;
33+
use crate::{api::proto::WaitIdleRequest, provider::events::ProgressError};
3434
pub use crate::{store::util::Tag, util::temp_tag::TempTag};
3535

3636
pub(crate) type ApiClient = irpc::Client<proto::Request>;
@@ -97,6 +97,8 @@ pub enum ExportBaoError {
9797
ExportBaoIo { source: io::Error },
9898
#[snafu(display("encode error: {source}"))]
9999
ExportBaoInner { source: bao_tree::io::EncodeError },
100+
#[snafu(display("client error: {source}"))]
101+
ClientError { source: ProgressError },
100102
}
101103

102104
impl From<ExportBaoError> for Error {
@@ -107,6 +109,7 @@ impl From<ExportBaoError> for Error {
107109
ExportBaoError::Request { source, .. } => Self::Io(source.into()),
108110
ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source),
109111
ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()),
112+
ExportBaoError::ClientError { source, .. } => Self::Io(source.into()),
110113
}
111114
}
112115
}
@@ -152,6 +155,12 @@ impl From<bao_tree::io::EncodeError> for ExportBaoError {
152155
}
153156
}
154157

158+
impl From<ProgressError> for ExportBaoError {
159+
fn from(value: ProgressError) -> Self {
160+
ClientSnafu.into_error(value)
161+
}
162+
}
163+
155164
pub type ExportBaoResult<T> = std::result::Result<T, ExportBaoError>;
156165

157166
#[derive(Debug, derive_more::Display, derive_more::From, Serialize, Deserialize)]

src/api/blobs.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ use super::{
5757
};
5858
use crate::{
5959
api::proto::{BatchRequest, ImportByteStreamUpdate},
60+
provider::events::ClientResult,
6061
store::IROH_BLOCK_SIZE,
6162
util::temp_tag::TempTag,
6263
BlobFormat, Hash, HashAndFormat,
@@ -1112,7 +1113,9 @@ impl ExportBaoProgress {
11121113
.write_chunk(leaf.data)
11131114
.await
11141115
.map_err(io::Error::from)?;
1115-
progress.notify_payload_write(index, leaf.offset, len).await;
1116+
progress
1117+
.notify_payload_write(index, leaf.offset, len)
1118+
.await?;
11161119
}
11171120
EncodedItem::Done => break,
11181121
EncodedItem::Error(cause) => return Err(cause.into()),
@@ -1158,7 +1161,7 @@ impl ExportBaoProgress {
11581161

11591162
pub(crate) trait WriteProgress {
11601163
/// Notify the progress writer that a payload write has happened.
1161-
async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize);
1164+
async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) -> ClientResult;
11621165

11631166
/// Log a write of some other data.
11641167
fn log_other_write(&mut self, len: usize);

src/api/downloader.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::{
33
collections::{HashMap, HashSet},
44
fmt::Debug,
55
future::{Future, IntoFuture},
6-
io,
76
sync::Arc,
87
};
98

@@ -113,7 +112,7 @@ async fn handle_download_impl(
113112
SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
114113
SplitStrategy::None => match request.request {
115114
FiniteRequest::Get(get) => {
116-
let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other);
115+
let sink = IrpcSenderRefSink(tx);
117116
execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
118117
}
119118
FiniteRequest::GetMany(_) => {
@@ -144,7 +143,7 @@ async fn handle_download_split_impl(
144143
let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16);
145144
progress_tx.send(rx).await.ok();
146145
let sink = TokioMpscSenderSink(tx)
147-
.with_map_err(io::Error::other)
146+
.with_map_err(|_| irpc::channel::SendError::ReceiverClosed)
148147
.with_map(move |x| (id, x));
149148
let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
150149
(hash, res)
@@ -375,7 +374,7 @@ async fn split_request<'a>(
375374
providers: &Arc<dyn ContentDiscovery>,
376375
pool: &ConnectionPool,
377376
store: &Store,
378-
progress: impl Sink<DownloadProgessItem, Error = io::Error>,
377+
progress: impl Sink<DownloadProgessItem, Error = irpc::channel::SendError>,
379378
) -> anyhow::Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
380379
Ok(match request {
381380
FiniteRequest::Get(req) => {
@@ -431,7 +430,7 @@ async fn execute_get(
431430
request: Arc<GetRequest>,
432431
providers: &Arc<dyn ContentDiscovery>,
433432
store: &Store,
434-
mut progress: impl Sink<DownloadProgessItem, Error = io::Error>,
433+
mut progress: impl Sink<DownloadProgessItem, Error = irpc::channel::SendError>,
435434
) -> anyhow::Result<()> {
436435
let remote = store.remote();
437436
let mut providers = providers.find_providers(request.content());

src/api/remote.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::{
1818
GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
1919
MAX_MESSAGE_SIZE,
2020
},
21+
provider::events::{ClientResult, ProgressError},
2122
util::sink::{Sink, TokioMpscSenderSink},
2223
};
2324

@@ -478,9 +479,7 @@ impl Remote {
478479
let content = content.into();
479480
let (tx, rx) = tokio::sync::mpsc::channel(64);
480481
let tx2 = tx.clone();
481-
let sink = TokioMpscSenderSink(tx)
482-
.with_map(GetProgressItem::Progress)
483-
.with_map_err(io::Error::other);
482+
let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
484483
let this = self.clone();
485484
let fut = async move {
486485
let res = this.fetch_sink(conn, content, sink).await.into();
@@ -503,7 +502,7 @@ impl Remote {
503502
&self,
504503
mut conn: impl GetConnection,
505504
content: impl Into<HashAndFormat>,
506-
progress: impl Sink<u64, Error = io::Error>,
505+
progress: impl Sink<u64, Error = irpc::channel::SendError>,
507506
) -> GetResult<Stats> {
508507
let content = content.into();
509508
let local = self
@@ -556,9 +555,7 @@ impl Remote {
556555
pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
557556
let (tx, rx) = tokio::sync::mpsc::channel(64);
558557
let tx2 = tx.clone();
559-
let sink = TokioMpscSenderSink(tx)
560-
.with_map(PushProgressItem::Progress)
561-
.with_map_err(io::Error::other);
558+
let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
562559
let this = self.clone();
563560
let fut = async move {
564561
let res = this.execute_push_sink(conn, request, sink).await.into();
@@ -577,7 +574,7 @@ impl Remote {
577574
&self,
578575
conn: Connection,
579576
request: PushRequest,
580-
progress: impl Sink<u64, Error = io::Error>,
577+
progress: impl Sink<u64, Error = irpc::channel::SendError>,
581578
) -> anyhow::Result<Stats> {
582579
let hash = request.hash;
583580
debug!(%hash, "pushing");
@@ -632,9 +629,7 @@ impl Remote {
632629
pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
633630
let (tx, rx) = tokio::sync::mpsc::channel(64);
634631
let tx2 = tx.clone();
635-
let sink = TokioMpscSenderSink(tx)
636-
.with_map(GetProgressItem::Progress)
637-
.with_map_err(io::Error::other);
632+
let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
638633
let this = self.clone();
639634
let fut = async move {
640635
let res = this.execute_get_sink(&conn, request, sink).await.into();
@@ -658,7 +653,7 @@ impl Remote {
658653
&self,
659654
conn: &Connection,
660655
request: GetRequest,
661-
mut progress: impl Sink<u64, Error = io::Error>,
656+
mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
662657
) -> GetResult<Stats> {
663658
let store = self.store();
664659
let root = request.hash;
@@ -721,9 +716,7 @@ impl Remote {
721716
pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
722717
let (tx, rx) = tokio::sync::mpsc::channel(64);
723718
let tx2 = tx.clone();
724-
let sink = TokioMpscSenderSink(tx)
725-
.with_map(GetProgressItem::Progress)
726-
.with_map_err(io::Error::other);
719+
let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
727720
let this = self.clone();
728721
let fut = async move {
729722
let res = this.execute_get_many_sink(conn, request, sink).await.into();
@@ -747,7 +740,7 @@ impl Remote {
747740
&self,
748741
conn: Connection,
749742
request: GetManyRequest,
750-
mut progress: impl Sink<u64, Error = io::Error>,
743+
mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
751744
) -> GetResult<Stats> {
752745
let store = self.store();
753746
let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
@@ -884,7 +877,7 @@ async fn get_blob_ranges_impl(
884877
header: AtBlobHeader,
885878
hash: Hash,
886879
store: &Store,
887-
mut progress: impl Sink<u64, Error = io::Error>,
880+
mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
888881
) -> GetResult<AtEndBlob> {
889882
let (mut content, size) = header.next().await?;
890883
let Some(size) = NonZeroU64::new(size) else {
@@ -1048,11 +1041,20 @@ struct StreamContext<S> {
10481041

10491042
impl<S> WriteProgress for StreamContext<S>
10501043
where
1051-
S: Sink<u64, Error = io::Error>,
1044+
S: Sink<u64, Error = irpc::channel::SendError>,
10521045
{
1053-
async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) {
1046+
async fn notify_payload_write(
1047+
&mut self,
1048+
_index: u64,
1049+
_offset: u64,
1050+
len: usize,
1051+
) -> ClientResult {
10541052
self.payload_bytes_sent += len as u64;
1055-
self.sender.send(self.payload_bytes_sent).await.ok();
1053+
self.sender
1054+
.send(self.payload_bytes_sent)
1055+
.await
1056+
.map_err(|e| ProgressError::Internal { source: e.into() })?;
1057+
Ok(())
10561058
}
10571059

10581060
fn log_other_write(&mut self, _len: usize) {}

src/provider.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
},
2626
hashseq::HashSeq,
2727
protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request},
28-
provider::events::{ClientConnected, ConnectionClosed, RequestTracker},
28+
provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker},
2929
Hash,
3030
};
3131
pub mod events;
@@ -264,11 +264,11 @@ impl WriterContext {
264264
}
265265

266266
impl WriteProgress for WriterContext {
267-
async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) {
267+
async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
268268
let len = len as u64;
269269
let end_offset = offset + len;
270270
self.payload_bytes_written += len;
271-
self.tracker.transfer_progress(len, end_offset).await.ok();
271+
self.tracker.transfer_progress(len, end_offset).await
272272
}
273273

274274
fn log_other_write(&mut self, len: usize) {

src/provider/events.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{fmt::Debug, ops::Deref};
1+
use std::{fmt::Debug, io, ops::Deref};
22

33
use irpc::{
44
channel::{mpsc, none::NoSender, oneshot},
@@ -76,60 +76,70 @@ pub enum AbortReason {
7676
}
7777

7878
#[derive(Debug, Snafu)]
79-
pub enum ClientError {
80-
RateLimited,
79+
pub enum ProgressError {
80+
Limit,
8181
Permission,
8282
#[snafu(transparent)]
83-
Irpc {
83+
Internal {
8484
source: irpc::Error,
8585
},
8686
}
8787

88-
impl ClientError {
88+
impl From<ProgressError> for io::Error {
89+
fn from(value: ProgressError) -> Self {
90+
match value {
91+
ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
92+
ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
93+
ProgressError::Internal { source } => source.into(),
94+
}
95+
}
96+
}
97+
98+
impl ProgressError {
8999
pub fn code(&self) -> quinn::VarInt {
90100
match self {
91-
ClientError::RateLimited => ERR_LIMIT,
92-
ClientError::Permission => ERR_PERMISSION,
93-
ClientError::Irpc { .. } => ERR_INTERNAL,
101+
ProgressError::Limit => ERR_LIMIT,
102+
ProgressError::Permission => ERR_PERMISSION,
103+
ProgressError::Internal { .. } => ERR_INTERNAL,
94104
}
95105
}
96106

97107
pub fn reason(&self) -> &'static [u8] {
98108
match self {
99-
ClientError::RateLimited => b"limit",
100-
ClientError::Permission => b"permission",
101-
ClientError::Irpc { .. } => b"internal",
109+
ProgressError::Limit => b"limit",
110+
ProgressError::Permission => b"permission",
111+
ProgressError::Internal { .. } => b"internal",
102112
}
103113
}
104114
}
105115

106-
impl From<AbortReason> for ClientError {
116+
impl From<AbortReason> for ProgressError {
107117
fn from(value: AbortReason) -> Self {
108118
match value {
109-
AbortReason::RateLimited => ClientError::RateLimited,
110-
AbortReason::Permission => ClientError::Permission,
119+
AbortReason::RateLimited => ProgressError::Limit,
120+
AbortReason::Permission => ProgressError::Permission,
111121
}
112122
}
113123
}
114124

115-
impl From<irpc::channel::RecvError> for ClientError {
125+
impl From<irpc::channel::RecvError> for ProgressError {
116126
fn from(value: irpc::channel::RecvError) -> Self {
117-
ClientError::Irpc {
127+
ProgressError::Internal {
118128
source: value.into(),
119129
}
120130
}
121131
}
122132

123-
impl From<irpc::channel::SendError> for ClientError {
133+
impl From<irpc::channel::SendError> for ProgressError {
124134
fn from(value: irpc::channel::SendError) -> Self {
125-
ClientError::Irpc {
135+
ProgressError::Internal {
126136
source: value.into(),
127137
}
128138
}
129139
}
130140

131141
pub type EventResult = Result<(), AbortReason>;
132-
pub type ClientResult = Result<(), ClientError>;
142+
pub type ClientResult = Result<(), ProgressError>;
133143

134144
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135145
pub struct EventMask {
@@ -407,7 +417,7 @@ impl EventSender {
407417
f: impl FnOnce() -> Req,
408418
connection_id: u64,
409419
request_id: u64,
410-
) -> Result<RequestTracker, ClientError>
420+
) -> Result<RequestTracker, ProgressError>
411421
where
412422
ProviderProto: From<RequestReceived<Req>>,
413423
ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
@@ -466,7 +476,7 @@ impl EventSender {
466476
RequestUpdates::Active(tx)
467477
}
468478
RequestMode::Disabled => {
469-
return Err(ClientError::Permission);
479+
return Err(ProgressError::Permission);
470480
}
471481
_ => RequestUpdates::None,
472482
},

0 commit comments

Comments
 (0)