Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion webgraph/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ dary_heap = "0.3.6"
rdst = { version = "0.20.14", features = ["multi-threaded"] }
sealed = "0.6.0"
serde = { workspace = true, optional = true }
crossbeam-queue = "0.3.12"
crossbeam-channel = "0.5"
parallel_frontier = "0.1.1"
thread_local = "1.1.8"

# Fuzzing deps
zip = { version = "6.0.0", optional = true }
Expand Down
110 changes: 69 additions & 41 deletions webgraph/src/utils/par_sort_pairs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@
//! If your pairs are emitted by a sequence of sequential iterators, consider
//! using [`ParSortIters`](crate::utils::par_sort_iters::ParSortIters) instead.

use std::cell::RefCell;
use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use anyhow::{Context, Result, ensure};
use crossbeam_queue::SegQueue;
use dsi_progress_logger::{ProgressLog, concurrent_progress_logger};
use rayon::Yield;
use rayon::prelude::*;
use thread_local::ThreadLocal;

use crate::utils::DefaultBatchCodec;

Expand Down Expand Up @@ -271,7 +270,7 @@ impl ParSortPairs {
let presort_tmp_dir =
tempfile::tempdir().context("Could not create temporary directory")?;

let sorter_thread_states = ThreadLocal::<RefCell<SorterThreadState<C>>>::new();
let sorter_thread_states = Arc::new(SegQueue::<SorterThreadState<C>>::new());

// iterators in partitioned_presorted_pairs[partition_id] contain all pairs (src, dst, label)
// where num_nodes_per_partition*partition_id <= src < num_nodes_per_partition*(partition_id+1)
Expand All @@ -286,35 +285,20 @@ impl ParSortPairs {
// Thus, we use ThreadLocal to have one SorterThreadState per thread, which is reused
// across multiple sequential iterators.
|| {
(
pl.clone(),
loop {
if let Ok(state) = sorter_thread_states
.get_or(|| {
RefCell::new(SorterThreadState {
worker_id: worker_id.fetch_add(1, Ordering::Relaxed),
unsorted_buffers: (0..num_partitions)
.map(|_| Vec::with_capacity(batch_size))
.collect(),
sorted_pairs: (0..num_partitions).map(|_| Vec::new()).collect(),
})
})
.try_borrow_mut() {
// usually succeeds in the first attempt
break state;
}
// This thread is already borrowing its state higher in the call stack,
// but rayon is calling us again because of work stealing.
// But we cannot work right now (without allocating a new state, that is)
// so we yield back to rayon so it can resume the task that is already
// running in this thread.
match rayon::yield_now() {
None => panic!("rayon::yield_now() claims we are not running in a thread pool"),
Some(Yield::Idle) => panic!("Thread state is already borrowed, but there are no other tasks running"),
Some(Yield::Executed) => (),
}
}
)
let mut state = sorter_thread_states
.pop()
.unwrap_or_else(|| SorterThreadState {
worker_id: worker_id.fetch_add(1, Ordering::Relaxed),
unsorted_buffers: (0..num_partitions)
.map(|_| Vec::with_capacity(batch_size))
.collect(),
sorted_pairs: (0..num_partitions).map(|_| Vec::new()).collect(),
queue: None,
});

// So it adds itself back to the queue when dropped
state.queue = Some(Arc::clone(&sorter_thread_states));
(pl.clone(), state)
},
|(pl, thread_state), pair| -> Result<_> {
let ((src, dst), label) = pair.map_err(Into::into)?;
Expand All @@ -328,7 +312,8 @@ impl ParSortPairs {
worker_id,
sorted_pairs,
unsorted_buffers,
} = &mut **thread_state;
queue: _,
} = thread_state;

let sorted_pairs = &mut sorted_pairs[partition_id];
let buf = &mut unsorted_buffers[partition_id];
Expand All @@ -352,17 +337,24 @@ impl ParSortPairs {
},
)?;

// Collect them into an iterable
let sorter_thread_states: Vec<_> = std::iter::repeat(())
.map_while(|()| sorter_thread_states.pop())
.collect();

// flush remaining buffers
let partitioned_presorted_pairs: Vec<Vec<CodecIter<C>>> = sorter_thread_states
.into_iter()
.collect::<Vec<_>>()
.into_par_iter()
.map_with(pl.clone(), |pl, thread_state: RefCell<SorterThreadState<C>>| {
let thread_state = thread_state.into_inner();
.map_with(pl.clone(), |pl, mut thread_state: SorterThreadState<C>| {
let mut sorted_pairs = Vec::new();
std::mem::swap(&mut sorted_pairs, &mut thread_state.sorted_pairs);
let mut unsorted_buffers = Vec::new();
std::mem::swap(&mut unsorted_buffers, &mut thread_state.unsorted_buffers);

let mut partitioned_sorted_pairs = Vec::with_capacity(num_partitions);
assert_eq!(thread_state.sorted_pairs.len(), num_partitions);
assert_eq!(thread_state.unsorted_buffers.len(), num_partitions);
for (partition_id, (mut sorted_pairs, mut buf)) in thread_state.sorted_pairs.into_iter().zip(thread_state.unsorted_buffers.into_iter()).enumerate() {
assert_eq!(sorted_pairs.len(), num_partitions);
assert_eq!(unsorted_buffers.len(), num_partitions);
for (partition_id, (mut sorted_pairs, mut buf)) in sorted_pairs.into_iter().zip(unsorted_buffers.into_iter()).enumerate() {
let buf_len = buf.len();
flush_buffer(presort_tmp_dir.path(), batch_codec, thread_state.worker_id, partition_id, &mut sorted_pairs, &mut buf).context("Could not flush buffer at the end")?;
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
Expand Down Expand Up @@ -419,6 +411,42 @@ struct SorterThreadState<C: BatchCodec> {
worker_id: usize,
sorted_pairs: Vec<Vec<CodecIter<C>>>,
unsorted_buffers: Vec<Vec<((usize, usize), C::Label)>>,
/// Where should this SorterThreadState put itself back to when dropped
queue: Option<Arc<SegQueue<Self>>>,
}

impl<C: BatchCodec> SorterThreadState<C> {
fn new_empty() -> Self {
SorterThreadState {
worker_id: usize::MAX,
sorted_pairs: Vec::new(),
unsorted_buffers: Vec::new(),
queue: None,
}
}
}

impl<C: BatchCodec> Drop for SorterThreadState<C> {
fn drop(&mut self) {
match self.queue.take() {
Some(queue) => {
// Put self back on the queue
let mut other_self = Self::new_empty();
std::mem::swap(&mut other_self, self);
queue.push(other_self);
}
None => {
assert!(
self.sorted_pairs.iter().all(|vec| vec.is_empty()),
"Dropped SorterThreadState without consuming sorted_pairs"
);
assert!(
self.unsorted_buffers.iter().all(|vec| vec.is_empty()),
"Dropped SorterThreadState without consuming unsorted_buffers"
);
}
}
}
}

pub(crate) fn flush_buffer<C: BatchCodec>(
Expand Down