diff --git a/Cargo.toml b/Cargo.toml index 7dda86c8f1418..e0a842921b76e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3020,6 +3020,17 @@ description = "A very simple compute shader that writes to a buffer that is read category = "Shaders" wasm = false +[[example]] +name = "compute_mesh" +path = "examples/shader_advanced/compute_mesh.rs" +doc-scrape-examples = true + +[package.metadata.example.compute_mesh] +name = "Compute Shader Mesh" +description = "A compute shader that generates a mesh that is controlled by a Handle" +category = "Shaders" +wasm = false + [[example]] name = "array_texture" path = "examples/shader/array_texture.rs" diff --git a/assets/shaders/compute_mesh.wgsl b/assets/shaders/compute_mesh.wgsl new file mode 100644 index 0000000000000..40d7a63aca885 --- /dev/null +++ b/assets/shaders/compute_mesh.wgsl @@ -0,0 +1,96 @@ +// This shader is used for the compute_mesh example +// The actual work it does is not important for the example and +// has been hardcoded to return a cube mesh + +// `vertex_start` is the starting offset of the mesh data in the *vertex_data* storage buffer +// `index_start` is the starting offset of the index data in the *index_data* storage buffer +struct DataRanges { + vertex_start: u32, + vertex_end: u32, + index_start: u32, + index_end: u32, +} + +@group(0) @binding(0) var data_range: DataRanges; +@group(0) @binding(1) var vertex_data: array; +@group(0) @binding(2) var index_data: array; + +@compute @workgroup_size(1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + // this loop is iterating over the full list of (position, normal, uv) + // data what we have in `vertices`. + // `192` is used because arrayLength on const arrays doesn't work + for (var i = 0u; i < 192; i++) { + // The vertex_data buffer is bigger than just the mesh we're + // processing because Bevy stores meshes in the mesh_allocator + // which allocates slabs that each can contain multiple meshes. + // This buffer is one slab, and data_range.vertex_start is the + // starting offset for the mesh we care about. + // So the 0 starting value in the for loop is added to + // data_range.vertex_start which means we start writing at the + // correct offset. + // + // The "end" of the available space to write into is known by us + // ahead of time in this example, so we know this has enough space, + // but you may wish to also check to make sure you are not writing + // past the end of the range *because you should not write past the + // end of the range ever*. Doing this can overwrite a different + // mesh's data. + vertex_data[i + data_range.vertex_start] = vertices[i]; + } + // `36` is the length of the `indices` array + for (var i = 0u; i < 36; i++) { + // This is doing the same as the vertex_data offset described above + index_data[i + data_range.index_start] = u32(indices[i]); + } +} + +// hardcoded compute shader data. +// half_size is half the size of the cube +const half_size = vec3(1.5); +const min = -half_size; +const max = half_size; + +// Suppose Y-up right hand, and camera look from +Z to -Z +const vertices = array( + // xyz, normal.xyz, uv.xy + // Front + min.x, min.y, max.z, 0.0, 0.0, 1.0, 0.0, 0.0, + max.x, min.y, max.z, 0.0, 0.0, 1.0, 1.0, 0.0, + max.x, max.y, max.z, 0.0, 0.0, 1.0, 1.0, 1.0, + min.x, max.y, max.z, 0.0, 0.0, 1.0, 0.0, 1.0, + // Back + min.x, max.y, min.z, 0.0, 0.0, -1.0, 1.0, 0.0, + max.x, max.y, min.z, 0.0, 0.0, -1.0, 0.0, 0.0, + max.x, min.y, min.z, 0.0, 0.0, -1.0, 0.0, 1.0, + min.x, min.y, min.z, 0.0, 0.0, -1.0, 1.0, 1.0, + // Right + max.x, min.y, min.z, 1.0, 0.0, 0.0, 0.0, 0.0, + max.x, max.y, min.z, 1.0, 0.0, 0.0, 1.0, 0.0, + max.x, max.y, max.z, 1.0, 0.0, 0.0, 1.0, 1.0, + max.x, min.y, max.z, 1.0, 0.0, 0.0, 0.0, 1.0, + // Left + min.x, min.y, max.z, -1.0, 0.0, 0.0, 1.0, 0.0, + min.x, max.y, max.z, -1.0, 0.0, 0.0, 0.0, 0.0, + min.x, max.y, min.z, -1.0, 0.0, 0.0, 0.0, 1.0, + min.x, min.y, min.z, -1.0, 0.0, 0.0, 1.0, 1.0, + // Top + max.x, max.y, min.z, 0.0, 1.0, 0.0, 1.0, 0.0, + min.x, max.y, min.z, 0.0, 1.0, 0.0, 0.0, 0.0, + min.x, max.y, max.z, 0.0, 1.0, 0.0, 0.0, 1.0, + max.x, max.y, max.z, 0.0, 1.0, 0.0, 1.0, 1.0, + // Bottom + max.x, min.y, max.z, 0.0, -1.0, 0.0, 0.0, 0.0, + min.x, min.y, max.z, 0.0, -1.0, 0.0, 1.0, 0.0, + min.x, min.y, min.z, 0.0, -1.0, 0.0, 1.0, 1.0, + max.x, min.y, min.z, 0.0, -1.0, 0.0, 0.0, 1.0 +); + +const indices = array( + 0, 1, 2, 2, 3, 0, // front + 4, 5, 6, 6, 7, 4, // back + 8, 9, 10, 10, 11, 8, // right + 12, 13, 14, 14, 15, 12, // left + 16, 17, 18, 18, 19, 16, // top + 20, 21, 22, 22, 23, 20, // bottom +); diff --git a/examples/README.md b/examples/README.md index da86eac2fbdc1..972595e862351 100644 --- a/examples/README.md +++ b/examples/README.md @@ -467,6 +467,7 @@ Example | Description [Animated](../examples/shader/animate_shader.rs) | A shader that uses dynamic data like the time since startup [Array Texture](../examples/shader/array_texture.rs) | A shader that shows how to reuse the core bevy PBR shading functionality in a custom material that obtains the base color from an array texture. [Compute - Game of Life](../examples/shader/compute_shader_game_of_life.rs) | A compute shader that simulates Conway's Game of Life +[Compute Shader Mesh](../examples/shader_advanced/compute_mesh.rs) | A compute shader that generates a mesh that is controlled by a Handle [Custom Render Phase](../examples/shader_advanced/custom_render_phase.rs) | Shows how to make a complete render phase [Custom Vertex Attribute](../examples/shader_advanced/custom_vertex_attribute.rs) | A shader that reads a mesh's custom vertex attribute [Custom phase item](../examples/shader_advanced/custom_phase_item.rs) | Demonstrates how to enqueue custom draw commands in a render phase diff --git a/examples/shader_advanced/compute_mesh.rs b/examples/shader_advanced/compute_mesh.rs new file mode 100644 index 0000000000000..0d7b798f22df5 --- /dev/null +++ b/examples/shader_advanced/compute_mesh.rs @@ -0,0 +1,344 @@ +//! This example shows how to initialize an empty mesh with a Handle +//! and a render-world only usage. That buffer is then filled by a +//! compute shader on the GPU without transferring data back +//! to the CPU. +//! +//! The `mesh_allocator` is used to get references to the relevant slabs +//! that contain the mesh data we're interested in. +//! +//! This example does not remove the `GenerateMesh` component after +//! generating the mesh. + +use std::ops::Not; + +use bevy::{ + asset::RenderAssetUsages, + color::palettes::tailwind::{RED_400, SKY_400}, + mesh::Indices, + platform::collections::HashSet, + prelude::*, + render::{ + extract_component::{ExtractComponent, ExtractComponentPlugin}, + mesh::allocator::MeshAllocator, + render_graph::{self, RenderGraph, RenderLabel}, + render_resource::{ + binding_types::{storage_buffer, uniform_buffer}, + *, + }, + renderer::{RenderContext, RenderQueue}, + Render, RenderApp, RenderStartup, + }, +}; + +/// This example uses a shader source file from the assets subdirectory +const SHADER_ASSET_PATH: &str = "shaders/compute_mesh.wgsl"; + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + ComputeShaderMeshGeneratorPlugin, + ExtractComponentPlugin::::default(), + )) + .insert_resource(ClearColor(Color::BLACK)) + .add_systems(Startup, setup) + .run(); +} + +// We need a plugin to organize all the systems and render node required for this example +struct ComputeShaderMeshGeneratorPlugin; +impl Plugin for ComputeShaderMeshGeneratorPlugin { + fn build(&self, app: &mut App) { + let Some(render_app) = app.get_sub_app_mut(RenderApp) else { + return; + }; + + render_app + .init_resource::() + .add_systems( + RenderStartup, + (init_compute_pipeline, add_compute_render_graph_node), + ) + .add_systems(Render, prepare_chunks); + } + fn finish(&self, app: &mut App) { + let Some(render_app) = app.get_sub_app_mut(RenderApp) else { + return; + }; + render_app + .world_mut() + .resource_mut::() + // This allows using the mesh allocator slabs as + // storage buffers directly in the compute shader. + // Which means that we can write from our compute + // shader directly to the allocated mesh slabs. + .extra_buffer_usages = BufferUsages::STORAGE; + } +} + +/// Holds a handle to the empty mesh that should be filled +/// by the compute shader. +#[derive(Component, ExtractComponent, Clone)] +struct GenerateMesh(Handle); + +fn setup( + mut commands: Commands, + mut meshes: ResMut>, + mut materials: ResMut>, +) { + // a truly empty mesh will error if used in Mesh3d + // so we set up the data to be what we want the compute shader to output + // We're using 36 indices and 24 vertices which is directly taken from + // the Bevy Cuboid mesh implementation. + // + // We allocate 50 spots for each attribute here because + // it is *very important* that the amount of data allocated here is + // *bigger* than (or exactly equal to) the amount of data we intend to + // write from the compute shader. This amount of data defines how big + // the buffer we get from the mesh_allocator will be, which in turn + // defines how big the buffer is when we're in the compute shader. + // + // If it turns out you don't need all of the space when the compute shader + // is writing data, you can write NaN to the rest of the data. + let empty_mesh = { + let mut mesh = Mesh::new( + PrimitiveTopology::TriangleList, + RenderAssetUsages::RENDER_WORLD, + ) + .with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, vec![[0.; 3]; 50]) + .with_inserted_attribute(Mesh::ATTRIBUTE_NORMAL, vec![[0.; 3]; 50]) + .with_inserted_attribute(Mesh::ATTRIBUTE_UV_0, vec![[0.; 2]; 50]) + .with_inserted_indices(Indices::U32(vec![0; 50])); + + mesh.asset_usage = RenderAssetUsages::RENDER_WORLD; + mesh + }; + + let handle = meshes.add(empty_mesh); + + // we spawn two "users" of the mesh handle, + // but only insert `GenerateMesh` on one of them + // to show that the mesh handle works as usual + commands.spawn(( + GenerateMesh(handle.clone()), + Mesh3d(handle.clone()), + MeshMaterial3d(materials.add(StandardMaterial { + base_color: RED_400.into(), + ..default() + })), + Transform::from_xyz(-2.5, 1.5, 0.), + )); + + commands.spawn(( + Mesh3d(handle), + MeshMaterial3d(materials.add(StandardMaterial { + base_color: SKY_400.into(), + ..default() + })), + Transform::from_xyz(2.5, 1.5, 0.), + )); + + // some additional scene elements. + // This mesh specifically is here so that we don't assume + // mesh_allocator offsets that would only work if we had + // one mesh in the scene. + commands.spawn(( + Mesh3d(meshes.add(Circle::new(4.0))), + MeshMaterial3d(materials.add(Color::WHITE)), + Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)), + )); + commands.spawn(( + PointLight { + shadows_enabled: true, + ..default() + }, + Transform::from_xyz(4.0, 8.0, 4.0), + )); + // camera + commands.spawn(( + Camera3d::default(), + Transform::from_xyz(-2.5, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y), + )); +} + +fn add_compute_render_graph_node(mut render_graph: ResMut) { + render_graph.add_node(ComputeNodeLabel, ComputeNode::default()); + // add_node_edge guarantees that ComputeNodeLabel will run before CameraDriverLabel + render_graph.add_node_edge(ComputeNodeLabel, bevy::render::graph::CameraDriverLabel); +} + +/// This is called `ChunksToProcess` because this example originated +/// from a use case of generating chunks of landscape or voxels +/// It only exists in the render world. +#[derive(Resource, Default)] +struct ChunksToProcess(Vec>); + +/// `processed` is a `HashSet` contains the `AssetId`s that have been +/// processed. We use that to remove `AssetId`s that have already +/// been processed, which means each unique `GenerateMesh` will result +/// in one compute shader mesh generation process instead of generating +/// the mesh every frame. +fn prepare_chunks( + meshes_to_generate: Query<&GenerateMesh>, + mut chunks: ResMut, + pipeline_cache: Res, + pipeline: Res, + mut processed: Local>>, +) { + // If the pipeline isn't ready, then meshes + // won't be processed. So we want to wait until + // the pipeline is ready before considering any mesh processed. + if pipeline_cache + .get_compute_pipeline(pipeline.pipeline) + .is_some() + { + // get the AssetId for each Handle + // which we'll use later to get the relevant buffers + // from the mesh_allocator + let chunk_data: Vec> = meshes_to_generate + .iter() + .filter_map(|gmesh| { + let id = gmesh.0.id(); + processed.contains(&id).not().then_some(id) + }) + .collect(); + + // Cache any meshes we're going to process this frame + for id in &chunk_data { + processed.insert(*id); + } + + chunks.0 = chunk_data; + } +} + +#[derive(Resource)] +struct ComputePipeline { + layout: BindGroupLayoutDescriptor, + pipeline: CachedComputePipelineId, +} + +// init only happens once +fn init_compute_pipeline( + mut commands: Commands, + asset_server: Res, + pipeline_cache: Res, +) { + let layout = BindGroupLayoutDescriptor::new( + "", + &BindGroupLayoutEntries::sequential( + ShaderStages::COMPUTE, + ( + // offsets + uniform_buffer::(false), + // vertices + storage_buffer::>(false), + // indices + storage_buffer::>(false), + ), + ), + ); + let shader = asset_server.load(SHADER_ASSET_PATH); + let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("Mesh generation compute shader".into()), + layout: vec![layout.clone()], + shader: shader.clone(), + ..default() + }); + commands.insert_resource(ComputePipeline { layout, pipeline }); +} + +/// Label to identify the node in the render graph +#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] +struct ComputeNodeLabel; + +/// The node that will execute the compute shader +#[derive(Default)] +struct ComputeNode {} + +// A uniform that holds the vertex and index offsets +// for the vertex/index mesh_allocator buffer slabs +#[derive(ShaderType)] +struct DataRanges { + vertex_start: u32, + vertex_end: u32, + index_start: u32, + index_end: u32, +} + +impl render_graph::Node for ComputeNode { + fn run( + &self, + _graph: &mut render_graph::RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), render_graph::NodeRunError> { + let chunks = world.resource::(); + let mesh_allocator = world.resource::(); + + for mesh_id in &chunks.0 { + info!(?mesh_id, "processing mesh"); + let pipeline_cache = world.resource::(); + let pipeline = world.resource::(); + + if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) { + // the mesh_allocator holds slabs of meshes, so the buffers we get here + // can contain more data than just the mesh we're asking for. + // That's why there is a range field. + // You should *not* touch data in these buffers that is outside of the range. + let vertex_buffer_slice = mesh_allocator.mesh_vertex_slice(mesh_id).unwrap(); + let index_buffer_slice = mesh_allocator.mesh_index_slice(mesh_id).unwrap(); + + let first = DataRanges { + // there are 8 vertex data values (pos, normal, uv) per vertex + // and the vertex_buffer_slice.range.start is in "vertex elements" + // which includes all of that data, so each index is worth 8 indices + // to our shader code. + vertex_start: vertex_buffer_slice.range.start * 8, + vertex_end: vertex_buffer_slice.range.end * 8, + // but each vertex index is a single value, so the index of the + // vertex indices is exactly what the value is + index_start: index_buffer_slice.range.start, + index_end: index_buffer_slice.range.end, + }; + + let mut uniforms = UniformBuffer::from(first); + uniforms.write_buffer( + render_context.render_device(), + world.resource::(), + ); + + // pass in the full mesh_allocator slabs as well as the first index + // offsets for the vertex and index buffers + let bind_group = render_context.render_device().create_bind_group( + None, + &pipeline_cache.get_bind_group_layout(&pipeline.layout), + &BindGroupEntries::sequential(( + &uniforms, + vertex_buffer_slice.buffer.as_entire_buffer_binding(), + index_buffer_slice.buffer.as_entire_buffer_binding(), + )), + ); + + let mut pass = + render_context + .command_encoder() + .begin_compute_pass(&ComputePassDescriptor { + label: Some("Mesh generation compute pass"), + ..default() + }); + pass.push_debug_group("compute_mesh"); + + pass.set_bind_group(0, &bind_group, &[]); + pass.set_pipeline(init_pipeline); + // we only dispatch 1,1,1 workgroup here, but a real compute shader + // would take advantage of more and larger size workgroups + pass.dispatch_workgroups(1, 1, 1); + + pass.pop_debug_group(); + } + } + + Ok(()) + } +}