Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add associated constant for WGSL struct definition #44

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Having to manually lay out data into GPU buffers can become very tedious and err

The core trait is [`ShaderType`] which mainly contains metadata about the given type.

The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively.
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`WgslStruct`] trait allows to generate the WGSL struct definition.

Most data types can implement the above traits via their respective macros:

Expand Down
2 changes: 1 addition & 1 deletion derive/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ keywords = ["wgsl", "wgpu"]
categories = ["rendering"]

[dependencies]
syn = "2"
syn = "2.0.1"
quote = "1"
proc-macro2 = "1"
41 changes: 41 additions & 0 deletions derive/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,29 @@ impl FieldData {
fn ident(&self) -> &Ident {
self.field.ident.as_ref().unwrap()
}

fn wgsl_type(&self, root: &Path) -> TokenStream {
let ty = &self.field.ty;
quote! {
<#ty as #root::ShaderType>::wgsl_type()
}
}

fn wgsl_layout_attributes(&self) -> String {
let mut attribs = Vec::new();
if let Some((size, _)) = self.size {
attribs.push(format!("@size({})", size));
}
if let Some((align, _)) = self.align {
attribs.push(format!("@align({})", align));
}

if attribs.is_empty() {
String::new()
} else {
format!(" {}\n", &attribs.join(" "))
}
}
}

struct AlignmentAttr(u32);
Expand Down Expand Up @@ -522,6 +545,11 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let field_strings = field_data.iter().map(|data| data.ident().to_string());
let field_layout_attributes = field_data.iter().map(|data| data.wgsl_layout_attributes());
let field_wgsl_types = field_data.iter().map(|data| data.wgsl_type(root));
let name_string = name.to_string();

let set_contained_rt_sized_array_length = if is_runtime_sized {
quote! {
writer.ctx.rts_array_length = ::core::option::Option::Some(
Expand Down Expand Up @@ -607,6 +635,10 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
offset += #root::ShaderType::size(&self.#last_field_ident).get();
#root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
}

fn wgsl_type() -> ::std::string::String {
victorvde marked this conversation as resolved.
Show resolved Hide resolved
::std::string::ToString::to_string(#name_string)
}
}

impl #impl_generics #root::WriteInto for #name #ty_generics
Expand Down Expand Up @@ -642,6 +674,15 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
}
}

impl #impl_generics #root::WgslStruct for #name #ty_generics
victorvde marked this conversation as resolved.
Show resolved Hide resolved
{
fn wgsl_struct() -> ::std::string::String {
victorvde marked this conversation as resolved.
Show resolved Hide resolved
::std::format!("struct {} {{\n", #name_string)
#( + &::std::format!("{} {}: {},\n", #field_layout_attributes, #field_strings, #field_wgsl_types) )*
+ "}\n"
}
}

#extra
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/core/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ pub trait ShaderType {
#[doc(hidden)]
const UNIFORM_COMPAT_ASSERT: fn() = || {};

/// Returns the [WGSL type name](https://www.w3.org/TR/WGSL/#types) for the implementing Rust type.
///
/// Note that for structs, this is just the name of the struct. See also [`WgslStruct`].
fn wgsl_type() -> String;

/// Asserts that `Self` meets the requirements of the
/// [uniform address space restrictions on stored values](https://gpuweb.github.io/gpuweb/wgsl/#address-spaces-uniform) and the
/// [uniform address space layout constraints](https://gpuweb.github.io/gpuweb/wgsl/#address-space-layout-constraints)
Expand Down Expand Up @@ -244,3 +249,8 @@ pub trait CreateFrom: Sized {
where
B: BufferRef;
}

pub trait WgslStruct {
/// Returns the [WGSL struct](https://www.w3.org/TR/WGSL/#struct-types) definition for the implementing Rust struct
fn wgsl_struct() -> String;
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ mod impls;

pub use crate::core::{
CalculateSizeFor, DynamicStorageBuffer, DynamicUniformBuffer, ShaderSize, ShaderType,
StorageBuffer, UniformBuffer,
StorageBuffer, UniformBuffer, WgslStruct,
};
pub use types::runtime_sized_array::ArrayLength;

Expand Down Expand Up @@ -152,6 +152,7 @@ pub mod private {
pub use super::core::Reader;
pub use super::core::RuntimeSizedArray;
pub use super::core::SizeValue;
pub use super::core::WgslStruct;
pub use super::core::WriteInto;
pub use super::core::Writer;
pub use super::types::array::ArrayMetadata;
Expand Down
4 changes: 4 additions & 0 deletions src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ impl<T: ShaderType + ShaderSize, const N: usize> ShaderType for [T; N] {
},
])
};

fn wgsl_type() -> String {
format!("array<{},{}>", T::wgsl_type(), N)
victorvde marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl<T: ShaderSize, const N: usize> ShaderSize for [T; N] {}
Expand Down
4 changes: 4 additions & 0 deletions src/types/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ macro_rules! impl_matrix_inner {
},
}
};

fn wgsl_type() -> ::std::string::String {
::std::format!("mat{}x{}<{}>", $c, $r, <$el_ty as $crate::private::ShaderType>::wgsl_type())
}
}

impl<$($generics)*> $crate::private::ShaderSize for $type
Expand Down
7 changes: 7 additions & 0 deletions src/types/runtime_sized_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub struct ArrayLength;
impl ShaderType for ArrayLength {
type ExtraMetadata = ();
const METADATA: Metadata<Self::ExtraMetadata> = Metadata::from_alignment_and_size(4, 4);
fn wgsl_type() -> String {
"u32".to_string()
}
}

impl ShaderSize for ArrayLength {}
Expand Down Expand Up @@ -151,6 +154,10 @@ macro_rules! impl_rts_array_inner {
.mul($crate::private::Length::length(self).max(1) as ::core::primitive::u64)
.0
}

fn wgsl_type() -> ::std::string::String {
::std::format!("array<{}>", <T as $crate::private::ShaderType>::wgsl_type())
}
}

impl<$($generics)*> $crate::private::RuntimeSizedArray for $type
Expand Down
39 changes: 21 additions & 18 deletions src/types/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ use core::num::{NonZeroI32, NonZeroU32, Wrapping};
use core::sync::atomic::{AtomicI32, AtomicU32};

macro_rules! impl_basic_traits {
($type:ty) => {
($type:ty, $wgsl:literal) => {
impl ShaderType for $type {
type ExtraMetadata = ();
const METADATA: Metadata<Self::ExtraMetadata> = Metadata::from_alignment_and_size(4, 4);
fn wgsl_type() -> String {
$wgsl.to_string()
}
}

impl ShaderSize for $type {}
};
}

macro_rules! impl_traits {
($type:ty) => {
impl_basic_traits!($type);
($type:ty, $wgsl:literal) => {
impl_basic_traits!($type, $wgsl);

impl WriteInto for $type {
#[inline]
Expand All @@ -43,13 +46,13 @@ macro_rules! impl_traits {
};
}

impl_traits!(f32);
impl_traits!(u32);
impl_traits!(i32);
impl_traits!(f32, "f32");
impl_traits!(u32, "u32");
impl_traits!(i32, "i32");

macro_rules! impl_traits_for_non_zero_option {
($type:ty) => {
impl_basic_traits!(Option<$type>);
($type:ty, $wgsl:literal) => {
impl_basic_traits!(Option<$type>, $wgsl);

impl WriteInto for Option<$type> {
#[inline]
Expand All @@ -75,12 +78,12 @@ macro_rules! impl_traits_for_non_zero_option {
};
}

impl_traits_for_non_zero_option!(NonZeroU32);
impl_traits_for_non_zero_option!(NonZeroI32);
impl_traits_for_non_zero_option!(NonZeroU32, "u32");
impl_traits_for_non_zero_option!(NonZeroI32, "i32");

macro_rules! impl_traits_for_wrapping {
($type:ty) => {
impl_basic_traits!($type);
($type:ty, $wgsl:literal) => {
impl_basic_traits!($type, $wgsl);

impl WriteInto for $type {
#[inline]
Expand All @@ -105,12 +108,12 @@ macro_rules! impl_traits_for_wrapping {
};
}

impl_traits_for_wrapping!(Wrapping<u32>);
impl_traits_for_wrapping!(Wrapping<i32>);
impl_traits_for_wrapping!(Wrapping<u32>, "u32");
impl_traits_for_wrapping!(Wrapping<i32>, "i32");

macro_rules! impl_traits_for_atomic {
($type:ty) => {
impl_basic_traits!($type);
($type:ty, $wgsl:literal) => {
impl_basic_traits!($type, $wgsl);

impl WriteInto for $type {
#[inline]
Expand All @@ -136,8 +139,8 @@ macro_rules! impl_traits_for_atomic {
};
}

impl_traits_for_atomic!(AtomicU32);
impl_traits_for_atomic!(AtomicI32);
impl_traits_for_atomic!(AtomicU32, "atomic<u32>");
impl_traits_for_atomic!(AtomicI32, "atomic<i32>");

macro_rules! impl_marker_trait_for_f32 {
($trait:path) => {
Expand Down
4 changes: 4 additions & 0 deletions src/types/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ macro_rules! impl_vector_inner {
extra: ()
}
};

fn wgsl_type() -> ::std::string::String {
::std::format!("vec{}<{}>", $n, <$el_ty as $crate::private::ShaderType>::wgsl_type())
}
}

impl<$($generics)*> $crate::private::ShaderSize for $type
Expand Down
4 changes: 4 additions & 0 deletions src/types/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ macro_rules! impl_wrapper_inner {
fn size(&self) -> ::core::num::NonZeroU64 {
<T as $crate::private::ShaderType>::size(&self$($get_ref)*)
}

fn wgsl_type() -> ::std::string::String {
<T as $crate::private::ShaderType>::wgsl_type()
}
}
impl<$($generics)*> $crate::private::ShaderSize for $type
where
Expand Down
47 changes: 46 additions & 1 deletion tests/general.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use encase::{ArrayLength, CalculateSizeFor, ShaderType, StorageBuffer};
use encase::{ArrayLength, CalculateSizeFor, ShaderType, StorageBuffer, WgslStruct};

macro_rules! gen {
($rng:ident, $ty:ty) => {{
Expand Down Expand Up @@ -143,3 +143,48 @@ fn all_types() {

assert_eq!(raw_buffer, raw_buffer_2);
}

#[test]
fn wgsl_struct() {
assert_eq!(A::wgsl_type(), "A");
assert_eq!(
A::wgsl_struct(),
"struct A {
f: f32,
u: u32,
i: i32,
nu: u32,
ni: i32,
wu: u32,
wi: i32,
au: atomic<u32>,
ai: atomic<i32>,
v2: vec2<f32>,
v3: vec3<u32>,
v4: vec4<i32>,
p2: vec2<f32>,
p3: vec3<f32>,
mat2: mat2x2<f32>,
mat2x3: mat3x2<f32>,
mat2x4: mat4x2<f32>,
mat3x2: mat2x3<f32>,
mat3: mat3x3<f32>,
mat3x4: mat4x3<f32>,
mat4x2: mat2x4<f32>,
mat4x3: mat3x4<f32>,
mat4: mat4x4<f32>,
arrf: array<f32,32>,
arru: array<u32,32>,
arri: array<i32,32>,
arrvf: array<vec2<f32>,16>,
arrvu: array<vec3<u32>,16>,
arrvi: array<vec4<i32>,16>,
arrm2: array<mat2x2<f32>,8>,
arrm3: array<mat3x3<f32>,8>,
arrm4: array<mat4x4<f32>,8>,
rt_arr_len: u32,
rt_arr: array<mat3x2<f32>>,
}
"
);
}
8 changes: 0 additions & 8 deletions tests/shaders/array_length.wgsl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
struct A {
array_length: u32,
array_length_call_ret_val: u32,
a: vec3<u32>,
@align(16)
arr: array<u32>,
}

@group(0) @binding(0)
var<storage> in: A;

Expand Down
23 changes: 0 additions & 23 deletions tests/shaders/general.wgsl
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
struct A {
u: u32,
v: u32,
w: vec2<u32>,
@size(16) @align(8)
x: u32,
xx: u32,
}

struct B {
a: vec2<u32>,
b: vec3<u32>,
c: u32,
d: u32,
@align(16)
e: A,
f: vec3<u32>,
g: array<A, 3>,
h: i32,
@align(32)
i: array<A>,
}

@group(0) @binding(0)
var<storage> in: B;

Expand Down
Loading