-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add associated constant for WGSL struct definition #44
base: main
Are you sure you want to change the base?
Changes from all commits
10928a4
2822bfe
4c42f01
cb23601
a16032b
22d97a6
b0b65dd
64fd334
2dfd7f8
88a3900
9ce2ee4
9216df6
192dbf2
915bafc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,7 +18,7 @@ Having to manually lay out data into GPU buffers can become very tedious and err | |||||
|
||||||
The core trait is [`ShaderType`] which mainly contains metadata about the given type. | ||||||
|
||||||
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. | ||||||
The [`WriteInto`], [`ReadFrom`] and [`CreateFrom`] traits represent the ability of a type to be written into the buffer, read from the buffer and created from the buffer respectively. The [`ShaderStructDeclaration`] trait allows to generate the WGSL struct definition. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
Most data types can implement the above traits via their respective macros: | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,7 +14,7 @@ pub use syn; | |||
#[macro_export] | ||||
macro_rules! implement { | ||||
($path:expr) => { | ||||
#[proc_macro_derive(ShaderType, attributes(align, size))] | ||||
#[proc_macro_derive(ShaderType, attributes(align, size, shader_atomic))] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for implementing the new attribute! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you mention it in the docs here? Line 22 in 308bb72
|
||||
pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream { | ||||
let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput); | ||||
let expanded = encase_derive_impl::derive_shader_type(input, &$path); | ||||
|
@@ -40,6 +40,7 @@ struct FieldData { | |||
pub field: syn::Field, | ||||
pub size: Option<(u32, Span)>, | ||||
pub align: Option<(u32, Span)>, | ||||
pub shader_atomic: bool, | ||||
} | ||||
|
||||
impl FieldData { | ||||
|
@@ -97,6 +98,29 @@ impl FieldData { | |||
fn ident(&self) -> &Ident { | ||||
self.field.ident.as_ref().unwrap() | ||||
} | ||||
|
||||
fn wgsl_type(&self, root: &Path) -> TokenStream { | ||||
let ty = &self.field.ty; | ||||
quote! { | ||||
<#ty as #root::ShaderType>::SHADER_TYPE | ||||
} | ||||
} | ||||
|
||||
fn wgsl_layout_attributes(&self) -> String { | ||||
let mut attribs = Vec::new(); | ||||
if let Some((size, _)) = self.size { | ||||
attribs.push(format!("@size({})", size)); | ||||
} | ||||
if let Some((align, _)) = self.align { | ||||
attribs.push(format!("@align({})", align)); | ||||
} | ||||
|
||||
if attribs.is_empty() { | ||||
String::new() | ||||
} else { | ||||
format!(" {}\n", &attribs.join(" ")) | ||||
} | ||||
} | ||||
} | ||||
|
||||
struct AlignmentAttr(u32); | ||||
|
@@ -195,21 +219,25 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { | |||
field: field.clone(), | ||||
size: None, | ||||
align: None, | ||||
shader_atomic: false, | ||||
}; | ||||
for attr in &field.attrs { | ||||
if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) { | ||||
continue; | ||||
} | ||||
match attr.meta.require_list() { | ||||
Ok(meta_list) => { | ||||
let span = meta_list.tokens.span(); | ||||
if meta_list.path.is_ident("align") { | ||||
if attr.meta.path().is_ident("align") { | ||||
match attr.meta.require_list() { | ||||
Ok(meta_list) => { | ||||
let span = meta_list.tokens.span(); | ||||
let res = attr.parse_args::<AlignmentAttr>(); | ||||
match res { | ||||
Ok(val) => data.align = Some((val.0, span)), | ||||
Err(err) => errors.append(err), | ||||
} | ||||
} else if meta_list.path.is_ident("size") { | ||||
} | ||||
Err(err) => errors.append(err), | ||||
} | ||||
} else if attr.meta.path().is_ident("size") { | ||||
match attr.meta.require_list() { | ||||
Ok(meta_list) => { | ||||
let span = meta_list.tokens.span(); | ||||
let res = if i == last_field_index { | ||||
attr.parse_args::<SizeAttr>().map(|val| match val { | ||||
SizeAttr::Runtime => { | ||||
|
@@ -227,9 +255,16 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { | |||
Err(err) => errors.append(err), | ||||
} | ||||
} | ||||
Err(err) => errors.append(err), | ||||
} | ||||
Err(err) => errors.append(err), | ||||
}; | ||||
} else if attr.meta.path().is_ident("shader_atomic") { | ||||
match attr.meta.require_path_only() { | ||||
Ok(_) => { | ||||
data.shader_atomic = true; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to make sure that this is only set for integer scalars. The validation could go in |
||||
} | ||||
Err(err) => errors.append(err), | ||||
} | ||||
} | ||||
} | ||||
data | ||||
}) | ||||
|
@@ -516,12 +551,32 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { | |||
let field_types_2 = field_types.clone(); | ||||
let field_types_3 = field_types.clone(); | ||||
let field_types_4 = field_types.clone(); | ||||
let field_types_5 = field_types.clone(); | ||||
let all_other = field_types.clone().take(last_field_index); | ||||
let last_field_type = &last_field.field.ty; | ||||
|
||||
let name = &input.ident; | ||||
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); | ||||
|
||||
let field_strings = field_data.iter().map(|data| data.ident().to_string()); | ||||
let field_layout_attributes = field_data.iter().map(|data| data.wgsl_layout_attributes()); | ||||
let field_wgsl_types = field_data.iter().map(|data| data.wgsl_type(root)); | ||||
let field_atomics_pre = field_data.iter().map(|data| { | ||||
if data.shader_atomic { | ||||
Literal::string("atomic<") | ||||
} else { | ||||
Literal::string("") | ||||
} | ||||
}); | ||||
let field_atomics_post = field_data.iter().map(|data| { | ||||
if data.shader_atomic { | ||||
Literal::string(">") | ||||
} else { | ||||
Literal::string("") | ||||
} | ||||
}); | ||||
let name_string = name.to_string(); | ||||
|
||||
let set_contained_rt_sized_array_length = if is_runtime_sized { | ||||
quote! { | ||||
writer.ctx.rts_array_length = ::core::option::Option::Some( | ||||
|
@@ -607,6 +662,8 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { | |||
offset += #root::ShaderType::size(&self.#last_field_ident).get(); | ||||
#root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0 | ||||
} | ||||
|
||||
const SHADER_TYPE: &'static ::core::primitive::str = #name_string; | ||||
} | ||||
|
||||
impl #impl_generics #root::WriteInto for #name #ty_generics | ||||
|
@@ -642,6 +699,20 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream { | |||
} | ||||
} | ||||
|
||||
impl #impl_generics #root::ShaderStructDeclaration for #name #ty_generics | ||||
where | ||||
#( #field_types_5: #root::ShaderType, )* | ||||
{ | ||||
const SHADER_STRUCT_DECLARATION: &'static ::core::primitive::str = | ||||
#root::ConstStr::new() | ||||
.str("struct ").str(#name_string).str(" {\n") | ||||
#( | ||||
.str(#field_layout_attributes).str(" ").str(#field_strings).str(": ") | ||||
.str(#field_atomics_pre).str(#field_wgsl_types).str(#field_atomics_post).str(",\n") | ||||
)* | ||||
.str("}\n").as_str(); | ||||
} | ||||
|
||||
#extra | ||||
} | ||||
} | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Const string implementation for SHADER_TYPE and SHADER_STRUCT_DECLARATION | ||
// Used instead of crates like const_str because of E0401 when trying to use them in traits | ||
// See also https://old.reddit.com/r/rust/comments/sv119a/concat_static_str_at_compile_time/ | ||
|
||
// Must be constant to avoid running into E0401. Should only affect compilation. | ||
const BUFFER_SIZE: usize = 8192; | ||
|
||
pub struct ConstStr { | ||
data: [u8; BUFFER_SIZE], | ||
len: usize, | ||
} | ||
|
||
impl ConstStr { | ||
pub const fn new() -> ConstStr { | ||
ConstStr { | ||
data: [0u8; BUFFER_SIZE], | ||
len: 0, | ||
} | ||
} | ||
|
||
pub const fn str(mut self, s: &str) -> Self { | ||
let b = s.as_bytes(); | ||
let mut index = 0; | ||
while index < b.len() { | ||
self.data[self.len] = b[index]; | ||
self.len += 1; | ||
index += 1; | ||
} | ||
|
||
self | ||
} | ||
|
||
pub const fn u64(mut self, x: u64) -> Self { | ||
let mut x2 = x; | ||
let mut l = 0; | ||
loop { | ||
l += 1; | ||
x2 /= 10; | ||
if x2 == 0 { | ||
break; | ||
} | ||
} | ||
let mut x3 = x; | ||
let mut index = 0; | ||
loop { | ||
self.data[self.len + l - 1 - index] = (x3 % 10) as u8 + b'0'; | ||
index += 1; | ||
x3 /= 10; | ||
if x3 == 0 { | ||
break; | ||
} | ||
} | ||
self.len += l; | ||
|
||
self | ||
} | ||
|
||
pub const fn as_str(&self) -> &str { | ||
// SAFETY: safe because this is only used in const, and should be correct by construction | ||
unsafe { | ||
std::str::from_utf8_unchecked(std::slice::from_raw_parts(self.data.as_ptr(), self.len)) | ||
} | ||
} | ||
} | ||
|
||
mod test { | ||
use super::ConstStr; | ||
|
||
trait Name { | ||
const NAME: &'static str; | ||
} | ||
|
||
trait Prefix: Name {} | ||
|
||
trait Root: Name {} | ||
|
||
struct Kilo; | ||
|
||
impl Prefix for Kilo {} | ||
|
||
struct Meter; | ||
|
||
impl Root for Meter {} | ||
|
||
impl Name for Kilo { | ||
const NAME: &'static str = "kilo"; | ||
} | ||
|
||
impl Name for Meter { | ||
const NAME: &'static str = "meter"; | ||
} | ||
|
||
impl<P: Prefix, R: Root> Name for (P, R) { | ||
const NAME: &'static str = ConstStr::new() | ||
.str(P::NAME) | ||
.str(R::NAME) | ||
.u64(1234567) | ||
.as_str(); | ||
} | ||
|
||
#[test] | ||
fn test_trait() { | ||
assert_eq!(<(Kilo, Meter)>::NAME, "kilometer1234567"); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to bump the MSRV?