Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add associated constant for WGSL struct definition #44

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "encase"
version = "0.6.1"
edition = "2021"
rust-version = "1.63"
rust-version = "1.64"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to bump the MSRV?


license = "MIT-0"
readme = "./README.md"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Having to manually lay out data into GPU buffers can become very tedious and err

The core trait is [`ShaderType`] which mainly contains metadata about the given type.

The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively.
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`ShaderStructDeclaration`] trait allows to generate the WGSL struct definition.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`ShaderStructDeclaration`] trait allows to generate the WGSL struct definition.
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`ShaderStructDeclaration`] trait contains the WGSL struct definition.


Most data types can implement the above traits via their respective macros:

Expand Down
2 changes: 1 addition & 1 deletion derive/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ keywords = ["wgsl", "wgpu"]
categories = ["rendering"]

[dependencies]
syn = "2"
syn = "2.0.1"
quote = "1"
proc-macro2 = "1"
93 changes: 82 additions & 11 deletions derive/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub use syn;
#[macro_export]
macro_rules! implement {
($path:expr) => {
#[proc_macro_derive(ShaderType, attributes(align, size))]
#[proc_macro_derive(ShaderType, attributes(align, size, shader_atomic))]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing the new attribute!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mention it in the docs here?

/// Used to implement `ShaderType` for structs

pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
let expanded = encase_derive_impl::derive_shader_type(input, &$path);
Expand All @@ -40,6 +40,7 @@ struct FieldData {
pub field: syn::Field,
pub size: Option<(u32, Span)>,
pub align: Option<(u32, Span)>,
pub shader_atomic: bool,
}

impl FieldData {
Expand Down Expand Up @@ -97,6 +98,29 @@ impl FieldData {
fn ident(&self) -> &Ident {
self.field.ident.as_ref().unwrap()
}

fn wgsl_type(&self, root: &Path) -> TokenStream {
let ty = &self.field.ty;
quote! {
<#ty as #root::ShaderType>::SHADER_TYPE
}
}

fn wgsl_layout_attributes(&self) -> String {
let mut attribs = Vec::new();
if let Some((size, _)) = self.size {
attribs.push(format!("@size({})", size));
}
if let Some((align, _)) = self.align {
attribs.push(format!("@align({})", align));
}

if attribs.is_empty() {
String::new()
} else {
format!(" {}\n", &attribs.join(" "))
}
}
}

struct AlignmentAttr(u32);
Expand Down Expand Up @@ -195,21 +219,25 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
field: field.clone(),
size: None,
align: None,
shader_atomic: false,
};
for attr in &field.attrs {
if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) {
continue;
}
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
if meta_list.path.is_ident("align") {
if attr.meta.path().is_ident("align") {
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
let res = attr.parse_args::<AlignmentAttr>();
match res {
Ok(val) => data.align = Some((val.0, span)),
Err(err) => errors.append(err),
}
} else if meta_list.path.is_ident("size") {
}
Err(err) => errors.append(err),
}
} else if attr.meta.path().is_ident("size") {
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
let res = if i == last_field_index {
attr.parse_args::<SizeAttr>().map(|val| match val {
SizeAttr::Runtime => {
Expand All @@ -227,9 +255,16 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
Err(err) => errors.append(err),
}
}
Err(err) => errors.append(err),
}
Err(err) => errors.append(err),
};
} else if attr.meta.path().is_ident("shader_atomic") {
match attr.meta.require_path_only() {
Ok(_) => {
data.shader_atomic = true;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make sure that this is only set for integer scalars. The validation could go in generate_field_trait_constraints but I think we need some new metadata for integer scalars to distinguish them from other types.

}
Err(err) => errors.append(err),
}
}
}
data
})
Expand Down Expand Up @@ -516,12 +551,32 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
let field_types_2 = field_types.clone();
let field_types_3 = field_types.clone();
let field_types_4 = field_types.clone();
let field_types_5 = field_types.clone();
let all_other = field_types.clone().take(last_field_index);
let last_field_type = &last_field.field.ty;

let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let field_strings = field_data.iter().map(|data| data.ident().to_string());
let field_layout_attributes = field_data.iter().map(|data| data.wgsl_layout_attributes());
let field_wgsl_types = field_data.iter().map(|data| data.wgsl_type(root));
let field_atomics_pre = field_data.iter().map(|data| {
if data.shader_atomic {
Literal::string("atomic<")
} else {
Literal::string("")
}
});
let field_atomics_post = field_data.iter().map(|data| {
if data.shader_atomic {
Literal::string(">")
} else {
Literal::string("")
}
});
let name_string = name.to_string();

let set_contained_rt_sized_array_length = if is_runtime_sized {
quote! {
writer.ctx.rts_array_length = ::core::option::Option::Some(
Expand Down Expand Up @@ -607,6 +662,8 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
offset += #root::ShaderType::size(&self.#last_field_ident).get();
#root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
}

const SHADER_TYPE: &'static ::core::primitive::str = #name_string;
}

impl #impl_generics #root::WriteInto for #name #ty_generics
Expand Down Expand Up @@ -642,6 +699,20 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
}
}

impl #impl_generics #root::ShaderStructDeclaration for #name #ty_generics
where
#( #field_types_5: #root::ShaderType, )*
{
const SHADER_STRUCT_DECLARATION: &'static ::core::primitive::str =
#root::ConstStr::new()
.str("struct ").str(#name_string).str(" {\n")
#(
.str(#field_layout_attributes).str(" ").str(#field_strings).str(": ")
.str(#field_atomics_pre).str(#field_wgsl_types).str(#field_atomics_post).str(",\n")
)*
.str("}\n").as_str();
}

#extra
}
}
Expand Down
105 changes: 105 additions & 0 deletions src/const_str.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Const string implementation for SHADER_TYPE and SHADER_STRUCT_DECLARATION
// Used instead of crates like const_str because of E0401 when trying to use them in traits
// See also https://old.reddit.com/r/rust/comments/sv119a/concat_static_str_at_compile_time/

// Must be constant to avoid running into E0401. Should only affect compilation.
const BUFFER_SIZE: usize = 8192;

pub struct ConstStr {
data: [u8; BUFFER_SIZE],
len: usize,
}

impl ConstStr {
pub const fn new() -> ConstStr {
ConstStr {
data: [0u8; BUFFER_SIZE],
len: 0,
}
}

pub const fn str(mut self, s: &str) -> Self {
let b = s.as_bytes();
let mut index = 0;
while index < b.len() {
self.data[self.len] = b[index];
self.len += 1;
index += 1;
}

self
}

pub const fn u64(mut self, x: u64) -> Self {
let mut x2 = x;
let mut l = 0;
loop {
l += 1;
x2 /= 10;
if x2 == 0 {
break;
}
}
let mut x3 = x;
let mut index = 0;
loop {
self.data[self.len + l - 1 - index] = (x3 % 10) as u8 + b'0';
index += 1;
x3 /= 10;
if x3 == 0 {
break;
}
}
self.len += l;

self
}

pub const fn as_str(&self) -> &str {
// SAFETY: safe because this is only used in const, and should be correct by construction
unsafe {
std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.data.as_ptr(), self.len))
}
}
}

mod test {
use super::ConstStr;

trait Name {
const NAME: &'static str;
}

trait Prefix: Name {}

trait Root: Name {}

struct Kilo;

impl Prefix for Kilo {}

struct Meter;

impl Root for Meter {}

impl Name for Kilo {
const NAME: &'static str = "kilo";
}

impl Name for Meter {
const NAME: &'static str = "meter";
}

impl<P: Prefix, R: Root> Name for (P, R) {
const NAME: &'static str = ConstStr::new()
.str(P::NAME)
.str(R::NAME)
.u64(1234567)
.as_str();
}

#[test]
fn test_trait() {
assert_eq!(<(Kilo, Meter)>::NAME, "kilometer1234567");
}
}
10 changes: 10 additions & 0 deletions src/core/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ pub trait ShaderType {
#[doc(hidden)]
const METADATA: Metadata<Self::ExtraMetadata>;

/// The [WGSL type name](https://www.w3.org/TR/WGSL/#types) for the implementing Rust type.
///
/// Note that for structs, this is just the name of the struct. See also [`ShaderStructDeclaration`].
const SHADER_TYPE: &'static str;

/// Represents the minimum size of `Self` (equivalent to [GPUBufferBindingLayout.minBindingSize](https://gpuweb.github.io/gpuweb/#dom-gpubufferbindinglayout-minbindingsize))
///
/// For [WGSL fixed-footprint types](https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types)
Expand Down Expand Up @@ -244,3 +249,8 @@ pub trait CreateFrom: Sized {
where
B: BufferRef;
}

pub trait ShaderStructDeclaration {
/// The [WGSL struct](https://www.w3.org/TR/WGSL/#struct-types) definition for the implementing Rust struct
const SHADER_STRUCT_DECLARATION: &'static str;
}
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ pub use encase_derive::ShaderType;

#[macro_use]
mod utils;
mod const_str;
mod core;
mod types;

mod impls;

pub use crate::core::{
CalculateSizeFor, DynamicStorageBuffer, DynamicUniformBuffer, ShaderSize, ShaderType,
StorageBuffer, UniformBuffer,
CalculateSizeFor, DynamicStorageBuffer, DynamicUniformBuffer, ShaderSize,
ShaderStructDeclaration, ShaderType, StorageBuffer, UniformBuffer,
};
pub use types::runtime_sized_array::ArrayLength;

Expand Down Expand Up @@ -143,6 +144,7 @@ pub mod matrix {
#[doc(hidden)]
pub mod private {
pub use super::build_struct;
pub use super::const_str::ConstStr;
pub use super::core::AlignmentValue;
pub use super::core::BufferMut;
pub use super::core::BufferRef;
Expand All @@ -151,6 +153,7 @@ pub mod private {
pub use super::core::ReadFrom;
pub use super::core::Reader;
pub use super::core::RuntimeSizedArray;
pub use super::core::ShaderStructDeclaration;
pub use super::core::SizeValue;
pub use super::core::WriteInto;
pub use super::core::Writer;
Expand Down
9 changes: 9 additions & 0 deletions src/types/array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::const_str::ConstStr;
use crate::core::{
BufferMut, BufferRef, CreateFrom, Metadata, ReadFrom, Reader, ShaderSize, ShaderType,
SizeValue, WriteInto, Writer,
Expand Down Expand Up @@ -55,6 +56,14 @@ impl<T: ShaderType + ShaderSize, const N: usize> ShaderType for [T; N] {
},
])
};

const SHADER_TYPE: &'static ::core::primitive::str = ConstStr::new()
.str("array<")
.str(T::SHADER_TYPE)
.str(", ")
.u64(N as u64)
.str(">")
.as_str();
}

impl<T: ShaderSize, const N: usize> ShaderSize for [T; N] {}
Expand Down
3 changes: 3 additions & 0 deletions src/types/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ macro_rules! impl_matrix_inner {
},
}
};

const SHADER_TYPE: &'static ::core::primitive::str = $crate::private::ConstStr::new()
.str("mat").u64($c).str("x").u64($r).str("<").str(<$el_ty as $crate::private::ShaderType>::SHADER_TYPE).str(">").as_str();
}

impl<$($generics)*> $crate::private::ShaderSize for $type
Expand Down
Loading