Skip to content

Commit

Permalink
shader_atomic attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
victorvde committed Jul 26, 2023
1 parent 2dfd7f8 commit 88a3900
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
53 changes: 41 additions & 12 deletions derive/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub use syn;
#[macro_export]
macro_rules! implement {
($path:expr) => {
#[proc_macro_derive(ShaderType, attributes(align, size))]
#[proc_macro_derive(ShaderType, attributes(align, size, shader_atomic))]
pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
let expanded = encase_derive_impl::derive_shader_type(input, &$path);
Expand All @@ -40,6 +40,7 @@ struct FieldData {
pub field: syn::Field,
pub size: Option<(u32, Span)>,
pub align: Option<(u32, Span)>,
pub shader_atomic: bool,
}

impl FieldData {
Expand Down Expand Up @@ -218,21 +219,25 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
field: field.clone(),
size: None,
align: None,
shader_atomic: false,
};
for attr in &field.attrs {
if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) {
continue;
}
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
if meta_list.path.is_ident("align") {
if attr.meta.path().is_ident("align") {
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
let res = attr.parse_args::<AlignmentAttr>();
match res {
Ok(val) => data.align = Some((val.0, span)),
Err(err) => errors.append(err),
}
} else if meta_list.path.is_ident("size") {
}
Err(err) => errors.append(err),
}
} else if attr.meta.path().is_ident("size") {
match attr.meta.require_list() {
Ok(meta_list) => {
let span = meta_list.tokens.span();
let res = if i == last_field_index {
attr.parse_args::<SizeAttr>().map(|val| match val {
SizeAttr::Runtime => {
Expand All @@ -250,9 +255,16 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
Err(err) => errors.append(err),
}
}
Err(err) => errors.append(err),
}
Err(err) => errors.append(err),
};
} else if attr.meta.path().is_ident("shader_atomic") {
match attr.meta.require_path_only() {
Ok(_) => {
data.shader_atomic = true;
}
Err(err) => errors.append(err),
}
}
}
data
})
Expand Down Expand Up @@ -548,6 +560,20 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
let field_strings = field_data.iter().map(|data| data.ident().to_string());
let field_layout_attributes = field_data.iter().map(|data| data.wgsl_layout_attributes());
let field_wgsl_types = field_data.iter().map(|data| data.wgsl_type(root));
let field_atomics_pre = field_data.iter().map(|data| {
if data.shader_atomic {
Literal::string("atomic<")
} else {
Literal::string("")
}
});
let field_atomics_post = field_data.iter().map(|data| {
if data.shader_atomic {
Literal::string(">")
} else {
Literal::string("")
}
});
let name_string = name.to_string();

let set_contained_rt_sized_array_length = if is_runtime_sized {
Expand Down Expand Up @@ -677,7 +703,10 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
const WGSL_STRUCT: &'static ::core::primitive::str =
#root::ConstStr::new()
.str("struct ").str(#name_string).str(" {\n")
#( .str(#field_layout_attributes).str(" ").str(#field_strings).str(": ").str(#field_wgsl_types).str(",\n") )*
#(
.str(#field_layout_attributes).str(" ").str(#field_strings).str(": ")
.str(#field_atomics_pre).str(#field_wgsl_types).str(#field_atomics_post).str(",\n")
)*
.str("}\n").as_str();
}

Expand Down
4 changes: 2 additions & 2 deletions src/types/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ macro_rules! impl_traits_for_atomic {
};
}

impl_traits_for_atomic!(AtomicU32, "atomic<u32>");
impl_traits_for_atomic!(AtomicI32, "atomic<i32>");
impl_traits_for_atomic!(AtomicU32, "u32");
impl_traits_for_atomic!(AtomicI32, "i32");

macro_rules! impl_marker_trait_for_f32 {
($trait:path) => {
Expand Down
16 changes: 12 additions & 4 deletions tests/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ struct A {
wi: core::num::Wrapping<i32>,
au: core::sync::atomic::AtomicU32,
ai: core::sync::atomic::AtomicI32,
#[shader_atomic]
aau: u32,
#[shader_atomic]
aai: i32,
v2: mint::Vector2<f32>,
v3: mint::Vector3<u32>,
v4: mint::Vector4<i32>,
Expand Down Expand Up @@ -77,6 +81,8 @@ fn gen_a(rng: &mut rand::rngs::StdRng) -> A {
wi: core::num::Wrapping(gen!(rng, i32)),
au: core::sync::atomic::AtomicU32::new(gen!(rng, u32)),
ai: core::sync::atomic::AtomicI32::new(gen!(rng, i32)),
aau: gen!(rng, u32),
aai: gen!(rng, i32),
v2: mint::Vector2::from(gen_arr!(rng, f32, 2)),
v3: mint::Vector3::from(gen_arr!(rng, u32, 3)),
v4: mint::Vector4::from(gen_arr!(rng, i32, 4)),
Expand Down Expand Up @@ -111,12 +117,12 @@ fn size() {
let mut rng = rand::rngs::StdRng::seed_from_u64(1234);
let a = gen_a(&mut rng);

assert_eq!(a.size().get(), 4080);
assert_eq!(a.size().get(), 4096);
}

#[test]
fn calculate_size_for() {
assert_eq!(<&A>::calculate_size_for(12).get(), 2832);
assert_eq!(<&A>::calculate_size_for(12).get(), 2848);
}

#[test]
Expand Down Expand Up @@ -157,8 +163,10 @@ fn wgsl_struct() {
ni: i32,
wu: u32,
wi: i32,
au: atomic<u32>,
ai: atomic<i32>,
au: u32,
ai: i32,
aau: atomic<u32>,
aai: atomic<i32>,
v2: vec2<f32>,
v3: vec3<u32>,
v4: vec4<i32>,
Expand Down

0 comments on commit 88a3900

Please sign in to comment.