diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 46d7948d55e..c3f80e79a67 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -2,13 +2,71 @@ 🧪Experimental🧪 -`wgpu` supports an experimental version of mesh shading. Currently `naga` has no support for mesh shaders beyond recognizing the additional shader stages. +`wgpu` supports an experimental version of mesh shading when `Features::EXPERIMENTAL_MESH_SHADER` is enabled. +Currently `naga` has no support for parsing or writing mesh shaders. For this reason, all shaders must be created with `Device::create_shader_module_passthrough`. **Note**: The features documented here may have major bugs in them and are expected to be subject to breaking changes, suggestions for the API exposed by this should be posted on [the mesh-shading issue](https://github.com/gfx-rs/wgpu/issues/7197). -***This is not*** a thorough explanation of mesh shading and how it works. Those wishing to understand mesh shading more broadly should look elsewhere first. +## Mesh shaders overview + +### What are mesh shaders? + +Mesh shaders are a new kind of rasterization pipeline intended to address some of the shortfalls with the vertex shader pipeline. The core idea of mesh shaders is that the GPU decides how to render the many small parts of a scene instead of the CPU issuing a draw call for every small part or issuing an inefficient monolithic draw call for a large part of the scene. + +Mesh shaders are specifically designed to be used with **meshlet rendering**, a technique where every object is split into many subobjects called meshlets that are each rendered with their own parameters. With the standard vertex pipeline, each draw call specifies an exact number of primitives to render and the same parameters for all vertex shaders on an entire object (or even multiple objects). This doesn't leave room for different LODs for different parts of an object, for example a closer part having more detail, nor does it allow culling smaller sections (or primitives) of objects. With mesh shaders, each task workgroup might get assigned to a single object. It can then analyze the different meshlets(sections) of that object, determine which are visible and should actually be rendered, and for those meshlets determine what LOD to use based on the distance from the camera. It can then dispatch a mesh workgroup for each meshlet, with each mesh workgroup then reading the data for that specific LOD of its meshlet, determining which and how many vertices and primitives to output, determining which remaining primitives need to be culled, and passing the resulting primitives to the rasterizer. + +Mesh shaders are most effective in scenes with many polygons. They can allow skipping processing of entire groups of primitives that are facing away from the camera or otherwise occluded, which reduces the number of primitives that need to be processed by more than half in most cases, and they can reduce the number of primitives that need to be processed for more distant objects. Scenes that are not bottlenecked by geometry (perhaps instead by fragment processing or post processing) will not see much benefit from using them. + +Mesh shaders were first shown off in [NVIDIA's asteroids demo](https://www.youtube.com/watch?v=CRfZYJ_sk5E). Now, they form the basis for [Unreal Engine's Nanite](https://www.unrealengine.com/en-US/blog/unreal-engine-5-is-now-available-in-preview#Nanite). + +### Mesh shader pipeline + +With the current pipeline set to a mesh pipeline, a draw command like +`render_pass.draw_mesh_tasks(x, y, z)` takes the following steps: + +* If the pipeline has a task shader stage: + + * Dispatch a grid of task shader workgroups, where `x`, `y`, and `z` give + the number of workgroups along each axis of the grid. Each task shader + workgroup produces a mesh shader workgroup grid size `(mx, my, mz)` and a + task payload value `mp`. + + * For each task shader workgroup, dispatch a grid of mesh shader workgroups, + where `mx`, `my`, and `mz` give the number of workgroups along each axis + of the grid. Pass `mp` to each of these workgroup's mesh shader + invocations. + +* Alternatively, if the pipeline does not have a task shader stage: + + * Dispatch a single grid of mesh shader workgroups, where `x`, `y`, and `z` + give the number of workgroups along each axis of the grid. These mesh + shaders receive no task payload value. + +* Each mesh shader workgroup produces a list of output vertices, and a list of + primitives built from those vertices. The workgroup can supply per-primitive + values as well, if needed. Each primitive selects its vertices by index, like + an indexed draw call, from among the vertices generated by this workgroup. + + Unlike a grid of ordinary compute shader workgroups collaborating to build + vertex and index data in common storage buffers, the vertices and primitives + produced by a mesh shader workgroup are entirely private to that workgroup, + and are not accessible by other workgroups. + +* Primitives produced by a mesh shader workgroup can have a culling flag. If a + primitive's culling flag is false, it is skipped during rasterization. + +* The primitives produced by all mesh shader workgroups are then rasterized in + the usual way, with each fragment shader invocation handling one pixel. + + Attributes from the vertices produced by the mesh shader workgroup are + provided to the fragment shader with interpolation applied as appropriate. + + If the mesh shader workgroup supplied per-primitive values, these are + available to each primitive's fragment shader invocations. Per-primitive + values are never interpolated; fragment shaders simply receive the values + the mesh shader workgroup associated with their primitive. ## `wgpu` API @@ -75,36 +133,63 @@ Using any of these features in a `wgsl` program will require adding the `enable Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-specific functionality, such as subgroup operations. ### Task shader -This shader stage can be selected by marking a function with `@task`. Task shaders must return a `vec3` as their output type. Similar to compute shaders, task shaders run in a workgroup. The output must be uniform across all threads in a workgroup. -The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup. +A function with the `@task` attribute is a **task shader entry point**. A mesh shader pipeline may optionally specify a task shader entry point, and if it does, mesh draw commands using that pipeline dispatch a **task shader grid** of workgroups running the task shader entry point. Like compute shader dispatches, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the task shader grid as the number of workgroups along each of the grid's three axes. -If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may write to `someVar`. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. +A task shader entry point must have a `@workgroup_size` attribute, meeting the same requirements as one appearing on a compute shader entry point. + +A task shader entry point must also have a `@payload(G)` property, where `G` is the name of a global variable in the `task_payload` address space. Each task shader workgroup has its own instance of this variable, visible to all invocations in the workgroup. Whatever value the workgroup collectively stores in that global variable becomes the **task payload**, and is provided to all invocations in the mesh shader grid dispatched for the workgroup. + +A task shader entry point must return a `vec3` value. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section. + +Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid; +and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload. ### Mesh shader -This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything. -Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this workgroup memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. +A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh shaders must not return anything. + +Like compute shaders, mesh shaders are invoked in a grid of workgroups, called a **mesh shader grid**. If the mesh shader pipeline has a task shader, then each task shader workgroup determines the size of a mesh shader grid to be dispatched, as described above. Otherwise, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the mesh shader grid directly, as the number of workgroups along each of the grid's three axes. + +If the mesh shader pipeline has a task shader entry point, then the pipeline's mesh shader entry point must also have a `@payload(G)` attribute, naming the same variable, and the sizes must match. Mesh shader invocations can read, but not write, this variable, which is initialized to whatever value was written to it by the task shader workgroup that dispatched this mesh shader grid. + +If the mesh shader pipeline does not have a task shader entry point, then the mesh shader entry point must not have any `@payload` attribute. + +A mesh shader entry point must have the following attributes: + +- `@workgroup_size`: this has the same meaning as when it appears on a compute shader entry point. + +- `@vertex_output(V, NV)`: This indicates that the mesh shader workgroup will generate at most `NV` vertex values, each of type `V`. -Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output. +- `@primitive_output(P, NP)`: This indicates that the mesh shader workgroup will generate at most `NP` primitives, each of type `P`. -Mesh shaders must also be marked with `@primitive_output(OutputType, numOutputs)`, which is similar to `@vertex_output` except it describes the primitive outputs. +Before generating any results, each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function exactly once, in uniform control flow. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. This call essentially establishes two implicit arrays of vertex and primitive values, shared across the workgroup, for invocations to populate. -### Mesh shader outputs +The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. -Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. +To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. An invocation may only call `setVertex` after its call to `setMeshOutputs`. -Mesh shader primitive outputs must also specify exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`. This determines the output topology of the mesh shader, and must match the output topology of the pipeline descriptor the mesh shader is used with. These must be of type `vec3`, `vec2`, and `u32` respectively. When setting this, each of the indices must be less than the number of vertices declared in `setMeshOutputs`. +To produce primitives, the workgroup as a whole must make `numPrimitives` calls to the `setPrimitive(i: u32, primitive: P)` builtin function. This establishes `primitive` as the value of the `i`'th primitive. `P` is the type given in the `@primitive_output` attribute. `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. The following `@builtin` attributes are allowed: -Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap. +- `triangle_indices`, `line_indices`, or `point_index`: The annotated member must be of type `vec3`, `vec2`, or `u32`. -Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. + The member's components are indices (or, its value is an index) into the list of vertices generated by this workgroup, identifying the vertices of the primitive to be drawn. These indices must be less than the value of `numVertices` passed to `setMeshOutputs`. -The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup. + The type `P` must contain exactly one member with one of these attributes, determining what sort of primitives the mesh shader generates. + +- `cull_primitive`: The annotated member must be of type `bool`. If it is true, then the primitive is skipped during rendering. + +Every member of `P` with a `@location` attribute must either have a `@per_primitive` attribute, or be part of a struct type that appears in the primitive data as a struct member with the `@per_primitive` attribute. + +The `@location` attributes of `P` and `V` must not overlap, since they are merged to produce the user-defined inputs to the fragment shader. + +It is possible to write to the same vertex or primitive index repeatedly. Since the implicit arrays written by `setVertex` and `setPrimitive` are shared by the workgroup, data races on writes to the same index for a given type are undefined behavior. ### Fragment shader -Fragment shaders may now be passed the primitive info from a mesh shader the same was as they are passed vertex inputs, for example `fn fs_main(vertex: VertexOutput, primitive: PrimitiveOutput)`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. +Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`. + +The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap. ### Full example @@ -114,9 +199,9 @@ The following is a full example of WGSL shaders that could be used to create a m enable mesh_shading; const positions = array( - vec4(0.,-1.,0.,1.), - vec4(-1.,1.,0.,1.), - vec4(1.,1.,0.,1.) + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) ); const colors = array( vec4(0.,1.,0.,1.), @@ -127,7 +212,7 @@ struct TaskPayload { colorMask: vec4, visible: bool, } -var taskPayload: TaskPayload; +var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { @builtin(position) position: vec4, @@ -136,14 +221,12 @@ struct VertexOutput { struct PrimitiveOutput { @builtin(triangle_indices) index: vec3, @builtin(cull_primitive) cull: bool, - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } -fn test_function(input: u32) { -} @task @payload(taskPayload) @workgroup_size(1) @@ -162,8 +245,6 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati workgroupData = 2.0; var v: VertexOutput; - test_function(1); - v.position = positions[0]; v.color = colors[0] * taskPayload.colorMask; setVertex(0, v); @@ -186,4 +267,4 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { return vertex.color * primitive.colorMask; } -``` \ No newline at end of file +``` diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index b620ecc704f..e22e44baddc 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -64,6 +64,12 @@ struct Args { #[argh(option)] shader_model: Option, + /// the SPIR-V version to use if targeting SPIR-V + /// + /// For example, 1.0, 1.4, etc + #[argh(option)] + spirv_version: Option, + /// the shader stage, for example 'frag', 'vert', or 'compute'. /// if the shader stage is unspecified it will be derived from /// the file extension. @@ -189,6 +195,22 @@ impl FromStr for ShaderModelArg { } } +#[derive(Debug, Clone)] +struct SpirvVersionArg(u8, u8); + +impl FromStr for SpirvVersionArg { + type Err = String; + + fn from_str(s: &str) -> Result { + let dot = s + .find(".") + .ok_or_else(|| "Missing dot separator".to_owned())?; + let major = s[..dot].parse::().map_err(|e| e.to_string())?; + let minor = s[dot + 1..].parse::().map_err(|e| e.to_string())?; + Ok(Self(major, minor)) + } +} + /// Newtype so we can implement [`FromStr`] for `ShaderSource`. #[derive(Debug, Clone, Copy)] struct ShaderStage(naga::ShaderStage); @@ -465,6 +487,9 @@ fn run() -> anyhow::Result<()> { if let Some(ref version) = args.metal_version { params.msl.lang_version = version.0; } + if let Some(ref version) = args.spirv_version { + params.spv_out.lang_version = (version.0, version.1); + } params.keep_coordinate_space = args.keep_coordinate_space; params.dot.cfg_only = args.dot_cfg_only; diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 826dad1c219..1f1396eccff 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -307,6 +307,25 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.dependencies.push((id, vertex_count, "vertex_count")); + self.dependencies + .push((id, primitive_count, "primitive_count")); + "SetMeshOutputs" + } + S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetVertex" + } + S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetPrimitive" + } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index a6dfe4e3100..b884f08ac39 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -610,6 +610,7 @@ impl Writer<'_, W> { interpolation, sampling, blend_src, + per_primitive: _, } => { if interpolation == Some(Interpolation::Linear) { self.features.request(Features::NOPERSPECTIVE_QUALIFIER); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 4c5a9d8cbcb..37bf318c4f8 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -139,7 +139,8 @@ impl crate::AddressSpace { | crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Handle - | crate::AddressSpace::PushConstant => false, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload => false, } } } @@ -1300,6 +1301,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::AddressSpace::Storage { .. } => { self.write_interface_block(handle, global)?; } + crate::AddressSpace::TaskPayload => { + self.write_interface_block(handle, global)?; + } // A global variable in the `Function` address space is a // contradiction in terms. crate::AddressSpace::Function => unreachable!(), @@ -1614,6 +1618,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation, sampling, blend_src, + per_primitive: _, } => (location, interpolation, sampling, blend_src), crate::Binding::BuiltIn(built_in) => { match built_in { @@ -1732,6 +1737,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation: None, sampling: None, blend_src, + per_primitive: false, }, stage: self.entry_point.stage, options: VaryingOptions::from_writer_options(self.options, output), @@ -1873,7 +1879,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ") {{")?; if self.options.zero_initialize_workgroup_memory - && ctx.ty.is_compute_entry_point(self.module) + && ctx.ty.is_compute_like_entry_point(self.module) { self.write_workgroup_variables_initialization(&ctx)?; } @@ -2669,6 +2675,11 @@ impl<'a, W: Write> Writer<'a, W> { self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction( + crate::MeshFunction::SetMeshOutputs { .. } + | crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -5247,6 +5258,15 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + // mesh + // TODO: figure out how to map these to glsl things as glsl treats them as arrays + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize => { + unimplemented!() + } } } @@ -5262,6 +5282,7 @@ const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static s As::Handle => Some("uniform"), As::WorkGroup => Some("shared"), As::PushConstant => Some("uniform"), + As::TaskPayload => unreachable!(), } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index ed40cbe5102..d6ccc5ec6e4 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -183,6 +183,9 @@ impl crate::BuiltIn { Self::PointSize | Self::ViewIndex | Self::PointCoord | Self::DrawID => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } + Self::CullPrimitive => "SV_CullPrimitive", + Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), + Self::MeshTaskSize => unreachable!(), }) } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index ab95b9327f9..6f0ba814a52 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -507,7 +507,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_wrapped_functions(module, &ctx)?; - if ep.stage == ShaderStage::Compute { + if ep.stage.compute_like() { // HLSL is calling workgroup size "num threads" let num_threads = ep.workgroup_size; writeln!( @@ -967,6 +967,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_type(module, global.ty)?; "" } + crate::AddressSpace::TaskPayload => unimplemented!(), crate::AddressSpace::Uniform => { // constant buffer declarations are expected to be inlined, e.g. // `cbuffer foo: register(b0) { field1: type1; }` @@ -1764,7 +1765,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module: &Module, ) -> bool { self.options.zero_initialize_workgroup_memory - && func_ctx.ty.is_compute_entry_point(module) + && func_ctx.ty.is_compute_like_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) @@ -2599,6 +2600,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ".Abort();")?; } }, + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + write!(self.out, "{level}SetMeshOutputCounts(")?; + self.write_expr(module, vertex_count, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, primitive_count, func_ctx)?; + write!(self.out, ");")?; + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); @@ -3076,7 +3090,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup - | crate::AddressSpace::PushConstant, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload, ) | None => true, Some(crate::AddressSpace::Uniform) => { diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 0d13d63dd9b..8be763234e7 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -139,11 +139,11 @@ pub enum FunctionType { } impl FunctionType { - /// Returns true if the function is an entry point for a compute shader. - pub fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + /// Returns true if the function is an entry point for a compute-like shader. + pub fn is_compute_like_entry_point(&self, module: &crate::Module) -> bool { match *self { FunctionType::EntryPoint(index) => { - module.entry_points[index as usize].stage == crate::ShaderStage::Compute + module.entry_points[index as usize].stage.compute_like() } FunctionType::Function(_) => false, } diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 44aedf686c4..aaec5e8094c 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -540,6 +540,7 @@ impl Options { interpolation, sampling, blend_src, + per_primitive: _, } => match mode { LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), LocationMode::FragmentOutput => { @@ -697,6 +698,10 @@ impl ResolvedBinding { Bi::CullDistance | Bi::ViewIndex | Bi::DrawID => { return Err(Error::UnsupportedBuiltIn(built_in)) } + Bi::CullPrimitive => "primitive_culled", + // TODO: figure out how to make this written as a function call + Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), + Bi::MeshTaskSize => unreachable!(), }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 6e51f90181e..484142630d2 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -594,7 +594,8 @@ impl crate::AddressSpace { | Self::Private | Self::WorkGroup | Self::PushConstant - | Self::Handle => true, + | Self::Handle + | Self::TaskPayload => true, Self::Function => false, } } @@ -607,6 +608,7 @@ impl crate::AddressSpace { // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, + Self::TaskPayload => unimplemented!(), // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. @@ -623,6 +625,7 @@ impl crate::AddressSpace { Self::Storage { .. } => Some("device"), Self::Private | Self::Function => Some("thread"), Self::WorkGroup => Some("threadgroup"), + Self::TaskPayload => Some("object_data"), } } } @@ -4060,6 +4063,14 @@ impl Writer { } } } + // TODO: write emitters for these + crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => { + unimplemented!() + } + crate::Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); @@ -6619,7 +6630,7 @@ template LocationMode::Uniform, false, ), - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? @@ -6686,6 +6697,9 @@ template break; } } + crate::AddressSpace::TaskPayload => { + unimplemented!() + } crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} @@ -7683,7 +7697,7 @@ mod workgroup_mem_init { fun_info: &valid::FunctionInfo, ) -> bool { options.zero_initialize_workgroup_memory - && ep.stage == crate::ShaderStage::Compute + && ep.stage.compute_like() && module.global_variables.iter().any(|(handle, var)| { !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d2b3ed70eda..109cc591e74 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -39,6 +39,8 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, + #[error("max vertices or max primitives is negative")] + NegativeMeshOutputMax, } /// Compact `module` and replace all overrides with constants. @@ -243,6 +245,7 @@ pub fn process_overrides<'a>( for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; + process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; module.overrides = overrides; @@ -296,6 +299,28 @@ fn process_workgroup_size_override( Ok(()) } +fn process_mesh_shader_overrides( + module: &mut Module, + adjusted_global_expressions: &HandleVec>, + ep: &mut crate::EntryPoint, +) -> Result<(), PipelineConstantError> { + if let Some(ref mut mesh_info) = ep.mesh_info { + if let Some(r#override) = mesh_info.max_vertices_override { + mesh_info.max_vertices = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; + } + if let Some(r#override) = mesh_info.max_primitives_override { + mesh_info.max_primitives = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. @@ -835,6 +860,26 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 7758d86c414..258d6869c99 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3654,6 +3654,7 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + Statement::MeshFunction(_) => unreachable!(), } } diff --git a/naga/src/back/spv/helpers.rs b/naga/src/back/spv/helpers.rs index 84e130efaa3..f6d26794e70 100644 --- a/naga/src/back/spv/helpers.rs +++ b/naga/src/back/spv/helpers.rs @@ -54,6 +54,7 @@ pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::Stor crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant, + crate::AddressSpace::TaskPayload => unreachable!(), } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index c86a53c6ef8..1e207fc7002 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1094,7 +1094,10 @@ impl Writer { super::ZeroInitializeWorkgroupMemoryMode::Polyfill, Some( ref mut interface @ FunctionInterface { - stage: crate::ShaderStage::Compute, + stage: + crate::ShaderStage::Compute + | crate::ShaderStage::Mesh + | crate::ShaderStage::Task, .. }, ), @@ -1991,6 +1994,7 @@ impl Writer { interpolation, sampling, blend_src, + per_primitive: _, } => { self.decorate(id, Decoration::Location, &[location]); @@ -2140,6 +2144,11 @@ impl Writer { )?; BuiltIn::SubgroupLocalInvocationId } + Bi::MeshTaskSize + | Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices => unreachable!(), }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 225a63343bf..d1ebf62e6ee 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -207,7 +207,7 @@ impl Writer { Attribute::Stage(ShaderStage::Compute), Attribute::WorkGroupSize(ep.workgroup_size), ], - ShaderStage::Task | ShaderStage::Mesh => unreachable!(), + ShaderStage::Mesh | ShaderStage::Task => unreachable!(), }; self.write_attributes(&attributes)?; @@ -856,6 +856,7 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction(..) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -1822,6 +1823,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: None, + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::Interpolate(interpolation, sampling), @@ -1831,6 +1833,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: Some(blend_src), + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::BlendSrc(blend_src), diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 72be441288f..e94fbf88796 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -188,7 +188,12 @@ impl TryToWgsl for crate::BuiltIn { | Bi::PointSize | Bi::DrawID | Bi::PointCoord - | Bi::WorkGroupSize => return None, + | Bi::WorkGroupSize + | Bi::CullPrimitive + | Bi::TriangleIndices + | Bi::LineIndices + | Bi::MeshTaskSize + | Bi::PointIndex => return None, }) } } @@ -352,6 +357,7 @@ pub const fn address_space_str( As::WorkGroup => "workgroup", As::Handle => return (None, None), As::Function => "function", + As::TaskPayload => return (None, None), }), None, ) diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index d059ba21e4f..a7d3d463f11 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -221,6 +221,45 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { } } + for entry in &module.entry_points { + if let Some(task_payload) = entry.task_payload { + module_tracer.global_variables_used.insert(task_payload); + } + if let Some(ref mesh_info) = entry.mesh_info { + module_tracer + .types_used + .insert(mesh_info.vertex_output_type); + module_tracer + .types_used + .insert(mesh_info.primitive_output_type); + if let Some(max_vertices_override) = mesh_info.max_vertices_override { + module_tracer + .global_expressions_used + .insert(max_vertices_override); + } + if let Some(max_primitives_override) = mesh_info.max_primitives_override { + module_tracer + .global_expressions_used + .insert(max_primitives_override); + } + } + if entry.stage == crate::ShaderStage::Task || entry.stage == crate::ShaderStage::Mesh { + // u32 should always be there if the module is valid, as it is e.g. the type of some expressions + let u32_type = module + .types + .iter() + .find_map(|tuple| { + if tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32) { + Some(tuple.0) + } else { + None + } + }) + .unwrap(); + module_tracer.types_used.insert(u32_type); + } + } + module_tracer.type_expression_tandem(); // Now that we know what is used and what is never touched, @@ -342,6 +381,23 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { &module_map, &mut reused_named_expressions, ); + if let Some(ref mut task_payload) = entry.task_payload { + module_map.globals.adjust(task_payload); + } + if let Some(ref mut mesh_info) = entry.mesh_info { + module_map.types.adjust(&mut mesh_info.vertex_output_type); + module_map + .types + .adjust(&mut mesh_info.primitive_output_type); + if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override { + module_map.global_expressions.adjust(max_vertices_override); + } + if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override { + module_map + .global_expressions + .adjust(max_primitives_override); + } + } } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 39d6065f5f0..b370501baca 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,6 +117,20 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.expressions_used.insert(vertex_count); + self.expressions_used.insert(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetPrimitive { index, value } + | crate::MeshFunction::SetVertex { index, value }, + ) => { + self.expressions_used.insert(index); + self.expressions_used.insert(value); + } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); @@ -335,6 +349,26 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } St::SubgroupBallot { ref mut result, ref mut predicate, diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 7de7364cd40..ba096a82b3b 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1377,6 +1377,8 @@ impl Frontend { result: ty.map(|ty| FunctionResult { ty, binding: None }), ..Default::default() }, + mesh_info: None, + task_payload: None, }); Ok(()) @@ -1446,6 +1448,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; @@ -1482,6 +1485,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; binding diff --git a/naga/src/front/glsl/mod.rs b/naga/src/front/glsl/mod.rs index 876add46a1c..e5eda6b3ad9 100644 --- a/naga/src/front/glsl/mod.rs +++ b/naga/src/front/glsl/mod.rs @@ -107,7 +107,7 @@ impl ShaderMetadata { self.version = 0; self.profile = Profile::Core; self.stage = stage; - self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3]; + self.workgroup_size = [u32::from(stage.compute_like()); 3]; self.early_fragment_tests = false; self.extensions.clear(); } diff --git a/naga/src/front/glsl/variables.rs b/naga/src/front/glsl/variables.rs index ef98143b769..98871bd2f81 100644 --- a/naga/src/front/glsl/variables.rs +++ b/naga/src/front/glsl/variables.rs @@ -465,6 +465,7 @@ impl Frontend { interpolation, sampling, blend_src, + per_primitive: false, }, handle, storage, diff --git a/naga/src/front/interpolator.rs b/naga/src/front/interpolator.rs index e23cae0e7c2..126e860426c 100644 --- a/naga/src/front/interpolator.rs +++ b/naga/src/front/interpolator.rs @@ -44,6 +44,7 @@ impl crate::Binding { interpolation: ref mut interpolation @ None, ref mut sampling, blend_src: _, + per_primitive: _, } = *self { match ty.scalar_kind() { diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 67cbf05f04f..48b23e7c4c4 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -596,6 +596,8 @@ impl> super::Frontend { workgroup_size: ep.workgroup_size, workgroup_size_overrides: None, function, + mesh_info: None, + task_payload: None, }); Ok(()) diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 5e1b1146503..48f8264d5c4 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -265,6 +265,7 @@ impl Decoration { interpolation, sampling, blend_src: None, + per_primitive: false, }), _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), } @@ -4659,6 +4660,7 @@ impl> Frontend { | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } => {} @@ -4940,6 +4942,8 @@ impl> Frontend { spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + spirv::ExecutionModel::TaskEXT => crate::ShaderStage::Task, + spirv::ExecutionModel::MeshEXT => crate::ShaderStage::Mesh, _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), }, name, diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0ea..004528dbe91 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -406,6 +406,19 @@ pub(crate) enum Error<'a> { accept_span: Span, accept_type: String, }, + MissingMeshShaderInfo { + mesh_attribute_span: Span, + }, + OneMeshShaderAttribute { + attribute_span: Span, + }, + ExpectedGlobalVariable { + name_span: Span, + }, + MeshPrimitiveNoDefinedTopology { + attribute_span: Span, + struct_span: Span, + }, StructMemberTooLarge { member_name_span: Span, }, @@ -1370,6 +1383,27 @@ impl<'a> Error<'a> { ], notes: vec![], }, + Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { + message: "mesh shader entry point is missing both `@vertex_output` and `@primitive_output`".into(), + labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], + notes: vec![], + }, + Error::OneMeshShaderAttribute { attribute_span } => ParseError { + message: "only one of `@vertex_output` or `@primitive_output` was given".into(), + labels: vec![(attribute_span, "only one of @vertex_output or @primitive_output is provided".into())], + notes: vec![], + }, + Error::ExpectedGlobalVariable { name_span } => ParseError { + message: "expected global variable".to_string(), + // TODO: I would like to also include the global declaration span + labels: vec![(name_span, "variable used here".into())], + notes: vec![], + }, + Error::MeshPrimitiveNoDefinedTopology { struct_span, attribute_span } => ParseError { + message: "mesh primitive struct must have exactly one of point indices, line indices, or triangle indices".to_string(), + labels: vec![(attribute_span, "primitive type declared here".into()), (struct_span, "primitive struct declared here".into())], + notes: vec![] + }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e90d7eab0a8..ef63e6aaea7 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1479,47 +1479,147 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { - // TODO: replace with try_map once stabilized - let mut workgroup_size_out = [1; 3]; - let mut workgroup_size_overrides_out = [None; 3]; - for (i, size) in workgroup_size.into_iter().enumerate() { - if let Some(size_expr) = size { - match self.const_u32(size_expr, &mut ctx.as_const()) { - Ok(value) => { - workgroup_size_out[i] = value.0; - } - Err(err) => { - if let Error::ConstantEvaluatorError(ref ty, _) = *err { - match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = - Some(self.workgroup_size_override( - size_expr, - &mut ctx.as_override(), - )?); - } - _ => { - return Err(err); + let (workgroup_size, workgroup_size_overrides) = + if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + )?); + } + _ => { + return Err(err); + } } + } else { + return Err(err); } - } else { - return Err(err); } } } } - } - if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { - (workgroup_size_out, None) + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - (workgroup_size_out, Some(workgroup_size_overrides_out)) + ([0; 3], None) + }; + + let mesh_info = if let Some(mesh_info) = entry.mesh_shader_info { + let mut const_u32 = |expr| match self.const_u32(expr, &mut ctx.as_const()) { + Ok(value) => Ok((value.0, None)), + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => Ok(( + 0, + Some( + // This is dubious but it seems the code isn't workgroup size specific + self.workgroup_size_override(expr, &mut ctx.as_override())?, + ), + )), + _ => Err(err), + } + } else { + Err(err) + } + } + }; + let (max_vertices, max_vertices_override) = const_u32(mesh_info.vertex_count)?; + let (max_primitives, max_primitives_override) = + const_u32(mesh_info.primitive_count)?; + let vertex_output_type = + self.resolve_ast_type(mesh_info.vertex_type.0, &mut ctx.as_const())?; + let primitive_output_type = + self.resolve_ast_type(mesh_info.primitive_type.0, &mut ctx.as_const())?; + + let mut topology = None; + let struct_span = ctx.module.types.get_span(primitive_output_type); + match &ctx.module.types[primitive_output_type].inner { + &ir::TypeInner::Struct { + ref members, + span: _, + } => { + for member in members { + let out_topology = match member.binding { + Some(ir::Binding::BuiltIn(ir::BuiltIn::TriangleIndices)) => { + Some(ir::MeshOutputTopology::Triangles) + } + Some(ir::Binding::BuiltIn(ir::BuiltIn::LineIndices)) => { + Some(ir::MeshOutputTopology::Lines) + } + _ => None, + }; + if out_topology.is_some() { + if topology.is_some() { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + } + topology = out_topology; + } + } + } + _ => { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })) + } } + let topology = if let Some(t) = topology { + t + } else { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + }; + + Some(ir::MeshStageInfo { + max_vertices, + max_vertices_override, + max_primitives, + max_primitives_override, + + vertex_output_type, + primitive_output_type, + topology, + }) + } else { + None + }; + + let task_payload = if let Some((var_name, var_span)) = entry.task_payload { + Some(match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, + })) + } + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), + }) } else { - ([0; 3], None) + None }; - let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(ir::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, @@ -1527,6 +1627,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { workgroup_size, workgroup_size_overrides, function, + mesh_info, + task_payload, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, @@ -3130,6 +3232,59 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + + "setMeshOutputs" | "setVertex" | "setPrimitive" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let arg1 = args.next()?; + let arg2 = args.next()?; + args.finish()?; + + let mut cast_u32 = |arg| { + // Try to convert abstract values to the known argument types + let expr = self.expression_for_abstract(arg, ctx)?; + let goal_ty = + ctx.ensure_type_exists(ir::TypeInner::Scalar(ir::Scalar::U32)); + ctx.try_automatic_conversions( + expr, + &proc::TypeResolution::Handle(goal_ty), + ctx.ast_expressions.get_span(arg), + ) + }; + + let arg1 = cast_u32(arg1)?; + let arg2 = if function.name == "setMeshOutputs" { + cast_u32(arg2)? + } else { + self.expression(arg2, ctx)? + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + + // Emit all previous expressions, even if not used directly + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.block.push( + crate::Statement::MeshFunction(match function.name { + "setMeshOutputs" => crate::MeshFunction::SetMeshOutputs { + vertex_count: arg1, + primitive_count: arg2, + }, + "setVertex" => crate::MeshFunction::SetVertex { + index: arg1, + value: arg2, + }, + "setPrimitive" => crate::MeshFunction::SetPrimitive { + index: arg1, + value: arg2, + }, + _ => unreachable!(), + }), + span, + ); + rctx.emitter.start(&rctx.function.expressions); + + return Ok(None); + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } @@ -4057,6 +4212,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }) => { let blend_src = if let Some(blend_src) = blend_src { Some(self.const_u32(blend_src, &mut ctx.as_const())?.0) @@ -4069,6 +4225,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c486..49ecddfdee5 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -128,6 +128,16 @@ pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, + pub mesh_shader_info: Option>, + pub task_payload: Option<(&'a str, Span)>, +} + +#[derive(Debug, Clone, Copy)] +pub struct EntryPointMeshShaderInfo<'a> { + pub vertex_count: Handle>, + pub primitive_count: Handle>, + pub vertex_type: (Handle>, Span), + pub primitive_type: (Handle>, Span), } #[cfg(doc)] @@ -152,6 +162,7 @@ pub enum Binding<'a> { interpolation: Option, sampling: Option, blend_src: Option>>, + per_primitive: bool, }, } diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 30d0eb2d598..2bde001804e 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -16,6 +16,7 @@ pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpa }), "push_constant" => Ok(crate::AddressSpace::PushConstant), "function" => Ok(crate::AddressSpace::Function), + "task_payload" => Ok(crate::AddressSpace::TaskPayload), _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } @@ -49,6 +50,12 @@ pub fn map_built_in( "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + // mesh + "cull_primitive" => crate::BuiltIn::CullPrimitive, + "point_index" => crate::BuiltIn::PointIndex, + "line_indices" => crate::BuiltIn::LineIndices, + "triangle_indices" => crate::BuiltIn::TriangleIndices, + "mesh_task_size" => crate::BuiltIn::MeshTaskSize, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { diff --git a/naga/src/front/wgsl/parse/directive/enable_extension.rs b/naga/src/front/wgsl/parse/directive/enable_extension.rs index 38d6d6719ca..d376c114ff0 100644 --- a/naga/src/front/wgsl/parse/directive/enable_extension.rs +++ b/naga/src/front/wgsl/parse/directive/enable_extension.rs @@ -10,6 +10,7 @@ use alloc::boxed::Box; /// Tracks the status of every enable-extension known to Naga. #[derive(Clone, Debug, Eq, PartialEq)] pub struct EnableExtensions { + mesh_shader: bool, dual_source_blending: bool, /// Whether `enable f16;` was written earlier in the shader module. f16: bool, @@ -19,6 +20,7 @@ pub struct EnableExtensions { impl EnableExtensions { pub(crate) const fn empty() -> Self { Self { + mesh_shader: false, f16: false, dual_source_blending: false, clip_distances: false, @@ -28,6 +30,7 @@ impl EnableExtensions { /// Add an enable-extension to the set requested by a module. pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) { let field = match ext { + ImplementedEnableExtension::MeshShader => &mut self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending, ImplementedEnableExtension::F16 => &mut self.f16, ImplementedEnableExtension::ClipDistances => &mut self.clip_distances, @@ -38,6 +41,7 @@ impl EnableExtensions { /// Query whether an enable-extension tracked here has been requested. pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool { match ext { + ImplementedEnableExtension::MeshShader => self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending, ImplementedEnableExtension::F16 => self.f16, ImplementedEnableExtension::ClipDistances => self.clip_distances, @@ -70,6 +74,7 @@ impl EnableExtension { const F16: &'static str = "f16"; const CLIP_DISTANCES: &'static str = "clip_distances"; const DUAL_SOURCE_BLENDING: &'static str = "dual_source_blending"; + const MESH_SHADER: &'static str = "mesh_shading"; const SUBGROUPS: &'static str = "subgroups"; const PRIMITIVE_INDEX: &'static str = "primitive_index"; @@ -81,6 +86,7 @@ impl EnableExtension { Self::DUAL_SOURCE_BLENDING => { Self::Implemented(ImplementedEnableExtension::DualSourceBlending) } + Self::MESH_SHADER => Self::Implemented(ImplementedEnableExtension::MeshShader), Self::SUBGROUPS => Self::Unimplemented(UnimplementedEnableExtension::Subgroups), Self::PRIMITIVE_INDEX => { Self::Unimplemented(UnimplementedEnableExtension::PrimitiveIndex) @@ -93,6 +99,7 @@ impl EnableExtension { pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => match kind { + ImplementedEnableExtension::MeshShader => Self::MESH_SHADER, ImplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING, ImplementedEnableExtension::F16 => Self::F16, ImplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES, @@ -126,6 +133,8 @@ pub enum ImplementedEnableExtension { /// /// [`enable clip_distances;`]: https://www.w3.org/TR/WGSL/#extension-clip_distances ClipDistances, + /// Enables the `mesh_shader` extension, native only + MeshShader, } /// A variant of [`EnableExtension::Unimplemented`]. diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30f..29376614d6e 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -178,6 +178,7 @@ struct BindingParser<'a> { sampling: ParsedAttribute, invariant: ParsedAttribute, blend_src: ParsedAttribute>>, + per_primitive: ParsedAttribute<()>, } impl<'a> BindingParser<'a> { @@ -238,6 +239,9 @@ impl<'a> BindingParser<'a> { lexer.skip(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } + "per_primitive" => { + self.per_primitive.set((), name_span)?; + } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } Ok(()) @@ -251,9 +255,10 @@ impl<'a> BindingParser<'a> { self.sampling.value, self.invariant.value.unwrap_or_default(), self.blend_src.value, + self.per_primitive.value, ) { - (None, None, None, None, false, None) => Ok(None), - (Some(location), None, interpolation, sampling, false, blend_src) => { + (None, None, None, None, false, None, None) => Ok(None), + (Some(location), None, interpolation, sampling, false, blend_src, per_primitive) => { // Before handing over the completed `Module`, we call // `apply_default_interpolation` to ensure that the interpolation and // sampling have been explicitly specified on all vertex shader output and fragment @@ -263,17 +268,18 @@ impl<'a> BindingParser<'a> { interpolation, sampling, blend_src, + per_primitive: per_primitive.is_some(), })) } - (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None) => { + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None, None) => { Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { invariant, }))) } - (None, Some(built_in), None, None, false, None) => { + (None, Some(built_in), None, None, false, None, None) => { Ok(Some(ast::Binding::BuiltIn(built_in))) } - (_, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), + (_, _, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), } } } @@ -2790,12 +2796,15 @@ impl Parser { // read attributes let mut binding = None; let mut stage = ParsedAttribute::default(); - let mut compute_span = Span::new(0, 0); + let mut compute_like_span = Span::new(0, 0); let mut workgroup_size = ParsedAttribute::default(); let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); + let mut payload = ParsedAttribute::default(); + let mut vertex_output = ParsedAttribute::default(); + let mut primitive_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); @@ -2854,7 +2863,35 @@ impl Parser { } "compute" => { stage.set(ShaderStage::Compute, name_span)?; - compute_span = name_span; + compute_like_span = name_span; + } + "task" => { + stage.set(ShaderStage::Task, name_span)?; + compute_like_span = name_span; + } + "mesh" => { + stage.set(ShaderStage::Mesh, name_span)?; + compute_like_span = name_span; + } + "payload" => { + lexer.expect(Token::Paren('('))?; + payload.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "vertex_output" | "primitive_output" => { + lexer.expect(Token::Paren('('))?; + let type_span = lexer.peek().1; + let r#type = self.type_decl(lexer, &mut ctx)?; + let type_span = lexer.span_from(type_span.to_range().unwrap().start); + lexer.expect(Token::Separator(','))?; + let max_output = self.general_expression(lexer, &mut ctx)?; + let end_span = lexer.expect_span(Token::Paren(')'))?; + let total_span = name_span.until(&end_span); + if name == "vertex_output" { + vertex_output.set((r#type, type_span, max_output), total_span)?; + } else if name == "primitive_output" { + primitive_output.set((r#type, type_span, max_output), total_span)?; + } } "workgroup_size" => { lexer.expect(Token::Paren('('))?; @@ -3020,13 +3057,39 @@ impl Parser { )?; Some(ast::GlobalDeclKind::Fn(ast::Function { entry_point: if let Some(stage) = stage.value { - if stage == ShaderStage::Compute && workgroup_size.value.is_none() { - return Err(Box::new(Error::MissingWorkgroupSize(compute_span))); + if stage.compute_like() && workgroup_size.value.is_none() { + return Err(Box::new(Error::MissingWorkgroupSize(compute_like_span))); } + if stage == ShaderStage::Mesh + && (vertex_output.value.is_none() || primitive_output.value.is_none()) + { + return Err(Box::new(Error::MissingMeshShaderInfo { + mesh_attribute_span: compute_like_span, + })); + } + let mesh_shader_info = match (vertex_output.value, primitive_output.value) { + (Some(vertex_output), Some(primitive_output)) => { + Some(ast::EntryPointMeshShaderInfo { + vertex_count: vertex_output.2, + primitive_count: primitive_output.2, + vertex_type: (vertex_output.0, vertex_output.1), + primitive_type: (primitive_output.0, primitive_output.1), + }) + } + (None, None) => None, + (Some(v), None) | (None, Some(v)) => { + return Err(Box::new(Error::OneMeshShaderAttribute { + attribute_span: v.1, + })) + } + }; + Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, + mesh_shader_info, + task_payload: payload.value, }) } else { None diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 257445952b8..4b0769c2803 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -320,13 +320,21 @@ pub enum ConservativeDepth { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -#[allow(missing_docs)] // The names are self evident pub enum ShaderStage { + /// A vertex shader, in a render pipeline. Vertex, - Fragment, - Compute, + + /// A task shader, in a mesh render pipeline. Task, + + /// A mesh shader, in a mesh render pipeline. Mesh, + + /// A fragment shader, in a render pipeline. + Fragment, + + /// Compute pipeline shader. + Compute, } /// Addressing space of variables. @@ -363,6 +371,8 @@ pub enum AddressSpace { /// /// [`SHADER_FLOAT16`]: crate::valid::Capabilities::SHADER_FLOAT16 PushConstant, + /// Task shader to mesh shader payload + TaskPayload, } /// Built-in inputs and outputs. @@ -371,36 +381,73 @@ pub enum AddressSpace { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BuiltIn { + /// Written in vertex/mesh shaders, read in fragment shaders Position { invariant: bool }, + /// Read in task, mesh, vertex, and fragment shaders ViewIndex, - // vertex + + /// Read in vertex shaders BaseInstance, + /// Read in vertex shaders BaseVertex, + /// Written in vertex & mesh shaders ClipDistance, + /// Written in vertex & mesh shaders CullDistance, + /// Read in vertex shaders InstanceIndex, + /// Written in vertex & mesh shaders PointSize, + /// Read in vertex shaders VertexIndex, + /// Read in vertex & task shaders, or mesh shaders in pipelines without task shaders DrawID, - // fragment + + /// Written in fragment shaders FragDepth, + /// Read in fragment shaders PointCoord, + /// Read in fragment shaders FrontFacing, + /// Read in fragment shaders, in the future may written in mesh shaders PrimitiveIndex, + /// Read in fragment shaders SampleIndex, + /// Read or written in fragment shaders SampleMask, - // compute + + /// Read in compute, task, and mesh shaders GlobalInvocationId, + /// Read in compute, task, and mesh shaders LocalInvocationId, + /// Read in compute, task, and mesh shaders LocalInvocationIndex, + /// Read in compute, task, and mesh shaders WorkGroupId, + /// Read in compute, task, and mesh shaders WorkGroupSize, + /// Read in compute, task, and mesh shaders NumWorkGroups, - // subgroup + + /// Read in compute, task, and mesh shaders NumSubgroups, + /// Read in compute, task, and mesh shaders SubgroupId, + /// Read in compute, fragment, task, and mesh shaders SubgroupSize, + /// Read in compute, fragment, task, and mesh shaders SubgroupInvocationId, + + /// Written in task shaders + MeshTaskSize, + /// Written in mesh shaders + CullPrimitive, + /// Written in mesh shaders + PointIndex, + /// Written in mesh shaders + LineIndices, + /// Written in mesh shaders + TriangleIndices, } /// Number of bytes per scalar. @@ -945,6 +992,9 @@ pub enum Binding { /// Indexed location. /// + /// This is a value passed to a [`Fragment`] shader from a [`Vertex`] or + /// [`Mesh`] shader. + /// /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must /// have their `interpolation` defaulted (i.e. not `None`) by the front end /// as appropriate for that language. @@ -958,14 +1008,30 @@ pub enum Binding { /// interpolation must be `Flat`. /// /// [`Vertex`]: crate::ShaderStage::Vertex + /// [`Mesh`]: crate::ShaderStage::Mesh /// [`Fragment`]: crate::ShaderStage::Fragment Location { location: u32, interpolation: Option, sampling: Option, + /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, + + /// Whether the binding is a per-primitive binding for use with mesh shaders. + /// + /// This must be `true` if this binding is a mesh shader primitive output, or such + /// an output's corresponding fragment shader input. It must be `false` otherwise. + /// + /// A stage's outputs must all have unique `location` numbers, regardless of + /// whether they are per-primitive; a mesh shader's per-vertex and per-primitive + /// outputs share the same location numbering space. + /// + /// Per-primitive values are not interpolated at all and are not dependent on the + /// vertices or pixel location. For example, it may be used to store a + /// non-interpolated normal vector. + per_primitive: bool, }, } @@ -1724,10 +1790,12 @@ pub enum Expression { query: Handle, committed: bool, }, + /// Result of a [`SubgroupBallot`] statement. /// /// [`SubgroupBallot`]: Statement::SubgroupBallot SubgroupBallotResult, + /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. /// /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation @@ -2141,6 +2209,8 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + /// A mesh shader intrinsic. + MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. @@ -2314,6 +2384,12 @@ pub struct EntryPoint { pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, + /// Information for [`Mesh`] shaders. + /// + /// [`Mesh`]: ShaderStage::Mesh + pub mesh_info: Option, + /// The unique global variable used as a task payload from task shader to mesh shader + pub task_payload: Option>, } /// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions. @@ -2490,6 +2566,66 @@ pub struct DocComments { pub module: Vec, } +/// The output topology for a mesh shader. Note that mesh shaders don't allow things like triangle-strips. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshOutputTopology { + /// Outputs individual vertices to be rendered as points. + Points, + /// Outputs groups of 2 vertices to be renderedas lines . + Lines, + /// Outputs groups of 3 vertices to be rendered as triangles. + Triangles, +} + +/// Information specific to mesh shader entry points. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[allow(dead_code)] +pub struct MeshStageInfo { + /// The type of primitive outputted. + pub topology: MeshOutputTopology, + /// The maximum number of vertices a mesh shader may output. + pub max_vertices: u32, + /// If pipeline constants are used, the expressions that override `max_vertices` + pub max_vertices_override: Option>, + /// The maximum number of primitives a mesh shader may output. + pub max_primitives: u32, + /// If pipeline constants are used, the expressions that override `max_primitives` + pub max_primitives_override: Option>, + /// The type used by vertex outputs, i.e. what is passed to `setVertex`. + pub vertex_output_type: Handle, + /// The type used by primitive outputs, i.e. what is passed to `setPrimitive`. + pub primitive_output_type: Handle, +} + +/// Mesh shader intrinsics +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshFunction { + /// Sets the number of vertices and primitives that will be outputted. + SetMeshOutputs { + vertex_count: Handle, + primitive_count: Handle, + }, + /// Sets the output vertex at a given index. + SetVertex { + index: Handle, + value: Handle, + }, + /// Sets the output primitive at a given index. + SetPrimitive { + index: Handle, + value: Handle, + }, +} + /// Shader module. /// /// A module is a set of constants, global variables and functions, as well as diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 26f873a9435..eca63ee4fb5 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -179,6 +179,9 @@ impl super::AddressSpace { crate::AddressSpace::Storage { access } => access, crate::AddressSpace::Handle => Sa::LOAD, crate::AddressSpace::PushConstant => Sa::LOAD, + // TaskPayload isn't always writable, but this is checked for elsewhere, + // when not using multiple payloads and matching the entry payload is checked. + crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE, } } } @@ -628,6 +631,15 @@ pub fn flatten_compose<'arenas>( .take(size) } +impl super::ShaderStage { + pub const fn compute_like(self) -> bool { + match self { + Self::Vertex | Self::Fragment => false, + Self::Compute | Self::Task | Self::Mesh => true, + } + } +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index b29ccb054a3..f76d4c06a3b 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -36,6 +36,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ImageStore { .. } | S::Call { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::Atomic { .. } | S::ImageAtomic { .. } | S::WorkGroupUniformLoad { .. } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 95ae40dcdb4..14554573c9f 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -85,6 +85,25 @@ struct FunctionUniformity { exit: ExitFlags, } +/// Mesh shader related characteristics of a function. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct FunctionMeshShaderInfo { + /// The type of value this function passes to [`SetVertex`], and the + /// expression that first established it. + /// + /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex + pub vertex_type: Option<(Handle, Handle)>, + + /// The type of value this function passes to [`SetPrimitive`], and the + /// expression that first established it. + /// + /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive + pub primitive_type: Option<(Handle, Handle)>, +} + impl ops::BitOr for FunctionUniformity { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -302,6 +321,9 @@ pub struct FunctionInfo { /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. diagnostic_filter_leaf: Option>, + + /// Mesh shader info for this function and its callees. + pub mesh_shader_info: FunctionMeshShaderInfo, } impl FunctionInfo { @@ -372,6 +394,14 @@ impl FunctionInfo { info.uniformity.non_uniform_result } + pub fn insert_global_use( + &mut self, + global_use: GlobalUse, + global: Handle, + ) { + self.global_uses[global.index()] |= global_use; + } + /// Record a use of `expr` for its value. /// /// This is used for almost all expression references. Anything @@ -482,6 +512,9 @@ impl FunctionInfo { *mine |= *other; } + // Inherit mesh output types from our callees. + self.try_update_mesh_info(&callee.mesh_shader_info)?; + Ok(FunctionUniformity { result: callee.uniformity.clone(), exit: if callee.may_kill { @@ -635,7 +668,8 @@ impl FunctionInfo { // local data is non-uniform As::Function | As::Private => false, // workgroup memory is exclusively accessed by the group - As::WorkGroup => true, + // task payload memory is very similar to workgroup memory + As::WorkGroup | As::TaskPayload => true, // uniform data As::Uniform | As::PushConstant => true, // storage data is only uniform when read-only @@ -1113,6 +1147,36 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::MeshFunction(func) => { + self.available_stages |= ShaderStages::MESH; + match &func { + // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. + &crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + let _ = self.add_ref(vertex_count); + let _ = self.add_ref(primitive_count); + FunctionUniformity::new() + } + &crate::MeshFunction::SetVertex { index, value } + | &crate::MeshFunction::SetPrimitive { index, value } => { + let _ = self.add_ref(index); + let _ = self.add_ref(value); + let ty = self.expressions[value.index()].ty.handle().ok_or( + FunctionError::InvalidMeshShaderOutputType(value).with_span(), + )?; + + if matches!(func, crate::MeshFunction::SetVertex { .. }) { + self.try_update_mesh_vertex_type(ty, value)?; + } else { + self.try_update_mesh_primitive_type(ty, value)?; + }; + + FunctionUniformity::new() + } + } + } S::SubgroupBallot { result: _, predicate, @@ -1158,6 +1222,72 @@ impl FunctionInfo { } Ok(combined_uniformity) } + + /// Note the type of value passed to [`SetVertex`]. + /// + /// Record that this function passed a value of type `ty` as the second + /// argument to the [`SetVertex`] builtin function. All calls to + /// `SetVertex` must pass the same type, and this must match the + /// function's [`vertex_output_type`]. + /// + /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex + /// [`vertex_output_type`]: crate::ir::MeshStageInfo::vertex_output_type + fn try_update_mesh_vertex_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.vertex_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.vertex_type = Some((ty, value)); + } + Ok(()) + } + + /// Note the type of value passed to [`SetPrimitive`]. + /// + /// Record that this function passed a value of type `ty` as the second + /// argument to the [`SetPrimitive`] builtin function. All calls to + /// `SetPrimitive` must pass the same type, and this must match the + /// function's [`primitive_output_type`]. + /// + /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive + /// [`primitive_output_type`]: crate::ir::MeshStageInfo::primitive_output_type + fn try_update_mesh_primitive_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.primitive_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.primitive_type = Some((ty, value)); + } + Ok(()) + } + + /// Update this function's mesh shader info, given that it calls `callee`. + fn try_update_mesh_info( + &mut self, + callee: &FunctionMeshShaderInfo, + ) -> Result<(), WithSpan> { + if let &Some(ref other_vertex) = &callee.vertex_type { + self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; + } + if let &Some(ref other_primitive) = &callee.primitive_type { + self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; + } + Ok(()) + } } impl ModuleInfo { @@ -1193,6 +1323,7 @@ impl ModuleInfo { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: fun.diagnostic_filter_leaf, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); @@ -1326,6 +1457,7 @@ fn uniform_control_flow() { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext { constants: &Arena::new(), diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dc19e191764..0216c6ef7f6 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,6 +217,14 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), + #[error("Expression {0:?} in mesh shader intrinsic call should be `u32` (is the expression a signed integer?)")] + InvalidMeshFunctionCall(Handle), + #[error("Mesh output types differ from {0:?} to {1:?}")] + ConflictingMeshOutputTypes(Handle, Handle), + #[error("Task payload variables differ from {0:?} to {1:?}")] + ConflictingTaskPayloadVariables(Handle, Handle), + #[error("Mesh shader output at {0:?} is not a user-defined struct")] + InvalidMeshShaderOutputType(Handle), } bitflags::bitflags! { @@ -1539,6 +1547,41 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::MeshFunction(func) => { + let ensure_u32 = + |expr: Handle| -> Result<(), WithSpan> { + let u32_ty = TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)); + let ty = context + .resolve_type_impl(expr, &self.valid_expression_set) + .map_err_inner(|source| { + FunctionError::Expression { + source, + handle: expr, + } + .with_span_handle(expr, context.expressions) + })?; + if !context.compare_types(&u32_ty, ty) { + return Err(FunctionError::InvalidMeshFunctionCall(expr) + .with_span_handle(expr, context.expressions)); + } + Ok(()) + }; + match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + ensure_u32(vertex_count)?; + ensure_u32(primitive_count)?; + } + crate::MeshFunction::SetVertex { index, value: _ } + | crate::MeshFunction::SetPrimitive { index, value: _ } => { + ensure_u32(index)?; + // Value is validated elsewhere (since the value type isn't known ahead of time but must match for all calls + // in a function or the function's called functions) + } + } + } S::SubgroupBallot { result, predicate } => { stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e8a69013434..adb9f355c11 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -233,6 +233,20 @@ impl super::Validator { validate_const_expr(size)?; } } + if let Some(task_payload) = entry_point.task_payload { + Self::validate_global_variable_handle(task_payload, global_variables)?; + } + if let Some(ref mesh_info) = entry_point.mesh_info { + validate_type(mesh_info.vertex_output_type)?; + validate_type(mesh_info.primitive_output_type)?; + for ov in mesh_info + .max_vertices_override + .iter() + .chain(mesh_info.max_primitives_override.iter()) + { + validate_const_expr(*ov)?; + } + } } for (function_handle, function) in functions.iter() { @@ -801,6 +815,22 @@ impl super::Validator { } Ok(()) } + crate::Statement::MeshFunction(func) => match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + validate_expr(vertex_count)?; + validate_expr(primitive_count)?; + Ok(()) + } + crate::MeshFunction::SetVertex { index, value } + | crate::MeshFunction::SetPrimitive { index, value } => { + validate_expr(index)?; + validate_expr(value)?; + Ok(()) + } + }, crate::Statement::SubgroupBallot { result, predicate } => { validate_expr_opt(predicate)?; validate_expr(result)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 7c8cc903139..a4e0af99ccc 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -43,6 +43,8 @@ pub enum GlobalVariableError { StorageAddressSpaceWriteOnlyNotSupported, #[error("Type is not valid for use as a push constant")] InvalidPushConstantType(#[source] PushConstantError), + #[error("Task payload must not be zero-sized")] + ZeroSizedTaskPayload, } #[derive(Clone, Debug, thiserror::Error)] @@ -92,6 +94,12 @@ pub enum VaryingError { }, #[error("Workgroup size is multi dimensional, `@builtin(subgroup_id)` and `@builtin(subgroup_invocation_id)` are not supported.")] InvalidMultiDimensionalSubgroupBuiltIn, + #[error("The `@per_primitive` attribute can only be used in fragment shader inputs or mesh shader primitive outputs")] + InvalidPerPrimitive, + #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] + MissingPerPrimitive, + #[error("The `MESH_SHADER` capability must be enabled to use per-primitive fragment inputs.")] + PerPrimitiveNotAllowed, } #[derive(Clone, Debug, thiserror::Error)] @@ -123,6 +131,32 @@ pub enum EntryPointError { InvalidIntegerInterpolation { location: u32 }, #[error(transparent)] Function(#[from] FunctionError), + #[error("Non mesh shader entry point cannot have mesh shader attributes")] + UnexpectedMeshShaderAttributes, + #[error("Non mesh/task shader entry point cannot have task payload attribute")] + UnexpectedTaskPayload, + #[error("Task payload must be declared with `var`")] + TaskPayloadWrongAddressSpace, + #[error("For a task payload to be used, it must be declared with @payload")] + WrongTaskPayloadUsed, + #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")] + WrongMeshOutputType, + #[error("Only mesh shader entry points can write to mesh output vertices and primitives")] + UnexpectedMeshShaderOutput, + #[error("Mesh shader entry point cannot have a return type")] + UnexpectedMeshShaderEntryResult, + #[error("Task shader entry point must return @builtin(mesh_task_size) vec3")] + WrongTaskShaderEntryResult, + #[error("Mesh output type must be a user-defined struct.")] + InvalidMeshOutputType, + #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] + InvalidMeshPrimitiveOutputType, + #[error("Task shaders must declare a task payload output")] + ExpectedTaskPayload, + #[error( + "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders." + )] + MeshShaderCapabilityDisabled, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -139,6 +173,13 @@ fn storage_usage(access: crate::StorageAccess) -> GlobalUse { storage_usage } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MeshOutputType { + None, + VertexOutput, + PrimitiveOutput, +} + struct VaryingContext<'a> { stage: crate::ShaderStage, output: bool, @@ -149,6 +190,8 @@ struct VaryingContext<'a> { built_ins: &'a mut crate::FastHashSet, capabilities: Capabilities, flags: super::ValidationFlags, + mesh_output_type: MeshOutputType, + has_task_payload: bool, } impl VaryingContext<'_> { @@ -201,16 +244,20 @@ impl VaryingContext<'_> { } let (visible, type_good) = match built_in { - Bi::BaseInstance - | Bi::BaseVertex - | Bi::InstanceIndex - | Bi::VertexIndex - | Bi::DrawID => ( + Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( self.stage == St::Vertex && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), + Bi::DrawID => ( + // Always allowed in task/vertex stage. Allowed in mesh stage if there is no task stage in the pipeline. + (self.stage == St::Vertex + || self.stage == St::Task + || (self.stage == St::Mesh && !self.has_task_payload)) + && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), Bi::ClipDistance | Bi::CullDistance => ( - self.stage == St::Vertex && self.output, + (self.stage == St::Vertex || self.stage == St::Mesh) && self.output, match *ty_inner { Ti::Array { base, size, .. } => { self.types[base].inner == Ti::Scalar(crate::Scalar::F32) @@ -223,7 +270,7 @@ impl VaryingContext<'_> { }, ), Bi::PointSize => ( - self.stage == St::Vertex && self.output, + (self.stage == St::Vertex || self.stage == St::Mesh) && self.output, *ty_inner == Ti::Scalar(crate::Scalar::F32), ), Bi::PointCoord => ( @@ -236,10 +283,9 @@ impl VaryingContext<'_> { ), Bi::Position { .. } => ( match self.stage { - St::Vertex => self.output, + St::Vertex | St::Mesh => self.output, St::Fragment => !self.output, - St::Compute => false, - St::Task | St::Mesh => unreachable!(), + St::Compute | St::Task => false, }, *ty_inner == Ti::Vector { @@ -249,9 +295,8 @@ impl VaryingContext<'_> { ), Bi::ViewIndex => ( match self.stage { - St::Vertex | St::Fragment => !self.output, + St::Vertex | St::Fragment | St::Task | St::Mesh => !self.output, St::Compute => false, - St::Task | St::Mesh => unreachable!(), }, *ty_inner == Ti::Scalar(crate::Scalar::I32), ), @@ -276,7 +321,7 @@ impl VaryingContext<'_> { *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::LocalInvocationIndex => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::GlobalInvocationId @@ -284,7 +329,7 @@ impl VaryingContext<'_> { | Bi::WorkGroupId | Bi::WorkGroupSize | Bi::NumWorkGroups => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Vector { size: Vs::Tri, @@ -292,17 +337,48 @@ impl VaryingContext<'_> { }, ), Bi::NumSubgroups | Bi::SubgroupId => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { - St::Compute | St::Fragment => !self.output, + St::Compute | St::Fragment | St::Task | St::Mesh => !self.output, St::Vertex => false, - St::Task | St::Mesh => unreachable!(), }, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), + Bi::CullPrimitive => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::BOOL), + ), + Bi::PointIndex => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::LineIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Bi, + scalar: crate::Scalar::U32, + }, + ), + Bi::TriangleIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), + Bi::MeshTaskSize => ( + self.stage == St::Task && self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), }; if !visible { @@ -318,7 +394,11 @@ impl VaryingContext<'_> { interpolation, sampling, blend_src, + per_primitive, } => { + if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) { + return Err(VaryingError::PerPrimitiveNotAllowed); + } // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] .flags @@ -327,6 +407,22 @@ impl VaryingContext<'_> { return Err(VaryingError::NotIOShareableType(ty)); } + // Check whether `per_primitive` is appropriate for this stage and direction. + if self.mesh_output_type == MeshOutputType::PrimitiveOutput { + // All mesh shader `Location` outputs must be `per_primitive`. + if !per_primitive { + return Err(VaryingError::MissingPerPrimitive); + } + } else if self.stage == crate::ShaderStage::Fragment && !self.output { + // Fragment stage inputs may be `per_primitive`. We'll only + // know if these are correct when the whole mesh pipeline is + // created and we're paired with a specific mesh or vertex + // shader. + } else if per_primitive { + // All other `Location` bindings must not be `per_primitive`. + return Err(VaryingError::InvalidPerPrimitive); + } + if let Some(blend_src) = blend_src { // `blend_src` is only valid if dual source blending was explicitly enabled, // see https://www.w3.org/TR/WGSL/#extension-dual_source_blending @@ -392,9 +488,9 @@ impl VaryingContext<'_> { let needs_interpolation = match self.stage { crate::ShaderStage::Vertex => self.output, - crate::ShaderStage::Fragment => !self.output, - crate::ShaderStage::Compute => false, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Fragment => !self.output && !per_primitive, + crate::ShaderStage::Compute | crate::ShaderStage::Task => false, + crate::ShaderStage::Mesh => self.output, }; // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but @@ -595,7 +691,9 @@ impl super::Validator { TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, false, ), - crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), + crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => { + (TypeFlags::DATA | TypeFlags::SIZED, false) + } crate::AddressSpace::PushConstant => { if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { return Err(GlobalVariableError::UnsupportedCapability( @@ -628,6 +726,14 @@ impl super::Validator { } } + if var.space == crate::AddressSpace::TaskPayload { + let ty = &gctx.types[var.ty].inner; + // HLSL doesn't allow zero sized payloads. + if ty.try_size(gctx) == Some(0) { + return Err(GlobalVariableError::ZeroSizedTaskPayload); + } + } + if let Some(init) = var.init { match var.space { crate::AddressSpace::Private | crate::AddressSpace::Function => {} @@ -651,12 +757,72 @@ impl super::Validator { Ok(()) } + /// Validate the mesh shader output type `ty`, used as `mesh_output_type`. + fn validate_mesh_output_type( + &mut self, + ep: &crate::EntryPoint, + module: &crate::Module, + ty: Handle, + mesh_output_type: MeshOutputType, + ) -> Result<(), WithSpan> { + if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) { + return Err(EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types)); + } + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + blend_src_mask: &mut self.blend_src_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + mesh_output_type, + has_task_payload: ep.task_payload.is_some(), + }; + ctx.validate(ep, ty, None) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if mesh_output_type == MeshOutputType::PrimitiveOutput { + let mut num_indices_builtins = 0; + if result_built_ins.contains(&crate::BuiltIn::PointIndex) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::LineIndices) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) { + num_indices_builtins += 1; + } + if num_indices_builtins != 1 { + return Err(EntryPointError::InvalidMeshPrimitiveOutputType + .with_span_handle(ty, &module.types)); + } + } else if mesh_output_type == MeshOutputType::VertexOutput + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err( + EntryPointError::MissingVertexOutputPosition.with_span_handle(ty, &module.types) + ); + } + + Ok(()) + } + pub(super) fn validate_entry_point( &mut self, ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, ) -> Result> { + if matches!( + ep.stage, + crate::ShaderStage::Task | crate::ShaderStage::Mesh + ) && !self.capabilities.contains(Capabilities::MESH_SHADER) + { + return Err(EntryPointError::MeshShaderCapabilityDisabled.with_span()); + } if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; if !self.capabilities.contains(required) { @@ -671,7 +837,7 @@ impl super::Validator { } } - if ep.stage == crate::ShaderStage::Compute { + if ep.stage.compute_like() { if ep .workgroup_size .iter() @@ -683,10 +849,48 @@ impl super::Validator { return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); } + if ep.stage != crate::ShaderStage::Mesh && ep.mesh_info.is_some() { + return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); + } + let mut info = self .validate_function(&ep.function, module, mod_info, true) .map_err(WithSpan::into_other)?; + // Validate the task shader payload. + match ep.stage { + // Task shaders must produce a payload. + crate::ShaderStage::Task => { + let Some(handle) = ep.task_payload else { + return Err(EntryPointError::ExpectedTaskPayload.with_span()); + }; + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(handle, &module.global_variables)); + } + info.insert_global_use(GlobalUse::READ | GlobalUse::WRITE, handle); + } + + // Mesh shaders may accept a payload. + crate::ShaderStage::Mesh => { + if let Some(handle) = ep.task_payload { + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(handle, &module.global_variables)); + } + info.insert_global_use(GlobalUse::READ, handle); + } + } + + // Other stages must not have a payload. + _ => { + if let Some(handle) = ep.task_payload { + return Err(EntryPointError::UnexpectedTaskPayload + .with_span_handle(handle, &module.global_variables)); + } + } + } + { use super::ShaderStages; @@ -694,7 +898,8 @@ impl super::Validator { crate::ShaderStage::Vertex => ShaderStages::VERTEX, crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, crate::ShaderStage::Compute => ShaderStages::COMPUTE, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Mesh => ShaderStages::MESH, + crate::ShaderStage::Task => ShaderStages::TASK, }; if !info.available_stages.contains(stage_bit) { @@ -716,6 +921,8 @@ impl super::Validator { built_ins: &mut argument_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, + has_task_payload: ep.task_payload.is_some(), }; ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; @@ -734,6 +941,8 @@ impl super::Validator { built_ins: &mut result_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, + has_task_payload: ep.task_payload.is_some(), }; ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; @@ -742,11 +951,25 @@ impl super::Validator { { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); } + if ep.stage == crate::ShaderStage::Mesh { + return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span()); + } + // Task shaders must have a single `MeshTaskSize` output, and nothing else. + if ep.stage == crate::ShaderStage::Task { + let ok = result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) + && result_built_ins.len() == 1 + && self.location_mask.is_empty(); + if !ok { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); + } + } if !self.blend_src_mask.is_empty() { info.dual_source_blending = true; } } else if ep.stage == crate::ShaderStage::Vertex { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } else if ep.stage == crate::ShaderStage::Task { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); } { @@ -771,6 +994,13 @@ impl super::Validator { continue; } + if var.space == crate::AddressSpace::TaskPayload { + if ep.task_payload != Some(var_handle) { + return Err(EntryPointError::WrongTaskPayloadUsed + .with_span_handle(var_handle, &module.global_variables)); + } + } + let allowed_usage = match var.space { crate::AddressSpace::Function => unreachable!(), crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY, @@ -792,6 +1022,15 @@ impl super::Validator { crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => { GlobalUse::READ | GlobalUse::WRITE | GlobalUse::QUERY } + crate::AddressSpace::TaskPayload => { + GlobalUse::READ + | GlobalUse::QUERY + | if ep.stage == crate::ShaderStage::Task { + GlobalUse::WRITE + } else { + GlobalUse::empty() + } + } crate::AddressSpace::PushConstant => GlobalUse::READ, }; if !allowed_usage.contains(usage) { @@ -811,6 +1050,46 @@ impl super::Validator { } } + // If this is a `Mesh` entry point, check its vertex and primitive output types. + // We verified previously that only mesh shaders can have `mesh_info`. + if let &Some(ref mesh_info) = &ep.mesh_info { + // Mesh shaders don't return any value. All their results are supplied through + // [`SetVertex`] and [`SetPrimitive`] calls. + if let Some((used_vertex_type, _)) = info.mesh_shader_info.vertex_type { + if used_vertex_type != mesh_info.vertex_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.vertex_output_type, &module.types)); + } + } + if let Some((used_primitive_type, _)) = info.mesh_shader_info.primitive_type { + if used_primitive_type != mesh_info.primitive_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.primitive_output_type, &module.types)); + } + } + + self.validate_mesh_output_type( + ep, + module, + mesh_info.vertex_output_type, + MeshOutputType::VertexOutput, + )?; + self.validate_mesh_output_type( + ep, + module, + mesh_info.primitive_output_type, + MeshOutputType::PrimitiveOutput, + )?; + } else { + // This is not a `Mesh` entry point, so ensure that it never tries to produce + // vertices or primitives. + if info.mesh_shader_info.vertex_type.is_some() + || info.mesh_shader_info.primitive_type.is_some() + { + return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + } + } + Ok(info) } } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 426b3d637d7..2460a46df4b 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -186,6 +186,8 @@ bitflags::bitflags! { /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store /// `f16`-precision values in `f32`s. const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28; + /// Support for task shaders, mesh shaders, and per-primitive fragment inputs + const MESH_SHADER = 1 << 29; } } @@ -278,6 +280,8 @@ bitflags::bitflags! { const VERTEX = 0x1; const FRAGMENT = 0x2; const COMPUTE = 0x4; + const MESH = 0x8; + const TASK = 0x10; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e8b83ff08f3..aa0633e1852 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -220,9 +220,12 @@ const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags { use crate::AddressSpace as As; match space { As::Function | As::Private => TypeFlags::ARGUMENT, - As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => { - TypeFlags::empty() - } + As::Uniform + | As::Storage { .. } + | As::Handle + | As::PushConstant + | As::WorkGroup + | As::TaskPayload => TypeFlags::empty(), } } diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml new file mode 100644 index 00000000000..1f8b4e23baa --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -0,0 +1,19 @@ +# Stolen from ray-query.toml + +god_mode = true +targets = "IR | ANALYSIS" + +[msl] +fake_missing_bindings = true +lang_version = [2, 4] +spirv_cross_compatibility = false +zero_initialize_workgroup_memory = false + +[hlsl] +shader_model = "V6_5" +fake_missing_bindings = true +zero_initialize_workgroup_memory = true + +[spv] +version = [1, 4] +capabilities = ["MeshShadingEXT"] diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl new file mode 100644 index 00000000000..70fc2aec333 --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -0,0 +1,71 @@ +enable mesh_shading; + +const positions = array( + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) +); +const colors = array( + vec4(0.,1.,0.,1.), + vec4(0.,0.,1.,1.), + vec4(1.,0.,0.,1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} +@mesh +@payload(taskPayload) +@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + setMeshOutputs(3, 1); + workgroupData = 2.0; + var v: VertexOutput; + + v.position = positions[0]; + v.color = colors[0] * taskPayload.colorMask; + setVertex(0, v); + + v.position = positions[1]; + v.color = colors[1] * taskPayload.colorMask; + setVertex(1, v); + + v.position = positions[2]; + v.color = colors[2] * taskPayload.colorMask; + setVertex(2, v); + + var p: PrimitiveOutput; + p.index = vec3(0, 1, 2); + p.cull = !taskPayload.visible; + p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); + setPrimitive(0, p); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/tests/out/analysis/spv-shadow.info.ron b/naga/tests/out/analysis/spv-shadow.info.ron index 6ddda61f5c6..b08a28438ed 100644 --- a/naga/tests/out/analysis/spv-shadow.info.ron +++ b/naga/tests/out/analysis/spv-shadow.info.ron @@ -18,7 +18,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -413,10 +413,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1591,12 +1595,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1685,6 +1693,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-access.info.ron b/naga/tests/out/analysis/wgsl-access.info.ron index 319f62bdf13..d297b09a404 100644 --- a/naga/tests/out/analysis/wgsl-access.info.ron +++ b/naga/tests/out/analysis/wgsl-access.info.ron @@ -42,7 +42,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -1197,10 +1197,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2523,10 +2527,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2563,10 +2571,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2612,10 +2624,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2655,10 +2671,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2749,10 +2769,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2870,10 +2894,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2922,10 +2950,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2977,10 +3009,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3029,10 +3065,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3084,10 +3124,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3148,10 +3192,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3221,10 +3269,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3297,10 +3349,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3397,10 +3453,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -3593,12 +3653,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4290,10 +4354,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -4742,10 +4810,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4812,6 +4884,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-collatz.info.ron b/naga/tests/out/analysis/wgsl-collatz.info.ron index 7ec5799d758..2796f544510 100644 --- a/naga/tests/out/analysis/wgsl-collatz.info.ron +++ b/naga/tests/out/analysis/wgsl-collatz.info.ron @@ -8,7 +8,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -275,12 +275,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -430,6 +434,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron new file mode 100644 index 00000000000..208e0aac84e --- /dev/null +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -0,0 +1,1211 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ | WRITE"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Bool, + width: 1, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(5), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 9, + assignable_global: None, + ty: Value(Pointer( + base: 4, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 4, + assignable_global: None, + ty: Value(Pointer( + base: 7, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 6, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(7), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: Some((4, 24)), + primitive_type: Some((7, 79)), + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + (""), + (""), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(8), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/analysis/wgsl-overrides.info.ron b/naga/tests/out/analysis/wgsl-overrides.info.ron index 0e0ae318042..a76c9c89c9b 100644 --- a/naga/tests/out/analysis/wgsl-overrides.info.ron +++ b/naga/tests/out/analysis/wgsl-overrides.info.ron @@ -8,7 +8,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -201,6 +201,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-storage-textures.info.ron b/naga/tests/out/analysis/wgsl-storage-textures.info.ron index fbbf7206c33..35b5a7e320c 100644 --- a/naga/tests/out/analysis/wgsl-storage-textures.info.ron +++ b/naga/tests/out/analysis/wgsl-storage-textures.info.ron @@ -11,7 +11,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -184,10 +184,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -396,6 +400,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.compact.ron b/naga/tests/out/ir/spv-fetch_depth.compact.ron index 1fbee2deb35..98f4426c3eb 100644 --- a/naga/tests/out/ir/spv-fetch_depth.compact.ron +++ b/naga/tests/out/ir/spv-fetch_depth.compact.ron @@ -196,6 +196,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.ron b/naga/tests/out/ir/spv-fetch_depth.ron index 186f78354ad..104de852c17 100644 --- a/naga/tests/out/ir/spv-fetch_depth.ron +++ b/naga/tests/out/ir/spv-fetch_depth.ron @@ -266,6 +266,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.compact.ron b/naga/tests/out/ir/spv-shadow.compact.ron index b49cd9b55be..bed86a5334d 100644 --- a/naga/tests/out/ir/spv-shadow.compact.ron +++ b/naga/tests/out/ir/spv-shadow.compact.ron @@ -974,6 +974,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -984,6 +985,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -994,6 +996,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1032,6 +1035,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.ron b/naga/tests/out/ir/spv-shadow.ron index e1f0f60b6bb..bdda1d18566 100644 --- a/naga/tests/out/ir/spv-shadow.ron +++ b/naga/tests/out/ir/spv-shadow.ron @@ -1252,6 +1252,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -1262,6 +1263,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -1272,6 +1274,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1310,6 +1313,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.compact.ron b/naga/tests/out/ir/spv-spec-constants.compact.ron index 3fa6ffef4ff..67eb29c2475 100644 --- a/naga/tests/out/ir/spv-spec-constants.compact.ron +++ b/naga/tests/out/ir/spv-spec-constants.compact.ron @@ -151,6 +151,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -510,6 +511,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -520,6 +522,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -530,6 +533,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -613,6 +617,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.ron b/naga/tests/out/ir/spv-spec-constants.ron index 94c90aa78f9..51686aa20eb 100644 --- a/naga/tests/out/ir/spv-spec-constants.ron +++ b/naga/tests/out/ir/spv-spec-constants.ron @@ -242,6 +242,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -616,6 +617,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -626,6 +628,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -636,6 +639,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -719,6 +723,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.compact.ron b/naga/tests/out/ir/wgsl-access.compact.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.compact.ron +++ b/naga/tests/out/ir/wgsl-access.compact.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.ron b/naga/tests/out/ir/wgsl-access.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.ron +++ b/naga/tests/out/ir/wgsl-access.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.compact.ron b/naga/tests/out/ir/wgsl-collatz.compact.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.compact.ron +++ b/naga/tests/out/ir/wgsl-collatz.compact.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.ron b/naga/tests/out/ir/wgsl-collatz.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.ron +++ b/naga/tests/out/ir/wgsl-collatz.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.compact.ron b/naga/tests/out/ir/wgsl-const_assert.compact.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.compact.ron +++ b/naga/tests/out/ir/wgsl-const_assert.compact.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.ron b/naga/tests/out/ir/wgsl-const_assert.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.ron +++ b/naga/tests/out/ir/wgsl-const_assert.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-index-by-value.compact.ron b/naga/tests/out/ir/wgsl-index-by-value.compact.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.compact.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.compact.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-index-by-value.ron b/naga/tests/out/ir/wgsl-index-by-value.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.compact.ron b/naga/tests/out/ir/wgsl-local-const.compact.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.compact.ron +++ b/naga/tests/out/ir/wgsl-local-const.compact.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.ron b/naga/tests/out/ir/wgsl-local-const.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.ron +++ b/naga/tests/out/ir/wgsl-local-const.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-mesh-shader.ron b/naga/tests/out/ir/wgsl-mesh-shader.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-must-use.compact.ron b/naga/tests/out/ir/wgsl-must-use.compact.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.compact.ron +++ b/naga/tests/out/ir/wgsl-must-use.compact.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-must-use.ron b/naga/tests/out/ir/wgsl-must-use.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.ron +++ b/naga/tests/out/ir/wgsl-must-use.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.compact.ron b/naga/tests/out/ir/wgsl-overrides.compact.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides.compact.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.ron b/naga/tests/out/ir/wgsl-overrides.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.ron +++ b/naga/tests/out/ir/wgsl-overrides.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.compact.ron b/naga/tests/out/ir/wgsl-storage-textures.compact.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.compact.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.compact.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.ron b/naga/tests/out/ir/wgsl-storage-textures.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.compact.ron b/naga/tests/out/ir/wgsl-texture-external.compact.ron index 2b9e1c8d5e4..689fe215e36 100644 --- a/naga/tests/out/ir/wgsl-texture-external.compact.ron +++ b/naga/tests/out/ir/wgsl-texture-external.compact.ron @@ -394,6 +394,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -416,6 +417,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -452,6 +455,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -483,6 +488,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.ron b/naga/tests/out/ir/wgsl-texture-external.ron index 2b9e1c8d5e4..689fe215e36 100644 --- a/naga/tests/out/ir/wgsl-texture-external.ron +++ b/naga/tests/out/ir/wgsl-texture-external.ron @@ -394,6 +394,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -416,6 +417,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -452,6 +455,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -483,6 +488,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron index 7186209f00e..7c0d856946f 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron @@ -116,6 +116,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.ron b/naga/tests/out/ir/wgsl-types_with_comments.ron index 480b0d2337f..34e44cb9653 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.ron @@ -172,6 +172,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index ffb0bf5a175..ffe2c7e7572 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1094,6 +1094,8 @@ impl Interface { wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment, wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute, + wgt::ShaderStages::MESH => naga::ShaderStage::Mesh, + wgt::ShaderStages::TASK => naga::ShaderStage::Task, _ => unreachable!(), } } @@ -1238,7 +1240,7 @@ impl Interface { } // check workgroup size limits - if shader_stage == naga::ShaderStage::Compute { + if shader_stage.compute_like() { let max_workgroup_size_limits = [ self.limits.max_compute_workgroup_size_x, self.limits.max_compute_workgroup_size_y, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index ba50eed76f0..d28854caa1d 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -2134,6 +2134,9 @@ impl super::Adapter { if features.contains(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN) { capabilities.push(spv::Capability::RayQueryPositionFetchKHR) } + if features.contains(wgt::Features::EXPERIMENTAL_MESH_SHADER) { + capabilities.push(spv::Capability::MeshShadingEXT); + } if self.private_caps.shader_integer_dot_product { // See . capabilities.extend(&[ diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index c8cc56018eb..c73394db261 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -231,7 +231,33 @@ impl RenderPass<'_> { self.inner.draw_indexed(indices, base_vertex, instances); } - /// Draws using a mesh shader pipeline + /// Draws using a mesh pipeline. + /// + /// The current pipeline must be a mesh pipeline. + /// + /// If the current pipeline has a task shader, run it with an workgroup for + /// every `vec3(i, j, k)` where `i`, `j`, and `k` are between `0` and + /// `group_count_x`, `group_count_y`, and `group_count_z`. The invocation with + /// index zero in each group is responsible for determining the mesh shader dispatch. + /// Its return value indicates the number of workgroups of mesh shaders to invoke. It also + /// passes a payload value for them to consume. Because each task workgroup is essentially + /// a mesh shader draw call, mesh workgroups dispatched by different task workgroups + /// cannot interact in any way, and `workgroup_id` corresponds to its location in the + /// calling specific task shader's dispatch group. + /// + /// If the current pipeline lacks a task shader, run its mesh shader with a + /// workgroup for every `vec3(i, j, k)` where `i`, `j`, and `k` are + /// between `0` and `group_count_x`, `group_count_y`, and `group_count_z`. + /// + /// Each mesh shader workgroup outputs a set of vertices and indices for primitives. + /// The indices outputted correspond to the vertices outputted by that same workgroup; + /// there is no global vertex buffer. These primitives are passed to the rasterizer and + /// essentially treated like a vertex shader output, except that the mesh shader may + /// choose to cull specific primitives or pass per-primitive non-interpolated values + /// to the fragment shader. As such, each primitive is then rendered with the current + /// pipeline's fragment shader, if present. Otherwise, [No Color Output mode] is used. + /// + /// [No Color Output mode]: https://www.w3.org/TR/webgpu/#no-color-output pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { self.inner .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); @@ -264,7 +290,7 @@ impl RenderPass<'_> { .draw_indexed_indirect(&indirect_buffer.inner, indirect_offset); } - /// Draws using a mesh shader pipeline, + /// Draws using a mesh pipeline, /// based on the contents of the `indirect_buffer` /// /// This is like calling [`RenderPass::draw_mesh_tasks`] but the contents of the call are specified in the `indirect_buffer`. diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index e887bb4b97e..35b74100d00 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -152,13 +152,15 @@ static_assertions::assert_impl_all!(FragmentState<'_>: Send, Sync); pub struct TaskState<'a> { /// The compiled shader module for this stage. pub module: &'a ShaderModule, - /// The name of the entry point in the compiled shader to use. + + /// The name of the task shader entry point in the shader module to use. /// - /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. - /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be - /// selected. + /// If [`Some`], there must be a task shader entry point with the given name + /// in `module`. Otherwise, there must be exactly one task shader entry + /// point in `module`, which will be selected. pub entry_point: Option<&'a str>, - /// Advanced options for when this pipeline is compiled + + /// Advanced options for when this pipeline is compiled. /// /// This implements `Default`, and for most users can be set to `Default::default()` pub compilation_options: PipelineCompilationOptions<'a>, @@ -238,7 +240,43 @@ static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); /// Describes a mesh shader (graphics) pipeline. /// -/// For use with [`Device::create_mesh_pipeline`]. +/// For use with [`Device::create_mesh_pipeline`]. A mesh pipeline is very much +/// like a render pipeline, except that instead of [`RenderPass::draw`] it is +/// invoked with [`RenderPass::draw_mesh_tasks`], and instead of a vertex shader +/// and a fragment shader: +/// +/// - [`task`] specifies an optional task shader entry point, which determines how +/// many groups of mesh shaders to dispatch. +/// +/// - [`mesh`] specifies a mesh shader entry point, which generates groups of +/// primitives to draw +/// +/// - [`fragment`] specifies as fragment shader for drawing those primitives, +/// just like in an ordinary render pipeline. +/// +/// The key difference is that, whereas a vertex shader is invoked on the +/// elements of vertex buffers, the task shader gets to decide how many mesh +/// shader workgroups to make, and then each mesh shader workgroup gets to +/// decide which primitives it wants to generate, and what their vertex +/// attributes are. Task and mesh shaders can use whatever they please as +/// inputs, like a compute shader. However, they cannot use specialized vertex +/// or index buffers. +/// +/// A mesh pipeline is invoked by [`RenderPass::draw_mesh_tasks`], which looks +/// like a compute shader dispatch with [`ComputePass::dispatch_workgroups`]: +/// you pass `x`, `y`, and `z` values indicating the number of task shaders to +/// invoke in parallel. The output value of the first thread in a task shader +/// workgroup determines how many mesh workgroups should be dispatched from there. +/// Those mesh workgroups also get a special payload passed from the task shader. +/// +/// If the task shader is omitted, then the (`x`, `y`, `z`) parameters to +/// `draw_mesh_tasks` are used to decide how many invocations of the mesh shader +/// to invoke directly, without a task payload. +/// +/// [vertex formats]: wgpu_types::VertexFormat +/// [`task`]: Self::task +/// [`mesh`]: Self::mesh +/// [`fragment`]: Self::fragment #[derive(Clone, Debug)] pub struct MeshPipelineDescriptor<'a> { /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification. @@ -263,8 +301,15 @@ pub struct MeshPipelineDescriptor<'a> { /// /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout pub layout: Option<&'a PipelineLayout>, - /// The compiled task stage, its entry point, and the color targets. + + /// The mesh pipeline's task shader. + /// + /// If this is `None`, the mesh pipeline has no task shader. Executing a + /// mesh drawing command simply dispatches a grid of mesh shaders directly. + /// + /// [`draw_mesh_tasks`]: RenderPass::draw_mesh_tasks pub task: Option>, + /// The compiled mesh stage and its entry point pub mesh: MeshState<'a>, /// The properties of the pipeline at the primitive assembly and rasterization level.