Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3020,6 +3020,17 @@ description = "A very simple compute shader that writes to a buffer that is read
category = "Shaders"
wasm = false

[[example]]
name = "compute_mesh"
path = "examples/shader/compute_mesh.rs"
doc-scrape-examples = true

[package.metadata.example.compute_mesh]
name = "Compute Shader Mesh"
description = "A compute shader that generates a mesh that is controlled by a Handle"
category = "Shaders"
wasm = false

[[example]]
name = "array_texture"
path = "examples/shader/array_texture.rs"
Expand Down
71 changes: 71 additions & 0 deletions assets/shaders/compute_mesh.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// This shader is used for the gpu_readback example
// The actual work it does is not important for the example

struct FirstIndex {
first_vertex_index: u32,
first_index_index: u32,
}

// This is the data that lives in the gpu only buffer
@group(0) @binding(0) var<uniform> first_index: FirstIndex;
@group(0) @binding(1) var<storage, read_write> vertex_data: array<f32>;
@group(0) @binding(2) var<storage, read_write> index_data: array<u32>;

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
for (var i = 0u; i < 192; i++) {
vertex_data[i + first_index.first_vertex_index * 32 ] = vertices[i ];
}
for (var i = 0u; i < 36; i++) {
index_data[i + first_index.first_index_index * 6] = u32(indices[i]);
}
}

// hardcoded compute shader data.
const half_size = vec3(2.);
const min = -half_size;
const max = half_size;

// Suppose Y-up right hand, and camera look from +Z to -Z
const vertices = array(
// xyz, normal.xyz, uv.xy
// Front
min.x, min.y, max.z, 0.0, 0.0, 1.0, 0.0, 0.0,
max.x, min.y, max.z, 0.0, 0.0, 1.0, 1.0, 0.0,
max.x, max.y, max.z, 0.0, 0.0, 1.0, 1.0, 1.0,
min.x, max.y, max.z, 0.0, 0.0, 1.0, 0.0, 1.0,
// Back
min.x, max.y, min.z, 0.0, 0.0, -1.0, 1.0, 0.0,
max.x, max.y, min.z, 0.0, 0.0, -1.0, 0.0, 0.0,
max.x, min.y, min.z, 0.0, 0.0, -1.0, 0.0, 1.0,
min.x, min.y, min.z, 0.0, 0.0, -1.0, 1.0, 1.0,
// Right
max.x, min.y, min.z, 1.0, 0.0, 0.0, 0.0, 0.0,
max.x, max.y, min.z, 1.0, 0.0, 0.0, 1.0, 0.0,
max.x, max.y, max.z, 1.0, 0.0, 0.0, 1.0, 1.0,
max.x, min.y, max.z, 1.0, 0.0, 0.0, 0.0, 1.0,
// Left
min.x, min.y, max.z, -1.0, 0.0, 0.0, 1.0, 0.0,
min.x, max.y, max.z, -1.0, 0.0, 0.0, 0.0, 0.0,
min.x, max.y, min.z, -1.0, 0.0, 0.0, 0.0, 1.0,
min.x, min.y, min.z, -1.0, 0.0, 0.0, 1.0, 1.0,
// Top
max.x, max.y, min.z, 0.0, 1.0, 0.0, 1.0, 0.0,
min.x, max.y, min.z, 0.0, 1.0, 0.0, 0.0, 0.0,
min.x, max.y, max.z, 0.0, 1.0, 0.0, 0.0, 1.0,
max.x, max.y, max.z, 0.0, 1.0, 0.0, 1.0, 1.0,
// Bottom
max.x, min.y, max.z, 0.0, -1.0, 0.0, 0.0, 0.0,
min.x, min.y, max.z, 0.0, -1.0, 0.0, 1.0, 0.0,
min.x, min.y, min.z, 0.0, -1.0, 0.0, 1.0, 1.0,
max.x, min.y, min.z, 0.0, -1.0, 0.0, 0.0, 1.0
);

const indices = array(
0, 1, 2, 2, 3, 0, // front
4, 5, 6, 6, 7, 4, // back
8, 9, 10, 10, 11, 8, // right
12, 13, 14, 14, 15, 12, // left
16, 17, 18, 18, 19, 16, // top
20, 21, 22, 22, 23, 20, // bottom
);
270 changes: 270 additions & 0 deletions examples/shader/compute_mesh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
//! Simple example demonstrating the use of the [`Readback`] component to read back data from the GPU
//! using both a storage buffer and texture.

use bevy::{
asset::RenderAssetUsages,
color::palettes::tailwind::RED_400,
mesh::{Indices, MeshVertexAttribute},
prelude::*,
render::{
extract_resource::{ExtractResource, ExtractResourcePlugin},
gpu_readback::{Readback, ReadbackComplete},
render_asset::RenderAssets,
render_graph::{self, RenderGraph, RenderLabel},
render_resource::{
binding_types::{storage_buffer, texture_storage_2d},
*,
},
renderer::{RenderContext, RenderDevice},
storage::{GpuShaderStorageBuffer, ShaderStorageBuffer},
texture::GpuImage,
Render, RenderApp, RenderStartup, RenderSystems,
},
};
use bevy_render::{
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{allocator::MeshAllocator, RenderMesh},
render_resource::binding_types::uniform_buffer,
renderer::RenderQueue,
};

/// This example uses a shader source file from the assets subdirectory
const SHADER_ASSET_PATH: &str = "shaders/compute_mesh.wgsl";

fn main() {
App::new()
.add_plugins((
DefaultPlugins,
ComputeShaderMeshGeneratorPlugin,
ExtractComponentPlugin::<GenerateMesh>::default(),
))
.insert_resource(ClearColor(Color::BLACK))
.add_systems(Startup, setup)
.run();
}

// We need a plugin to organize all the systems and render node required for this example
struct ComputeShaderMeshGeneratorPlugin;
impl Plugin for ComputeShaderMeshGeneratorPlugin {
fn build(&self, app: &mut App) {
let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
return;
};

render_app
.init_resource::<Chunks>()
.add_systems(
RenderStartup,
(init_compute_pipeline, add_compute_render_graph_node),
)
.add_systems(Render, prepare_chunks);
}
fn finish(&self, app: &mut App) {
let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
return;
};
render_app
.world_mut()
.resource_mut::<MeshAllocator>()
.extra_buffer_usages = BufferUsages::STORAGE;
}
}

#[derive(Component, ExtractComponent, Clone)]
struct GenerateMesh(Handle<Mesh>);

fn setup(
mut commands: Commands,
mut images: ResMut<Assets<Image>>,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<StandardMaterial>>,
mut buffers: ResMut<Assets<ShaderStorageBuffer>>,
) {
// a truly empty mesh will error if used in Mesh3d
// so use a sphere for the example
let mut empty_mesh = Mesh::new(
PrimitiveTopology::TriangleList,
RenderAssetUsages::RENDER_WORLD,
);
// set up what we want to output from the compute shader.
// We're using 36 indices, 24 vertices which is directly taken from
// the Bevy Cuboid mesh
empty_mesh.insert_attribute(Mesh::ATTRIBUTE_POSITION, vec![[0.; 3]; 24]);
empty_mesh.insert_attribute(Mesh::ATTRIBUTE_NORMAL, vec![[0.; 3]; 24]);
empty_mesh.insert_attribute(Mesh::ATTRIBUTE_UV_0, vec![[0.; 2]; 24]);
empty_mesh.insert_indices(Indices::U32(vec![0; 36]));
empty_mesh.asset_usage = RenderAssetUsages::RENDER_WORLD;

let handle = meshes.add(empty_mesh);
commands.spawn((
GenerateMesh(handle.clone()),
Mesh3d(handle.clone()),
MeshMaterial3d(materials.add(StandardMaterial {
base_color: RED_400.into(),
..default()
})),
Transform::from_xyz(0., 1., 0.),
));

// commands.spawn((
// Mesh3d(handle),
// MeshMaterial3d(materials.add(StandardMaterial {
// base_color: RED_400.into(),
// ..default()
// })),
// Transform::from_xyz(2., 1., 0.),
// ));

// // spawn some scene
// commands.spawn((
// Mesh3d(meshes.add(Circle::new(4.0))),
// MeshMaterial3d(materials.add(Color::WHITE)),
// Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),
// ));
commands.spawn((
PointLight {
shadows_enabled: true,
..default()
},
Transform::from_xyz(4.0, 8.0, 4.0),
));
// camera
commands.spawn((
Camera3d::default(),
Transform::from_xyz(-2.5, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),
));
}

fn add_compute_render_graph_node(mut render_graph: ResMut<RenderGraph>) {
// Add the compute node as a top-level node to the render graph. This means it will only execute
// once per frame. Normally, adding a node would use the `RenderGraphApp::add_render_graph_node`
// method, but it does not allow adding as a top-level node.
render_graph.add_node(ComputeNodeLabel, ComputeNode::default());
}

#[derive(Resource, Default)]
struct Chunks(Vec<AssetId<Mesh>>);

fn prepare_chunks(
meshes_to_generate: Query<&GenerateMesh>,
mut chunks: ResMut<Chunks>,
mesh_handles: Res<RenderAssets<RenderMesh>>,
) {
let chunk_data: Vec<AssetId<Mesh>> = meshes_to_generate
.iter()
// sometimes RenderMesh doesn't exist yet!
.map(|gmesh| gmesh.0.id())
.collect();
// dbg!(chunk_data);
chunks.0 = chunk_data;
}

#[derive(Resource)]
struct ComputePipeline {
layout: BindGroupLayoutDescriptor,
pipeline: CachedComputePipelineId,
}

// init only happens once
fn init_compute_pipeline(
mut commands: Commands,
asset_server: Res<AssetServer>,
pipeline_cache: Res<PipelineCache>,
) {
let layout = BindGroupLayoutDescriptor::new(
"",
&BindGroupLayoutEntries::sequential(
ShaderStages::COMPUTE,
(
uniform_buffer::<FirstIndex>(false),
// vertices
storage_buffer::<Vec<u32>>(false),
// indices
storage_buffer::<Vec<u32>>(false),
),
),
);
let shader = asset_server.load(SHADER_ASSET_PATH);
let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("Mesh generation compute shader".into()),
layout: vec![layout.clone()],
shader: shader.clone(),
..default()
});
commands.insert_resource(ComputePipeline { layout, pipeline });
}

/// Label to identify the node in the render graph
#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
struct ComputeNodeLabel;

/// The node that will execute the compute shader
#[derive(Default)]
struct ComputeNode {}

#[derive(ShaderType)]
struct FirstIndex {
first_vertex_index: u32,
first_index_index: u32,
}

impl render_graph::Node for ComputeNode {
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
let Some(chunks) = world.get_resource::<Chunks>() else {
info!("no chunks");
return Ok(());
};
let mesh_allocator = world.resource::<MeshAllocator>();

for mesh_id in &chunks.0 {
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<ComputePipeline>();

if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) {
let vertex_buffer_slice = mesh_allocator.mesh_vertex_slice(mesh_id).unwrap();
let index_buffer_slice = mesh_allocator.mesh_index_slice(mesh_id).unwrap();

dbg!(&vertex_buffer_slice.range);
dbg!(&index_buffer_slice.range);

let first = FirstIndex {
first_vertex_index: vertex_buffer_slice.range.start * 4,
first_index_index: index_buffer_slice.range.start * 4,
};
let mut uniforms = UniformBuffer::from(first);
uniforms.write_buffer(
render_context.render_device(),
world.resource::<RenderQueue>(),
);
let bind_group = render_context.render_device().create_bind_group(
None,
&pipeline_cache.get_bind_group_layout(&pipeline.layout),
&BindGroupEntries::sequential((
&uniforms,
vertex_buffer_slice.buffer.as_entire_buffer_binding(),
index_buffer_slice.buffer.as_entire_buffer_binding(),
)),
);

let mut pass =
render_context
.command_encoder()
.begin_compute_pass(&ComputePassDescriptor {
label: Some("Mesh generation compute pass"),
..default()
});

pass.set_bind_group(0, &bind_group, &[]);
pass.set_pipeline(init_pipeline);
pass.dispatch_workgroups(1, 1, 1);
}
}

Ok(())
}
}
Loading