Skip to content

Commit

Permalink
[d3d12] get num_workgroups builtin working for indirect dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Aug 23, 2024
1 parent 18cbf48 commit 164a658
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ By @wumpf in [#6069](https://github.com/gfx-rs/wgpu/pull/6069), [#6099](https://
#### DX12

- Replace `winapi` code to use the `windows` crate. By @MarijnS95 in [#5956](https://github.com/gfx-rs/wgpu/pull/5956)
- Get `num_workgroups` builtin working for indirect dispatches. By @teoxoy in [#5730](https://github.com/gfx-rs/wgpu/pull/5730)

## 22.0.0 (2024-07-17)

Expand Down
8 changes: 3 additions & 5 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
Expand All @@ -12,8 +12,7 @@ static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new(
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
Expand All @@ -34,8 +33,7 @@ static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
}),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;
Expand Down
3 changes: 2 additions & 1 deletion wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2637,7 +2637,8 @@ impl Device {

let hal_desc = hal::PipelineLayoutDescriptor {
label: desc.label.to_hal(self.instance_flags),
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE,
flags: hal::PipelineLayoutFlags::FIRST_VERTEX_INSTANCE
| hal::PipelineLayoutFlags::NUM_WORK_GROUPS,
bind_group_layouts: &raw_bind_group_layouts,
push_constant_ranges: desc.push_constant_ranges.as_ref(),
};
Expand Down
21 changes: 13 additions & 8 deletions wgpu-core/src/indirect_validation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::atomic::AtomicBool;
use std::{num::NonZeroU64, sync::atomic::AtomicBool};

use thiserror::Error;

Expand Down Expand Up @@ -61,7 +61,7 @@ impl IndirectValidation {

let src = format!("
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 3>;
var<storage, read_write> dst: array<u32, 6>;
@group(1) @binding(0)
var<storage, read> src: array<u32>;
struct OffsetPc {{
Expand All @@ -76,6 +76,9 @@ impl IndirectValidation {
dst[0] = res.x;
dst[1] = res.y;
dst[2] = res.z;
dst[3] = res.x;
dst[4] = res.y;
dst[5] = res.z;
}}
");

Expand Down Expand Up @@ -121,6 +124,8 @@ impl IndirectValidation {
}
})?;

const DST_BUFFER_SIZE: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(4 * 3 * 2) };

let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
Expand All @@ -130,7 +135,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(DST_BUFFER_SIZE),
},
count: None,
}],
Expand All @@ -150,7 +155,7 @@ impl IndirectValidation {
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true,
min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
min_binding_size: Some(NonZeroU64::new(4 * 3).unwrap()),
},
count: None,
}],
Expand Down Expand Up @@ -208,7 +213,7 @@ impl IndirectValidation {

let dst_buffer_desc = hal::BufferDescriptor {
label: None,
size: 4 * 3,
size: DST_BUFFER_SIZE.get(),
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
Expand All @@ -228,7 +233,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer: dst_buffer_0.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand All @@ -251,7 +256,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer: dst_buffer_1.as_ref(),
offset: 0,
size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()),
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
Expand Down Expand Up @@ -296,7 +301,7 @@ impl IndirectValidation {
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: Some(std::num::NonZeroU64::new(binding_size).unwrap()),
size: Some(NonZeroU64::new(binding_size).unwrap()),
}],
samplers: &[],
textures: &[],
Expand Down
12 changes: 9 additions & 3 deletions wgpu-hal/src/dx12/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1198,11 +1198,17 @@ impl crate::CommandEncoder for super::CommandEncoder {
}

unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
self.prepare_dispatch([0; 3]);
//TODO: update special constants indirectly
self.update_root_elements();
let cmd_signature = if let Some(cmd_signatures) =
self.pass.layout.special_constants_cmd_signatures.as_ref()
{
&cmd_signatures.dispatch
} else {
&self.shared.cmd_signatures.dispatch
};
unsafe {
self.list.as_ref().unwrap().ExecuteIndirect(
&self.shared.cmd_signatures.dispatch,
cmd_signature,
1,
&buffer.resource,
offset,
Expand Down
114 changes: 88 additions & 26 deletions wgpu-hal/src/dx12/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,52 +93,32 @@ impl super::Device {
let capacity_views = limits.max_non_sampler_bindings as u64;
let capacity_samplers = 2_048;

fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
None,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::ResourceCreationFailed)
}

let shared = super::DeviceShared {
zero_buffer,
cmd_signatures: super::CommandSignatures {
draw: create_command_signature(
draw: Self::create_command_signature(
&raw,
None,
mem::size_of::<wgt::DrawIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
..Default::default()
}],
0,
)?,
draw_indexed: create_command_signature(
draw_indexed: Self::create_command_signature(
&raw,
None,
mem::size_of::<wgt::DrawIndexedIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
..Default::default()
}],
0,
)?,
dispatch: create_command_signature(
dispatch: Self::create_command_signature(
&raw,
None,
mem::size_of::<wgt::DispatchIndirectArgs>(),
&[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
Expand Down Expand Up @@ -213,6 +193,30 @@ impl super::Device {
})
}

fn create_command_signature(
raw: &Direct3D12::ID3D12Device,
root_signature: Option<&Direct3D12::ID3D12RootSignature>,
byte_stride: usize,
arguments: &[Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC],
node_mask: u32,
) -> Result<Direct3D12::ID3D12CommandSignature, crate::DeviceError> {
let mut signature = None;
unsafe {
raw.CreateCommandSignature(
&Direct3D12::D3D12_COMMAND_SIGNATURE_DESC {
ByteStride: byte_stride as u32,
NumArgumentDescs: arguments.len() as u32,
pArgumentDescs: arguments.as_ptr(),
NodeMask: node_mask,
},
root_signature,
&mut signature,
)
}
.into_device_result("Command signature creation")?;
signature.ok_or(crate::DeviceError::ResourceCreationFailed)
}

// Blocks until the dedicated present queue is finished with all of its work.
//
// Once this method completes, the surface is able to be resized or deleted.
Expand Down Expand Up @@ -1112,6 +1116,63 @@ impl crate::Device for super::Device {
}
.into_device_result("Root signature creation")?;

let special_constants_cmd_signatures =
if let Some(root_index) = special_constants_root_index {
let constant_indirect_argument_desc = Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT,
Anonymous: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0 {
Constant: Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC_0_1 {
RootParameterIndex: root_index,
DestOffsetIn32BitValues: 0,
Num32BitValuesToSet: 3,
},
},
};
Some(super::CommandSignatures {
draw: Self::create_command_signature(
&self.raw,
Some(&raw),
12 + mem::size_of::<wgt::DrawIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW,
..Default::default()
},
],
0,
)?,
draw_indexed: Self::create_command_signature(
&self.raw,
Some(&raw),
12 + mem::size_of::<wgt::DrawIndexedIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED,
..Default::default()
},
],
0,
)?,
dispatch: Self::create_command_signature(
&self.raw,
Some(&raw),
12 + mem::size_of::<wgt::DispatchIndirectArgs>(),
&[
constant_indirect_argument_desc,
Direct3D12::D3D12_INDIRECT_ARGUMENT_DESC {
Type: Direct3D12::D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH,
..Default::default()
},
],
0,
)?,
})
} else {
None
};

if let Some(label) = desc.label {
unsafe { raw.SetName(&windows::core::HSTRING::from(label)) }
.into_device_result("SetName")?;
Expand All @@ -1124,6 +1185,7 @@ impl crate::Device for super::Device {
signature: Some(raw),
total_root_elements: parameters.len() as super::RootIndex,
special_constants_root_index,
special_constants_cmd_signatures,
root_constant_info,
},
bind_group_infos,
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/dx12/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ struct Idler {
event: Event,
}

#[derive(Debug, Clone)]
struct CommandSignatures {
draw: Direct3D12::ID3D12CommandSignature,
draw_indexed: Direct3D12::ID3D12CommandSignature,
Expand Down Expand Up @@ -634,6 +635,7 @@ impl PassState {
signature: None,
total_root_elements: 0,
special_constants_root_index: None,
special_constants_cmd_signatures: None,
root_constant_info: None,
},
root_elements: [RootElement::Empty; MAX_ROOT_ELEMENTS],
Expand Down Expand Up @@ -871,6 +873,7 @@ struct PipelineLayoutShared {
signature: Option<Direct3D12::ID3D12RootSignature>,
total_root_elements: RootIndex,
special_constants_root_index: Option<RootIndex>,
special_constants_cmd_signatures: Option<CommandSignatures>,
root_constant_info: Option<RootConstantInfo>,
}

Expand Down

0 comments on commit 164a658

Please sign in to comment.