Skip to content

Commit d764dc0

Browse files
committed
Genericize provider side a bit
1 parent 2f9ebd5 commit d764dc0

File tree

4 files changed

+37
-56
lines changed

4 files changed

+37
-56
lines changed

src/api/blobs.rs

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@ use bao_tree::{
2323
};
2424
use bytes::Bytes;
2525
use genawaiter::sync::Gen;
26-
use iroh_io::{AsyncStreamReader, TokioStreamReader};
26+
use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
2727
use irpc::channel::{mpsc, oneshot};
2828
use n0_future::{future, stream, Stream, StreamExt};
29-
use quinn::SendStream;
3029
use range_collections::{range_set::RangeSetRange, RangeSet2};
3130
use ref_cast::RefCast;
3231
use serde::{Deserialize, Serialize};
33-
use tokio::io::AsyncWriteExt;
3432
use tracing::trace;
3533
mod reader;
3634
pub use reader::BlobReader;
@@ -431,7 +429,7 @@ impl Blobs {
431429
}
432430

433431
#[cfg_attr(feature = "hide-proto-docs", doc(hidden))]
434-
async fn import_bao_reader<R: AsyncStreamReader>(
432+
pub async fn import_bao_reader<R: AsyncStreamReader>(
435433
&self,
436434
hash: Hash,
437435
ranges: ChunkRanges,
@@ -468,18 +466,6 @@ impl Blobs {
468466
Ok(reader?)
469467
}
470468

471-
#[cfg_attr(feature = "hide-proto-docs", doc(hidden))]
472-
pub async fn import_bao_quinn(
473-
&self,
474-
hash: Hash,
475-
ranges: ChunkRanges,
476-
stream: &mut iroh::endpoint::RecvStream,
477-
) -> RequestResult<()> {
478-
let reader = TokioStreamReader::new(stream);
479-
self.import_bao_reader(hash, ranges, reader).await?;
480-
Ok(())
481-
}
482-
483469
#[cfg_attr(feature = "hide-proto-docs", doc(hidden))]
484470
pub async fn import_bao_bytes(
485471
&self,
@@ -1058,24 +1044,21 @@ impl ExportBaoProgress {
10581044
Ok(data)
10591045
}
10601046

1061-
pub async fn write_quinn(self, target: &mut quinn::SendStream) -> super::ExportBaoResult<()> {
1047+
pub async fn write<W: AsyncStreamWriter>(self, target: &mut W) -> super::ExportBaoResult<()> {
10621048
let mut rx = self.inner.await?;
10631049
while let Some(item) = rx.recv().await? {
10641050
match item {
10651051
EncodedItem::Size(size) => {
1066-
target.write_u64_le(size).await?;
1052+
target.write(&size.to_le_bytes()).await?;
10671053
}
10681054
EncodedItem::Parent(parent) => {
10691055
let mut data = vec![0u8; 64];
10701056
data[..32].copy_from_slice(parent.pair.0.as_bytes());
10711057
data[32..].copy_from_slice(parent.pair.1.as_bytes());
1072-
target.write_all(&data).await.map_err(io::Error::from)?;
1058+
target.write(&data).await?;
10731059
}
10741060
EncodedItem::Leaf(leaf) => {
1075-
target
1076-
.write_chunk(leaf.data)
1077-
.await
1078-
.map_err(io::Error::from)?;
1061+
target.write_bytes(leaf.data).await?;
10791062
}
10801063
EncodedItem::Done => break,
10811064
EncodedItem::Error(cause) => return Err(cause.into()),
@@ -1085,9 +1068,9 @@ impl ExportBaoProgress {
10851068
}
10861069

10871070
/// Write quinn variant that also feeds a progress writer.
1088-
pub(crate) async fn write_quinn_with_progress(
1071+
pub(crate) async fn write_with_progress<W: AsyncStreamWriter>(
10891072
self,
1090-
writer: &mut SendStream,
1073+
writer: &mut W,
10911074
progress: &mut impl WriteProgress,
10921075
hash: &Hash,
10931076
index: u64,
@@ -1097,22 +1080,19 @@ impl ExportBaoProgress {
10971080
match item {
10981081
EncodedItem::Size(size) => {
10991082
progress.send_transfer_started(index, hash, size).await;
1100-
writer.write_u64_le(size).await?;
1083+
writer.write(&size.to_le_bytes()).await?;
11011084
progress.log_other_write(8);
11021085
}
11031086
EncodedItem::Parent(parent) => {
11041087
let mut data = vec![0u8; 64];
11051088
data[..32].copy_from_slice(parent.pair.0.as_bytes());
11061089
data[32..].copy_from_slice(parent.pair.1.as_bytes());
1107-
writer.write_all(&data).await.map_err(io::Error::from)?;
1090+
writer.write(&data).await?;
11081091
progress.log_other_write(64);
11091092
}
11101093
EncodedItem::Leaf(leaf) => {
11111094
let len = leaf.data.len();
1112-
writer
1113-
.write_chunk(leaf.data)
1114-
.await
1115-
.map_err(io::Error::from)?;
1095+
writer.write_bytes(leaf.data).await?;
11161096
progress
11171097
.notify_payload_write(index, leaf.offset, len)
11181098
.await?;

src/api/remote.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616
get::{
1717
fsm::DecodeError,
1818
get_error::{BadRequestSnafu, LocalFailureSnafu},
19-
GetError, GetResult, Stats,
19+
GetError, GetResult, IrohStreamWriter, Stats,
2020
},
2121
protocol::{
2222
GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
@@ -594,15 +594,16 @@ impl Remote {
594594
let mut request_ranges = request.ranges.iter_infinite();
595595
let root = request.hash;
596596
let root_ranges = request_ranges.next().expect("infinite iterator");
597+
let mut send = IrohStreamWriter(send);
597598
if !root_ranges.is_empty() {
598599
self.store()
599600
.export_bao(root, root_ranges.clone())
600-
.write_quinn_with_progress(&mut send, &mut context, &root, 0)
601+
.write_with_progress(&mut send, &mut context, &root, 0)
601602
.await?;
602603
}
603604
if request.ranges.is_blob() {
604605
// we are done
605-
send.finish()?;
606+
send.0.finish()?;
606607
return Ok(Default::default());
607608
}
608609
let hash_seq = self.store().get_bytes(root).await?;
@@ -613,16 +614,11 @@ impl Remote {
613614
if !child_ranges.is_empty() {
614615
self.store()
615616
.export_bao(child_hash, child_ranges.clone())
616-
.write_quinn_with_progress(
617-
&mut send,
618-
&mut context,
619-
&child_hash,
620-
(child + 1) as u64,
621-
)
617+
.write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64)
622618
.await?;
623619
}
624620
}
625-
send.finish()?;
621+
send.0.finish()?;
626622
Ok(Default::default())
627623
}
628624

src/get.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub mod request;
4141
pub(crate) use error::get_error;
4242
pub use error::{GetError, GetResult};
4343

44-
pub struct IrohStreamWriter(iroh::endpoint::SendStream);
44+
pub struct IrohStreamWriter(pub iroh::endpoint::SendStream);
4545

4646
impl AsyncStreamWriter for IrohStreamWriter {
4747
async fn write(&mut self, data: &[u8]) -> io::Result<()> {
@@ -57,7 +57,7 @@ impl AsyncStreamWriter for IrohStreamWriter {
5757
}
5858
}
5959

60-
pub struct IrohStreamReader(iroh::endpoint::RecvStream);
60+
pub struct IrohStreamReader(pub iroh::endpoint::RecvStream);
6161

6262
impl AsyncStreamReader for IrohStreamReader {
6363
async fn read<const N: usize>(&mut self) -> io::Result<[u8; N]> {

src/provider.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::{
1212
use anyhow::{Context, Result};
1313
use bao_tree::ChunkRanges;
1414
use iroh::endpoint::{self, RecvStream, SendStream};
15+
use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
1516
use n0_future::StreamExt;
1617
use quinn::{ClosedStream, ConnectionError, ReadToEndError};
1718
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -23,6 +24,7 @@ use crate::{
2324
blobs::{Bitfield, WriteProgress},
2425
ExportBaoResult, Store,
2526
},
27+
get::{IrohStreamReader, IrohStreamWriter},
2628
hashseq::HashSeq,
2729
protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request},
2830
provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker},
@@ -31,6 +33,9 @@ use crate::{
3133
pub mod events;
3234
use events::EventSender;
3335

36+
type DefaultWriter = IrohStreamWriter;
37+
type DefaultReader = IrohStreamReader;
38+
3439
/// Statistics about a successful or failed transfer.
3540
#[derive(Debug, Serialize, Deserialize)]
3641
pub struct TransferStats {
@@ -106,7 +111,7 @@ impl StreamPair {
106111
return Err(e);
107112
};
108113
Ok(ProgressWriter::new(
109-
self.writer,
114+
IrohStreamWriter(self.writer),
110115
WriterContext {
111116
t0: self.t0,
112117
other_bytes_read: self.other_bytes_read,
@@ -130,7 +135,7 @@ impl StreamPair {
130135
return Err(e);
131136
};
132137
Ok(ProgressReader {
133-
inner: self.reader,
138+
inner: IrohStreamReader(self.reader),
134139
context: ReaderContext {
135140
t0: self.t0,
136141
other_bytes_read: self.other_bytes_read,
@@ -282,14 +287,14 @@ impl WriteProgress for WriterContext {
282287

283288
/// Wrapper for a [`quinn::SendStream`] with additional per request information.
284289
#[derive(Debug)]
285-
pub struct ProgressWriter {
290+
pub struct ProgressWriter<W: AsyncStreamWriter = DefaultWriter> {
286291
/// The quinn::SendStream to write to
287-
pub inner: SendStream,
292+
pub inner: W,
288293
pub(crate) context: WriterContext,
289294
}
290295

291-
impl ProgressWriter {
292-
fn new(inner: SendStream, context: WriterContext) -> Self {
296+
impl<W: AsyncStreamWriter> ProgressWriter<W> {
297+
fn new(inner: W, context: WriterContext) -> Self {
293298
Self { inner, context }
294299
}
295300

@@ -465,7 +470,7 @@ pub async fn handle_push(
465470
if !root_ranges.is_empty() {
466471
// todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
467472
store
468-
.import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner)
473+
.import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
469474
.await?;
470475
}
471476
if request.ranges.is_blob() {
@@ -480,7 +485,7 @@ pub async fn handle_push(
480485
continue;
481486
}
482487
store
483-
.import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner)
488+
.import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
484489
.await?;
485490
}
486491
Ok(())
@@ -496,7 +501,7 @@ pub(crate) async fn send_blob(
496501
) -> ExportBaoResult<()> {
497502
store
498503
.export_bao(hash, ranges)
499-
.write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
504+
.write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
500505
.await
501506
}
502507

@@ -527,7 +532,7 @@ pub async fn handle_observe(
527532
send_observe_item(writer, &diff).await?;
528533
old = new;
529534
}
530-
_ = writer.inner.stopped() => {
535+
_ = writer.inner.0.stopped() => {
531536
debug!("observer closed");
532537
break;
533538
}
@@ -539,13 +544,13 @@ pub async fn handle_observe(
539544
async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> {
540545
use irpc::util::AsyncWriteVarintExt;
541546
let item = ObserveItem::from(item);
542-
let len = writer.inner.write_length_prefixed(item).await?;
547+
let len = writer.inner.0.write_length_prefixed(item).await?;
543548
writer.context.log_other_write(len);
544549
Ok(())
545550
}
546551

547-
pub struct ProgressReader {
548-
inner: RecvStream,
552+
pub struct ProgressReader<R: AsyncStreamReader = DefaultReader> {
553+
inner: R,
549554
context: ReaderContext,
550555
}
551556

0 commit comments

Comments
 (0)