diff --git a/Cargo.toml b/Cargo.toml index 43fa6de..5f58629 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "encase" version = "0.6.1" edition = "2021" -rust-version = "1.63" +rust-version = "1.64" license = "MIT-0" readme = "./README.md" diff --git a/README.md b/README.md index 757d4e9..02ffe39 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Having to manually lay out data into GPU buffers can become very tedious and err The core trait is [`ShaderType`] which mainly contains metadata about the given type. -The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. +The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`ShaderStructDeclaration`] trait allows to generate the WGSL struct definition. Most data types can implement the above traits via their respective macros: diff --git a/derive/impl/Cargo.toml b/derive/impl/Cargo.toml index efa3d97..bd60531 100644 --- a/derive/impl/Cargo.toml +++ b/derive/impl/Cargo.toml @@ -11,6 +11,6 @@ keywords = ["wgsl", "wgpu"] categories = ["rendering"] [dependencies] -syn = "2" +syn = "2.0.1" quote = "1" proc-macro2 = "1" \ No newline at end of file diff --git a/derive/impl/src/lib.rs b/derive/impl/src/lib.rs index 4fc75d4..b55ad99 100644 --- a/derive/impl/src/lib.rs +++ b/derive/impl/src/lib.rs @@ -14,7 +14,7 @@ pub use syn; #[macro_export] macro_rules! implement { ($path:expr) => { - #[proc_macro_derive(ShaderType, attributes(align, size))] + #[proc_macro_derive(ShaderType, attributes(align, size, shader_atomic))] pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput); let expanded = encase_derive_impl::derive_shader_type(input, &$path); @@ -40,6 +40,7 @@ struct FieldData { pub field: syn::Field, pub size: Option<(u32, Span)>, pub align: Option<(u32, Span)>, + pub shader_atomic: bool, } impl FieldData { @@ -97,6 +98,29 @@ impl FieldData { fn ident(&self) -> &Ident { self.field.ident.as_ref().unwrap() } + + fn wgsl_type(&self, root: &Path) -> TokenStream { + let ty = &self.field.ty; + quote! { + <#ty as #root::ShaderType>::SHADER_TYPE + } + } + + fn wgsl_layout_attributes(&self) -> String { + let mut attribs = Vec::new(); + if let Some((size, _)) = self.size { + attribs.push(format!("@size({})", size)); + } + if let Some((align, _)) = self.align { + attribs.push(format!("@align({})", align)); + } + + if attribs.is_empty() { + String::new() + } else { + format!(" {}\n", &attribs.join(" ")) + } + } } struct AlignmentAttr(u32); @@ -195,21 +219,25 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { field: field.clone(), size: None, align: None, + shader_atomic: false, }; for attr in &field.attrs { - if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) { - continue; - } - match attr.meta.require_list() { - Ok(meta_list) => { - let span = meta_list.tokens.span(); - if meta_list.path.is_ident("align") { + if attr.meta.path().is_ident("align") { + match attr.meta.require_list() { + Ok(meta_list) => { + let span = meta_list.tokens.span(); let res = attr.parse_args::(); match res { Ok(val) => data.align = Some((val.0, span)), Err(err) => errors.append(err), } - } else if meta_list.path.is_ident("size") { + } + Err(err) => errors.append(err), + } + } else if attr.meta.path().is_ident("size") { + match attr.meta.require_list() { + Ok(meta_list) => { + let span = meta_list.tokens.span(); let res = if i == last_field_index { attr.parse_args::().map(|val| match val { SizeAttr::Runtime => { @@ -227,9 +255,16 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { Err(err) => errors.append(err), } } + Err(err) => errors.append(err), } - Err(err) => errors.append(err), - }; + } else if attr.meta.path().is_ident("shader_atomic") { + match attr.meta.require_path_only() { + Ok(_) => { + data.shader_atomic = true; + } + Err(err) => errors.append(err), + } + } } data }) @@ -516,12 +551,32 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { let field_types_2 = field_types.clone(); let field_types_3 = field_types.clone(); let field_types_4 = field_types.clone(); + let field_types_5 = field_types.clone(); let all_other = field_types.clone().take(last_field_index); let last_field_type = &last_field.field.ty; let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let field_strings = field_data.iter().map(|data| data.ident().to_string()); + let field_layout_attributes = field_data.iter().map(|data| data.wgsl_layout_attributes()); + let field_wgsl_types = field_data.iter().map(|data| data.wgsl_type(root)); + let field_atomics_pre = field_data.iter().map(|data| { + if data.shader_atomic { + Literal::string("atomic<") + } else { + Literal::string("") + } + }); + let field_atomics_post = field_data.iter().map(|data| { + if data.shader_atomic { + Literal::string(">") + } else { + Literal::string("") + } + }); + let name_string = name.to_string(); + let set_contained_rt_sized_array_length = if is_runtime_sized { quote! { writer.ctx.rts_array_length = ::core::option::Option::Some( @@ -607,6 +662,8 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { offset += #root::ShaderType::size(&self.#last_field_ident).get(); #root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0 } + + const SHADER_TYPE: &'static ::core::primitive::str = #name_string; } impl #impl_generics #root::WriteInto for #name #ty_generics @@ -642,6 +699,20 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { } } + impl #impl_generics #root::ShaderStructDeclaration for #name #ty_generics + where + #( #field_types_5: #root::ShaderType, )* + { + const SHADER_STRUCT_DECLARATION: &'static ::core::primitive::str = + #root::ConstStr::new() + .str("struct ").str(#name_string).str(" {\n") + #( + .str(#field_layout_attributes).str(" ").str(#field_strings).str(": ") + .str(#field_atomics_pre).str(#field_wgsl_types).str(#field_atomics_post).str(",\n") + )* + .str("}\n").as_str(); + } + #extra } } diff --git a/src/const_str.rs b/src/const_str.rs new file mode 100644 index 0000000..48a678b --- /dev/null +++ b/src/const_str.rs @@ -0,0 +1,105 @@ +// Const string implementation for SHADER_TYPE and SHADER_STRUCT_DECLARATION +// Used instead of crates like const_str because of E0401 when trying to use them in traits +// See also https://old.reddit.com/r/rust/comments/sv119a/concat_static_str_at_compile_time/ + +// Must be constant to avoid running into E0401. Should only affect compilation. +const BUFFER_SIZE: usize = 8192; + +pub struct ConstStr { + data: [u8; BUFFER_SIZE], + len: usize, +} + +impl ConstStr { + pub const fn new() -> ConstStr { + ConstStr { + data: [0u8; BUFFER_SIZE], + len: 0, + } + } + + pub const fn str(mut self, s: &str) -> Self { + let b = s.as_bytes(); + let mut index = 0; + while index < b.len() { + self.data[self.len] = b[index]; + self.len += 1; + index += 1; + } + + self + } + + pub const fn u64(mut self, x: u64) -> Self { + let mut x2 = x; + let mut l = 0; + loop { + l += 1; + x2 /= 10; + if x2 == 0 { + break; + } + } + let mut x3 = x; + let mut index = 0; + loop { + self.data[self.len + l - 1 - index] = (x3 % 10) as u8 + b'0'; + index += 1; + x3 /= 10; + if x3 == 0 { + break; + } + } + self.len += l; + + self + } + + pub const fn as_str(&self) -> &str { + // SAFETY: safe because this is only used in const, and should be correct by construction + unsafe { + std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.data.as_ptr(), self.len)) + } + } +} + +mod test { + use super::ConstStr; + + trait Name { + const NAME: &'static str; + } + + trait Prefix: Name {} + + trait Root: Name {} + + struct Kilo; + + impl Prefix for Kilo {} + + struct Meter; + + impl Root for Meter {} + + impl Name for Kilo { + const NAME: &'static str = "kilo"; + } + + impl Name for Meter { + const NAME: &'static str = "meter"; + } + + impl Name for (P, R) { + const NAME: &'static str = ConstStr::new() + .str(P::NAME) + .str(R::NAME) + .u64(1234567) + .as_str(); + } + + #[test] + fn test_trait() { + assert_eq!(<(Kilo, Meter)>::NAME, "kilometer1234567"); + } +} diff --git a/src/core/traits.rs b/src/core/traits.rs index 60caa5c..01401ba 100644 --- a/src/core/traits.rs +++ b/src/core/traits.rs @@ -58,6 +58,11 @@ pub trait ShaderType { #[doc(hidden)] const METADATA: Metadata; + /// The [WGSL type name](https://www.w3.org/TR/WGSL/#types) for the implementing Rust type. + /// + /// Note that for structs, this is just the name of the struct. See also [`ShaderStructDeclaration`]. + const SHADER_TYPE: &'static str; + /// Represents the minimum size of `Self` (equivalent to [GPUBufferBindingLayout.minBindingSize](https://gpuweb.github.io/gpuweb/#dom-gpubufferbindinglayout-minbindingsize)) /// /// For [WGSL fixed-footprint types](https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types) @@ -244,3 +249,8 @@ pub trait CreateFrom: Sized { where B: BufferRef; } + +pub trait ShaderStructDeclaration { + /// The [WGSL struct](https://www.w3.org/TR/WGSL/#struct-types) definition for the implementing Rust struct + const SHADER_STRUCT_DECLARATION: &'static str; +} diff --git a/src/lib.rs b/src/lib.rs index cc399da..29dd7e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,14 +96,15 @@ pub use encase_derive::ShaderType; #[macro_use] mod utils; +mod const_str; mod core; mod types; mod impls; pub use crate::core::{ - CalculateSizeFor, DynamicStorageBuffer, DynamicUniformBuffer, ShaderSize, ShaderType, - StorageBuffer, UniformBuffer, + CalculateSizeFor, DynamicStorageBuffer, DynamicUniformBuffer, ShaderSize, + ShaderStructDeclaration, ShaderType, StorageBuffer, UniformBuffer, }; pub use types::runtime_sized_array::ArrayLength; @@ -143,6 +144,7 @@ pub mod matrix { #[doc(hidden)] pub mod private { pub use super::build_struct; + pub use super::const_str::ConstStr; pub use super::core::AlignmentValue; pub use super::core::BufferMut; pub use super::core::BufferRef; @@ -151,6 +153,7 @@ pub mod private { pub use super::core::ReadFrom; pub use super::core::Reader; pub use super::core::RuntimeSizedArray; + pub use super::core::ShaderStructDeclaration; pub use super::core::SizeValue; pub use super::core::WriteInto; pub use super::core::Writer; diff --git a/src/types/array.rs b/src/types/array.rs index b94c5ab..688f846 100644 --- a/src/types/array.rs +++ b/src/types/array.rs @@ -1,3 +1,4 @@ +use crate::const_str::ConstStr; use crate::core::{ BufferMut, BufferRef, CreateFrom, Metadata, ReadFrom, Reader, ShaderSize, ShaderType, SizeValue, WriteInto, Writer, @@ -55,6 +56,14 @@ impl ShaderType for [T; N] { }, ]) }; + + const SHADER_TYPE: &'static ::core::primitive::str = ConstStr::new() + .str("array<") + .str(T::SHADER_TYPE) + .str(", ") + .u64(N as u64) + .str(">") + .as_str(); } impl ShaderSize for [T; N] {} diff --git a/src/types/matrix.rs b/src/types/matrix.rs index be593b5..6f3e810 100644 --- a/src/types/matrix.rs +++ b/src/types/matrix.rs @@ -147,6 +147,9 @@ macro_rules! impl_matrix_inner { }, } }; + + const SHADER_TYPE: &'static ::core::primitive::str = $crate::private::ConstStr::new() + .str("mat").u64($c).str("x").u64($r).str("<").str(<$el_ty as $crate::private::ShaderType>::SHADER_TYPE).str(">").as_str(); } impl<$($generics)*> $crate::private::ShaderSize for $type diff --git a/src/types/runtime_sized_array.rs b/src/types/runtime_sized_array.rs index 6fc0d66..39de368 100644 --- a/src/types/runtime_sized_array.rs +++ b/src/types/runtime_sized_array.rs @@ -34,6 +34,7 @@ pub struct ArrayLength; impl ShaderType for ArrayLength { type ExtraMetadata = (); const METADATA: Metadata = Metadata::from_alignment_and_size(4, 4); + const SHADER_TYPE: &'static str = "u32"; } impl ShaderSize for ArrayLength {} @@ -151,6 +152,9 @@ macro_rules! impl_rts_array_inner { .mul($crate::private::Length::length(self).max(1) as ::core::primitive::u64) .0 } + + const SHADER_TYPE: &'static ::core::primitive::str = $crate::private::ConstStr::new(). + str("array<").str(::SHADER_TYPE).str(">").as_str(); } impl<$($generics)*> $crate::private::RuntimeSizedArray for $type @@ -264,7 +268,7 @@ mod array_length { #[test] fn derived_traits() { - assert_eq!(ArrayLength::default(), ArrayLength.clone()); + assert_eq!(ArrayLength, ArrayLength.clone()); assert_eq!(format!("{ArrayLength:?}"), "ArrayLength"); } diff --git a/src/types/scalar.rs b/src/types/scalar.rs index cf03d81..90fbc15 100644 --- a/src/types/scalar.rs +++ b/src/types/scalar.rs @@ -6,10 +6,11 @@ use core::num::{NonZeroI32, NonZeroU32, Wrapping}; use core::sync::atomic::{AtomicI32, AtomicU32}; macro_rules! impl_basic_traits { - ($type:ty) => { + ($type:ty, $wgsl:literal) => { impl ShaderType for $type { type ExtraMetadata = (); const METADATA: Metadata = Metadata::from_alignment_and_size(4, 4); + const SHADER_TYPE: &'static str = $wgsl; } impl ShaderSize for $type {} @@ -17,8 +18,8 @@ macro_rules! impl_basic_traits { } macro_rules! impl_traits { - ($type:ty) => { - impl_basic_traits!($type); + ($type:ty, $wgsl:literal) => { + impl_basic_traits!($type, $wgsl); impl WriteInto for $type { #[inline] @@ -43,13 +44,13 @@ macro_rules! impl_traits { }; } -impl_traits!(f32); -impl_traits!(u32); -impl_traits!(i32); +impl_traits!(f32, "f32"); +impl_traits!(u32, "u32"); +impl_traits!(i32, "i32"); macro_rules! impl_traits_for_non_zero_option { - ($type:ty) => { - impl_basic_traits!(Option<$type>); + ($type:ty, $wgsl:literal) => { + impl_basic_traits!(Option<$type>, $wgsl); impl WriteInto for Option<$type> { #[inline] @@ -75,12 +76,12 @@ macro_rules! impl_traits_for_non_zero_option { }; } -impl_traits_for_non_zero_option!(NonZeroU32); -impl_traits_for_non_zero_option!(NonZeroI32); +impl_traits_for_non_zero_option!(NonZeroU32, "u32"); +impl_traits_for_non_zero_option!(NonZeroI32, "i32"); macro_rules! impl_traits_for_wrapping { - ($type:ty) => { - impl_basic_traits!($type); + ($type:ty, $wgsl:literal) => { + impl_basic_traits!($type, $wgsl); impl WriteInto for $type { #[inline] @@ -105,12 +106,12 @@ macro_rules! impl_traits_for_wrapping { }; } -impl_traits_for_wrapping!(Wrapping); -impl_traits_for_wrapping!(Wrapping); +impl_traits_for_wrapping!(Wrapping, "u32"); +impl_traits_for_wrapping!(Wrapping, "i32"); macro_rules! impl_traits_for_atomic { - ($type:ty) => { - impl_basic_traits!($type); + ($type:ty, $wgsl:literal) => { + impl_basic_traits!($type, $wgsl); impl WriteInto for $type { #[inline] @@ -136,8 +137,8 @@ macro_rules! impl_traits_for_atomic { }; } -impl_traits_for_atomic!(AtomicU32); -impl_traits_for_atomic!(AtomicI32); +impl_traits_for_atomic!(AtomicU32, "u32"); +impl_traits_for_atomic!(AtomicI32, "i32"); macro_rules! impl_marker_trait_for_f32 { ($trait:path) => { diff --git a/src/types/vector.rs b/src/types/vector.rs index f826144..781478d 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -126,6 +126,9 @@ macro_rules! impl_vector_inner { extra: () } }; + + const SHADER_TYPE: &'static ::core::primitive::str = $crate::private::ConstStr::new() + .str("vec").u64($n).str("<").str(<$el_ty as $crate::private::ShaderType>::SHADER_TYPE).str(">").as_str(); } impl<$($generics)*> $crate::private::ShaderSize for $type diff --git a/src/types/wrapper.rs b/src/types/wrapper.rs index 20d961d..f76ef1c 100644 --- a/src/types/wrapper.rs +++ b/src/types/wrapper.rs @@ -49,6 +49,8 @@ macro_rules! impl_wrapper_inner { fn size(&self) -> ::core::num::NonZeroU64 { ::size(&self$($get_ref)*) } + + const SHADER_TYPE: &'static ::core::primitive::str = ::SHADER_TYPE; } impl<$($generics)*> $crate::private::ShaderSize for $type where diff --git a/tests/general.rs b/tests/general.rs index 23a5835..3109152 100644 --- a/tests/general.rs +++ b/tests/general.rs @@ -1,4 +1,4 @@ -use encase::{ArrayLength, CalculateSizeFor, ShaderType, StorageBuffer}; +use encase::{ArrayLength, CalculateSizeFor, ShaderStructDeclaration, ShaderType, StorageBuffer}; macro_rules! gen { ($rng:ident, $ty:ty) => {{ @@ -38,6 +38,10 @@ struct A { wi: core::num::Wrapping, au: core::sync::atomic::AtomicU32, ai: core::sync::atomic::AtomicI32, + #[shader_atomic] + aau: u32, + #[shader_atomic] + aai: i32, v2: mint::Vector2, v3: mint::Vector3, v4: mint::Vector4, @@ -77,6 +81,8 @@ fn gen_a(rng: &mut rand::rngs::StdRng) -> A { wi: core::num::Wrapping(gen!(rng, i32)), au: core::sync::atomic::AtomicU32::new(gen!(rng, u32)), ai: core::sync::atomic::AtomicI32::new(gen!(rng, i32)), + aau: gen!(rng, u32), + aai: gen!(rng, i32), v2: mint::Vector2::from(gen_arr!(rng, f32, 2)), v3: mint::Vector3::from(gen_arr!(rng, u32, 3)), v4: mint::Vector4::from(gen_arr!(rng, i32, 4)), @@ -111,12 +117,12 @@ fn size() { let mut rng = rand::rngs::StdRng::seed_from_u64(1234); let a = gen_a(&mut rng); - assert_eq!(a.size().get(), 4080); + assert_eq!(a.size().get(), 4096); } #[test] fn calculate_size_for() { - assert_eq!(<&A>::calculate_size_for(12).get(), 2832); + assert_eq!(<&A>::calculate_size_for(12).get(), 2848); } #[test] @@ -143,3 +149,50 @@ fn all_types() { assert_eq!(raw_buffer, raw_buffer_2); } + +#[test] +fn wgsl_struct() { + assert_eq!(A::SHADER_TYPE, "A"); + assert_eq!( + A::SHADER_STRUCT_DECLARATION, + "struct A { + f: f32, + u: u32, + i: i32, + nu: u32, + ni: i32, + wu: u32, + wi: i32, + au: u32, + ai: i32, + aau: atomic, + aai: atomic, + v2: vec2, + v3: vec3, + v4: vec4, + p2: vec2, + p3: vec3, + mat2: mat2x2, + mat2x3: mat3x2, + mat2x4: mat4x2, + mat3x2: mat2x3, + mat3: mat3x3, + mat3x4: mat4x3, + mat4x2: mat2x4, + mat4x3: mat3x4, + mat4: mat4x4, + arrf: array, + arru: array, + arri: array, + arrvf: array, 16>, + arrvu: array, 16>, + arrvi: array, 16>, + arrm2: array, 8>, + arrm3: array, 8>, + arrm4: array, 8>, + rt_arr_len: u32, + rt_arr: array>, +} +" + ); +} diff --git a/tests/shaders/array_length.wgsl b/tests/shaders/array_length.wgsl index 0bec1dd..02c8571 100644 --- a/tests/shaders/array_length.wgsl +++ b/tests/shaders/array_length.wgsl @@ -1,11 +1,3 @@ -struct A { - array_length: u32, - array_length_call_ret_val: u32, - a: vec3, - @align(16) - arr: array, -} - @group(0) @binding(0) var in: A; diff --git a/tests/shaders/general.wgsl b/tests/shaders/general.wgsl index bf822fd..ccbd57d 100644 --- a/tests/shaders/general.wgsl +++ b/tests/shaders/general.wgsl @@ -1,26 +1,3 @@ -struct A { - u: u32, - v: u32, - w: vec2, - @size(16) @align(8) - x: u32, - xx: u32, -} - -struct B { - a: vec2, - b: vec3, - c: u32, - d: u32, - @align(16) - e: A, - f: vec3, - g: array, - h: i32, - @align(32) - i: array, -} - @group(0) @binding(0) var in: B; diff --git a/tests/wgpu.rs b/tests/wgpu.rs index 48089d6..bbaaab5 100644 --- a/tests/wgpu.rs +++ b/tests/wgpu.rs @@ -1,7 +1,7 @@ -use encase::{ArrayLength, ShaderType, StorageBuffer}; +use encase::{ArrayLength, ShaderStructDeclaration, ShaderType, StorageBuffer}; use futures::executor::block_on; use mint::{Vector2, Vector3}; -use wgpu::{include_wgsl, util::DeviceExt}; +use wgpu::{util::DeviceExt, ShaderModuleDescriptor, ShaderSource}; #[derive(Debug, ShaderType, PartialEq)] struct A { @@ -104,7 +104,13 @@ fn test_wgpu() { in_buffer.write(&b).unwrap(); assert_eq!(in_byte_buffer.len(), b.size().get() as _); - let shader = include_wgsl!("./shaders/general.wgsl"); + let shader_text = A::SHADER_STRUCT_DECLARATION.to_string() + + B::SHADER_STRUCT_DECLARATION + + include_str!("./shaders/general.wgsl"); + let shader = ShaderModuleDescriptor { + label: Some("./shaders/general.wgsl"), + source: ShaderSource::Wgsl(shader_text.into()), + }; let out_byte_buffer = in_out::(shader, &in_byte_buffer, false); assert_eq!(in_byte_buffer, out_byte_buffer); @@ -139,7 +145,12 @@ fn array_length() { in_buffer.write(&in_value).unwrap(); assert_eq!(in_byte_buffer.len(), in_value.size().get() as _); - let shader = include_wgsl!("./shaders/array_length.wgsl"); + let shader_text = + A::SHADER_STRUCT_DECLARATION.to_string() + include_str!("./shaders/array_length.wgsl"); + let shader = ShaderModuleDescriptor { + label: Some("./shaders/array_length.wgsl"), + source: ShaderSource::Wgsl(shader_text.into()), + }; let out_byte_buffer = in_out::(shader, &in_byte_buffer, false); assert_eq!(in_byte_buffer, out_byte_buffer);