diff --git a/webgraph/Cargo.toml b/webgraph/Cargo.toml index fb38a53c..ffa3ba47 100644 --- a/webgraph/Cargo.toml +++ b/webgraph/Cargo.toml @@ -57,9 +57,9 @@ dary_heap = "0.3.6" rdst = { version = "0.20.14", features = ["multi-threaded"] } sealed = "0.6.0" serde = { workspace = true, optional = true } +crossbeam-queue = "0.3.12" crossbeam-channel = "0.5" parallel_frontier = "0.1.1" -thread_local = "1.1.8" # Fuzzing deps zip = { version = "6.0.0", optional = true } diff --git a/webgraph/src/utils/par_sort_pairs.rs b/webgraph/src/utils/par_sort_pairs.rs index eed5bb2c..8c0e6921 100644 --- a/webgraph/src/utils/par_sort_pairs.rs +++ b/webgraph/src/utils/par_sort_pairs.rs @@ -21,16 +21,15 @@ //! If your pairs are emitted by a sequence of sequential iterators, consider //! using [`ParSortIters`](crate::utils::par_sort_iters::ParSortIters) instead. -use std::cell::RefCell; use std::num::NonZeroUsize; use std::path::Path; +use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use anyhow::{Context, Result, ensure}; +use crossbeam_queue::SegQueue; use dsi_progress_logger::{ProgressLog, concurrent_progress_logger}; -use rayon::Yield; use rayon::prelude::*; -use thread_local::ThreadLocal; use crate::utils::DefaultBatchCodec; @@ -271,7 +270,7 @@ impl ParSortPairs { let presort_tmp_dir = tempfile::tempdir().context("Could not create temporary directory")?; - let sorter_thread_states = ThreadLocal::>>::new(); + let sorter_thread_states = Arc::new(SegQueue::>::new()); // iterators in partitioned_presorted_pairs[partition_id] contain all pairs (src, dst, label) // where num_nodes_per_partition*partition_id <= src < num_nodes_per_partition*(partition_id+1) @@ -286,35 +285,20 @@ impl ParSortPairs { // Thus, we use ThreadLocal to have one SorterThreadState per thread, which is reused // across multiple sequential iterators. || { - ( - pl.clone(), - loop { - if let Ok(state) = sorter_thread_states - .get_or(|| { - RefCell::new(SorterThreadState { - worker_id: worker_id.fetch_add(1, Ordering::Relaxed), - unsorted_buffers: (0..num_partitions) - .map(|_| Vec::with_capacity(batch_size)) - .collect(), - sorted_pairs: (0..num_partitions).map(|_| Vec::new()).collect(), - }) - }) - .try_borrow_mut() { - // usually succeeds in the first attempt - break state; - } - // This thread is already borrowing its state higher in the call stack, - // but rayon is calling us again because of work stealing. - // But we cannot work right now (without allocating a new state, that is) - // so we yield back to rayon so it can resume the task that is already - // running in this thread. - match rayon::yield_now() { - None => panic!("rayon::yield_now() claims we are not running in a thread pool"), - Some(Yield::Idle) => panic!("Thread state is already borrowed, but there are no other tasks running"), - Some(Yield::Executed) => (), - } - } - ) + let mut state = sorter_thread_states + .pop() + .unwrap_or_else(|| SorterThreadState { + worker_id: worker_id.fetch_add(1, Ordering::Relaxed), + unsorted_buffers: (0..num_partitions) + .map(|_| Vec::with_capacity(batch_size)) + .collect(), + sorted_pairs: (0..num_partitions).map(|_| Vec::new()).collect(), + queue: None, + }); + + // So it adds itself back to the queue when dropped + state.queue = Some(Arc::clone(&sorter_thread_states)); + (pl.clone(), state) }, |(pl, thread_state), pair| -> Result<_> { let ((src, dst), label) = pair.map_err(Into::into)?; @@ -328,7 +312,8 @@ impl ParSortPairs { worker_id, sorted_pairs, unsorted_buffers, - } = &mut **thread_state; + queue: _, + } = thread_state; let sorted_pairs = &mut sorted_pairs[partition_id]; let buf = &mut unsorted_buffers[partition_id]; @@ -352,17 +337,24 @@ impl ParSortPairs { }, )?; + // Collect them into an iterable + let sorter_thread_states: Vec<_> = std::iter::repeat(()) + .map_while(|()| sorter_thread_states.pop()) + .collect(); + // flush remaining buffers let partitioned_presorted_pairs: Vec>> = sorter_thread_states - .into_iter() - .collect::>() .into_par_iter() - .map_with(pl.clone(), |pl, thread_state: RefCell>| { - let thread_state = thread_state.into_inner(); + .map_with(pl.clone(), |pl, mut thread_state: SorterThreadState| { + let mut sorted_pairs = Vec::new(); + std::mem::swap(&mut sorted_pairs, &mut thread_state.sorted_pairs); + let mut unsorted_buffers = Vec::new(); + std::mem::swap(&mut unsorted_buffers, &mut thread_state.unsorted_buffers); + let mut partitioned_sorted_pairs = Vec::with_capacity(num_partitions); - assert_eq!(thread_state.sorted_pairs.len(), num_partitions); - assert_eq!(thread_state.unsorted_buffers.len(), num_partitions); - for (partition_id, (mut sorted_pairs, mut buf)) in thread_state.sorted_pairs.into_iter().zip(thread_state.unsorted_buffers.into_iter()).enumerate() { + assert_eq!(sorted_pairs.len(), num_partitions); + assert_eq!(unsorted_buffers.len(), num_partitions); + for (partition_id, (mut sorted_pairs, mut buf)) in sorted_pairs.into_iter().zip(unsorted_buffers.into_iter()).enumerate() { let buf_len = buf.len(); flush_buffer(presort_tmp_dir.path(), batch_codec, thread_state.worker_id, partition_id, &mut sorted_pairs, &mut buf).context("Could not flush buffer at the end")?; assert!(buf.is_empty(), "flush_buffer did not empty the buffer"); @@ -419,6 +411,42 @@ struct SorterThreadState { worker_id: usize, sorted_pairs: Vec>>, unsorted_buffers: Vec>, + /// Where should this SorterThreadState put itself back to when dropped + queue: Option>>, +} + +impl SorterThreadState { + fn new_empty() -> Self { + SorterThreadState { + worker_id: usize::MAX, + sorted_pairs: Vec::new(), + unsorted_buffers: Vec::new(), + queue: None, + } + } +} + +impl Drop for SorterThreadState { + fn drop(&mut self) { + match self.queue.take() { + Some(queue) => { + // Put self back on the queue + let mut other_self = Self::new_empty(); + std::mem::swap(&mut other_self, self); + queue.push(other_self); + } + None => { + assert!( + self.sorted_pairs.iter().all(|vec| vec.is_empty()), + "Dropped SorterThreadState without consuming sorted_pairs" + ); + assert!( + self.unsorted_buffers.iter().all(|vec| vec.is_empty()), + "Dropped SorterThreadState without consuming unsorted_buffers" + ); + } + } + } } pub(crate) fn flush_buffer(