Skip to content

Commit

Permalink
use 2 destination buffers for indirect dispatch validation
Browse files Browse the repository at this point in the history
This removes the required barrier prior to the validation dispatch.
  • Loading branch information
teoxoy committed Aug 23, 2024
1 parent 3a3cedf commit 18cbf48
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 26 deletions.
3 changes: 3 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ async fn run_test(
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, &bind_group, &[]);
}
// Issue multiple dispatches to test the internal destination buffer switching
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

Expand Down
21 changes: 10 additions & 11 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,13 +944,6 @@ fn dispatch_indirect(
state.raw_encoder.transition_buffers(src_barrier.as_slice());
}

unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
}]);
}

unsafe {
state.raw_encoder.dispatch([1, 1, 1]);
}
Expand Down Expand Up @@ -989,10 +982,16 @@ fn dispatch_indirect(
}

unsafe {
state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
}]);
state.raw_encoder.transition_buffers(&[
hal::BufferBarrier {
buffer: params.dst_buffer,
usage: hal::BufferUses::STORAGE_READ_WRITE..hal::BufferUses::INDIRECT,
},
hal::BufferBarrier {
buffer: params.other_dst_buffer,
usage: hal::BufferUses::INDIRECT..hal::BufferUses::STORAGE_READ_WRITE,
},
]);
}

state.flush_states(None)?;
Expand Down
87 changes: 72 additions & 15 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::AtomicBool;

use thiserror::Error;

use crate::{
Expand Down Expand Up @@ -33,14 +35,18 @@ pub struct IndirectValidation {
src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
pipeline: Box<dyn hal::DynComputePipeline>,
dst_buffer: Box<dyn hal::DynBuffer>,
dst_bind_group: Box<dyn hal::DynBindGroup>,
dst_buffer_0: Box<dyn hal::DynBuffer>,
dst_buffer_1: Box<dyn hal::DynBuffer>,
dst_bind_group_0: Box<dyn hal::DynBindGroup>,
dst_bind_group_1: Box<dyn hal::DynBindGroup>,
is_next_dst_0: AtomicBool,
}

pub struct Params<'a> {
pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
pub pipeline: &'a dyn hal::DynComputePipeline,
pub dst_buffer: &'a dyn hal::DynBuffer,
pub other_dst_buffer: &'a dyn hal::DynBuffer,
pub dst_bind_group: &'a dyn hal::DynBindGroup,
pub aligned_offset: u64,
pub offset_remainder: u64,
Expand Down Expand Up @@ -206,10 +212,12 @@ impl IndirectValidation {
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
let dst_buffer =
let dst_buffer_0 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?;
let dst_buffer_1 =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from)?;

let dst_bind_group_desc = hal::BindGroupDescriptor {
let dst_bind_group_desc_0 = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
Expand All @@ -218,17 +226,40 @@ impl IndirectValidation {
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer.as_ref(),
buffer: dst_buffer_0.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group = unsafe {
let dst_bind_group_0 = unsafe {
device
.create_bind_group(&dst_bind_group_desc)
.create_bind_group(&dst_bind_group_desc_0)
.map_err(DeviceError::from)
}?;

let dst_bind_group_desc_1 = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer_1.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group_1 = unsafe {
device
.create_bind_group(&dst_bind_group_desc_1)
.map_err(DeviceError::from)
}?;

Expand All @@ -238,8 +269,11 @@ impl IndirectValidation {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: AtomicBool::new(false),
})
}

Expand Down Expand Up @@ -298,11 +332,29 @@ impl IndirectValidation {
let aligned_offset = aligned_offset.min(max_aligned_offset);
let offset_remainder = offset - aligned_offset;

let (dst_buffer, other_dst_buffer, dst_bind_group) = if self
.is_next_dst_0
.fetch_xor(true, core::sync::atomic::Ordering::AcqRel)
{
(
self.dst_buffer_0.as_ref(),
self.dst_buffer_1.as_ref(),
self.dst_bind_group_0.as_ref(),
)
} else {
(
self.dst_buffer_1.as_ref(),
self.dst_buffer_0.as_ref(),
self.dst_bind_group_1.as_ref(),
)
};

Params {
pipeline_layout: self.pipeline_layout.as_ref(),
pipeline: self.pipeline.as_ref(),
dst_buffer: self.dst_buffer.as_ref(),
dst_bind_group: self.dst_bind_group.as_ref(),
dst_buffer,
other_dst_buffer,
dst_bind_group,
aligned_offset,
offset_remainder,
}
Expand All @@ -315,13 +367,18 @@ impl IndirectValidation {
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
dst_buffer_0,
dst_buffer_1,
dst_bind_group_0,
dst_bind_group_1,
is_next_dst_0: _,
} = self;

unsafe {
device.destroy_bind_group(dst_bind_group);
device.destroy_buffer(dst_buffer);
device.destroy_bind_group(dst_bind_group_0);
device.destroy_bind_group(dst_bind_group_1);
device.destroy_buffer(dst_buffer_0);
device.destroy_buffer(dst_buffer_1);
device.destroy_compute_pipeline(pipeline);
device.destroy_pipeline_layout(pipeline_layout);
device.destroy_bind_group_layout(src_bind_group_layout);
Expand Down

0 comments on commit 18cbf48

Please sign in to comment.