Skip to content

Commit

Permalink
[d3d12] get num_workgroups builtin working for indirect dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed May 22, 2024
1 parent 36281af commit 59891d9
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 12 deletions.
5 changes: 2 additions & 3 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
Expand All @@ -8,8 +8,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits::downlevel_defaults())
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
.limits(wgpu::Limits::downlevel_defaults()),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
Expand Down
4 changes: 2 additions & 2 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ impl Global {
device_id,
&crate::resource::BufferDescriptor {
label: None,
size: 4 * 3,
size: 4 * 3 * 2,
usage: wgt::BufferUsages::INDIRECT | wgt::BufferUsages::STORAGE,
mapped_at_creation: false,
},
Expand Down Expand Up @@ -956,7 +956,7 @@ impl Global {
buffer_id: dst_buffer_id,
offset: 0,
size: Some(
std::num::NonZeroU64::new(4 * 3).unwrap(),
std::num::NonZeroU64::new(4 * 3 * 2).unwrap(),
),
},
),
Expand Down
14 changes: 11 additions & 3 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1798,11 +1798,17 @@ impl Global {
@group(0) @binding(0)
var<uniform> src: vec3<u32>;
@group(0) @binding(1)
var<storage, read_write> dst: vec3<u32>;
var<storage, read_write> dst: array<u32, 6>;
@compute @workgroup_size(1)
fn main() {{
dst = select(src, vec3<u32>(), src > vec3({max_compute_workgroups_per_dimension}u));
let res = select(src, vec3<u32>(), src > vec3({max_compute_workgroups_per_dimension}u));
dst[0] = res.x;
dst[1] = res.y;
dst[2] = res.z;
dst[3] = res.x;
dst[4] = res.y;
dst[5] = res.z;
}}
");

Expand Down Expand Up @@ -1840,7 +1846,9 @@ impl Global {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(
std::num::NonZeroU64::new(4 * 3 * 2).unwrap(),
),
},
count: None,
},
Expand Down
3 changes: 2 additions & 1 deletion wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2681,7 +2681,8 @@ impl<A: HalApi> Device<A> {

let hal_desc = hal::PipelineLayoutDescriptor {
label: desc.label.to_hal(self.instance_flags),
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
bind_group_layouts: &raw_bind_group_layouts,
push_constant_ranges: desc.push_constant_ranges.as_ref(),
};
Expand Down
12 changes: 9 additions & 3 deletions wgpu-hal/src/dx12/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1193,11 +1193,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
self.list.as_ref().unwrap().dispatch(count);
}
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
self.prepare_dispatch([0; 3]);
//TODO: update special constants indirectly
self.update_root_elements();
let cmd_signature = if let Some(cmd_signatures) =
self.pass.layout.special_constants_cmd_signatures.as_mut()
{
cmd_signatures.dispatch.as_mut_ptr()
} else {
self.shared.cmd_signatures.dispatch.as_mut_ptr()
};
unsafe {
self.list.as_ref().unwrap().ExecuteIndirect(
self.shared.cmd_signatures.dispatch.as_mut_ptr(),
cmd_signature,
1,
buffer.resource.as_mut_ptr(),
offset,
Expand Down
45 changes: 45 additions & 0 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,50 @@ impl crate::Device for super::Device {
.create_root_signature(blob, 0)
.into_device_result("Root signature creation")?;

let special_constants_cmd_signatures =
if let Some(root_index) = special_constants_root_index {
Some(super::CommandSignatures {
draw: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::draw(),
],
12 + mem::size_of::<wgt::DrawIndirectArgs>() as u32,
0,
)
.into_device_result("Command (draw) signature creation")?,
draw_indexed: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::draw_indexed(),
],
12 + mem::size_of::<wgt::DrawIndexedIndirectArgs>() as u32,
0,
)
.into_device_result("Command (draw_indexed) signature creation")?,
dispatch: self
.raw
.create_command_signature(
raw.clone(),
&[
d3d12::IndirectArgument::constant(root_index, 0, 3),
d3d12::IndirectArgument::dispatch(),
],
12 + mem::size_of::<wgt::DispatchIndirectArgs>() as u32,
0,
)
.into_device_result("Command (dispatch) signature creation")?,
})
} else {
None
};

log::debug!("\traw = {:?}", raw);

if let Some(label) = desc.label {
Expand All @@ -1072,6 +1116,7 @@ impl crate::Device for super::Device {
signature: raw,
total_root_elements: parameters.len() as super::RootIndex,
special_constants_root_index,
special_constants_cmd_signatures,
root_constant_info,
},
bind_group_infos,
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ struct Idler {
event: d3d12::Event,
}

#[derive(Debug, Clone)]
struct CommandSignatures {
draw: d3d12::CommandSignature,
draw_indexed: d3d12::CommandSignature,
Expand Down Expand Up @@ -344,6 +345,7 @@ impl PassState {
signature: d3d12::RootSignature::null(),
total_root_elements: 0,
special_constants_root_index: None,
special_constants_cmd_signatures: None,
root_constant_info: None,
},
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
Expand Down Expand Up @@ -555,6 +557,7 @@ struct PipelineLayoutShared {
signature: d3d12::RootSignature,
total_root_elements: RootIndex,
special_constants_root_index: Option<RootIndex>,
special_constants_cmd_signatures: Option<CommandSignatures>,
root_constant_info: Option<RootConstantInfo>,
}

Expand Down

0 comments on commit 59891d9

Please sign in to comment.