|
| 1 | +//! Parallel data generation support: [`Source`] and [`Sink`] and [`generate_in_chunks`] |
| 2 | +
|
| 3 | +use futures::StreamExt; |
| 4 | +use std::collections::VecDeque; |
| 5 | +use std::io; |
| 6 | +use std::sync::{Arc, Mutex}; |
| 7 | +use tokio::task::JoinSet; |
| 8 | + |
| 9 | +/// Something that knows how to generate data into a buffer |
| 10 | +/// |
| 11 | +/// For example, this is implemented for the different generators in the tpchgen |
| 12 | +/// crate |
| 13 | +pub trait Source: Send { |
| 14 | + /// generates the data for this generator into the buffer, returning the buffer. |
| 15 | + fn create(self, buffer: Vec<u8>) -> Vec<u8>; |
| 16 | +} |
| 17 | + |
| 18 | +/// Something that can write the contents of a buffer somewhere |
| 19 | +/// |
| 20 | +/// For example, this is implemented for a file writer |
| 21 | +pub trait Sink: Send { |
| 22 | + /// Write all data from the buffer to the sink |
| 23 | + fn sink(&mut self, buffer: &[u8]); |
| 24 | + |
| 25 | + fn finish(self) -> Result<(), io::Error>; |
| 26 | +} |
| 27 | + |
| 28 | +/// Creates data from the Generators in parallel and invokes the provided |
| 29 | +/// function on each generated buffer |
| 30 | +/// |
| 31 | +/// |
| 32 | +/// G: Generator |
| 33 | +/// I: Iterator<Item = G> |
| 34 | +/// S: Sink that writes buffers somewhere |
| 35 | +pub async fn generate_in_chunks<G, I, S>(mut sink: S, generators: I) -> Result<(), io::Error> |
| 36 | +where |
| 37 | + G: Source + 'static, |
| 38 | + I: Iterator<Item = G>, |
| 39 | + S: Sink + 'static, |
| 40 | +{ |
| 41 | + let recycler = BufferRecycler::new(); |
| 42 | + |
| 43 | + // use all cores to make data |
| 44 | + let num_tasks = num_cpus::get(); |
| 45 | + println!("Using {num_tasks} parallel tasks"); |
| 46 | + |
| 47 | + // create a channel to communicate between the generator tasks and the writer task |
| 48 | + let (tx, mut rx) = tokio::sync::mpsc::channel(num_tasks); |
| 49 | + |
| 50 | + let generators_and_recyclers = generators.map(|generator| (generator, recycler.clone())); |
| 51 | + |
| 52 | + // convert to an async stream to run on tokio |
| 53 | + let mut stream = futures::stream::iter(generators_and_recyclers) |
| 54 | + // each generator writes to a buffer |
| 55 | + .map(async |(generator, recycler)| { |
| 56 | + let buffer = recycler.new_buffer(1024 * 1024 * 8); |
| 57 | + // do the work in a task (on a different thread) |
| 58 | + let mut join_set = JoinSet::new(); |
| 59 | + join_set.spawn(async move { generator.create(buffer) }); |
| 60 | + // wait for the task to be done and return the result |
| 61 | + let buffer = join_set |
| 62 | + .join_next() |
| 63 | + .await |
| 64 | + .expect("had one item") |
| 65 | + .expect("join_next join is infallible unless task panics"); |
| 66 | + // send the buffer to the writer task, ignoring error (if the writer errored) |
| 67 | + tx.send(buffer).await.ok(); |
| 68 | + }) |
| 69 | + // run in parallel |
| 70 | + .buffered(num_tasks); |
| 71 | + |
| 72 | + let captured_recycler = recycler.clone(); |
| 73 | + let writer_task = tokio::task::spawn_blocking(move || { |
| 74 | + while let Some(buffer) = rx.blocking_recv() { |
| 75 | + //println!("writing buffer with {} bytes", buffer.len()); |
| 76 | + sink.sink(&buffer); |
| 77 | + captured_recycler.return_buffer(buffer); |
| 78 | + } |
| 79 | + sink.finish() |
| 80 | + }); |
| 81 | + |
| 82 | + // drive the stream to completion |
| 83 | + while stream.next().await.is_some() {} |
| 84 | + println!("stream done, dropping"); |
| 85 | + drop(stream); // drop any stream references |
| 86 | + drop(tx); // drop last tx reference to stop the writer |
| 87 | + |
| 88 | + // wait for writer to finish |
| 89 | + println!("waiting on writer task"); |
| 90 | + writer_task.await.expect("writer task panicked") |
| 91 | +} |
| 92 | + |
| 93 | +/// A simple buffer recycler to avoid allocating new buffers for each part |
| 94 | +/// |
| 95 | +/// Clones share the same undrlying recycler, so it is not thread safe |
| 96 | +#[derive(Debug, Clone)] |
| 97 | +struct BufferRecycler { |
| 98 | + buffers: Arc<Mutex<VecDeque<Vec<u8>>>>, |
| 99 | +} |
| 100 | + |
| 101 | +impl BufferRecycler { |
| 102 | + fn new() -> Self { |
| 103 | + Self { |
| 104 | + buffers: Arc::new(Mutex::new(VecDeque::new())), |
| 105 | + } |
| 106 | + } |
| 107 | + /// return a new empty buffer, with size bytes capacity |
| 108 | + fn new_buffer(&self, size: usize) -> Vec<u8> { |
| 109 | + let mut buffers = self.buffers.lock().unwrap(); |
| 110 | + if let Some(mut buffer) = buffers.pop_front() { |
| 111 | + buffer.clear(); |
| 112 | + if size > buffer.capacity() { |
| 113 | + buffer.reserve(size - buffer.capacity()); |
| 114 | + } |
| 115 | + buffer |
| 116 | + } else { |
| 117 | + Vec::with_capacity(size) |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + fn return_buffer(&self, buffer: Vec<u8>) { |
| 122 | + let mut buffers = self.buffers.lock().unwrap(); |
| 123 | + buffers.push_back(buffer); |
| 124 | + } |
| 125 | +} |
0 commit comments