Skip to content

Commit

Permalink
[naga msl-out] chore: Use Scalar::to_msl_name and adjust atomic predef
Browse files Browse the repository at this point in the history
  • Loading branch information
AsherJingkongChen committed Oct 10, 2024
1 parent 88f9cad commit 31548e2
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 71 deletions.
2 changes: 0 additions & 2 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ pub enum Error {
UnsupportedAttribute(String),
#[error("function '{0}' is not supported for target MSL version")]
UnsupportedFunction(String),
#[error("scalar {0:?} is not supported for target MSL version")]
UnsupportedScalar(crate::Scalar),
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
Expand Down
86 changes: 38 additions & 48 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION_KEY: &str = "naga_atomic_compare_exchange_weak";
pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -3177,23 +3177,27 @@ impl<W: Write> Writer<W> {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;

// Put the extra argument if provided.
if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
} else {
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
// Put the atomic function invocation.
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}
}

// Finish the ternary expression.
Expand Down Expand Up @@ -3830,45 +3834,31 @@ impl<W: Write> Writer<W> {
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let crate::Scalar { kind, width } = scalar;
let arg_type_name = match width {
1 => "bool",
4 => match kind {
crate::ScalarKind::Sint => "int",
crate::ScalarKind::Uint => "uint",
crate::ScalarKind::Float => "float",
_ => return Err(Error::UnsupportedScalar(scalar)),
},
_ => return Err(Error::UnsupportedScalar(scalar)),
};

let arg_type_name = scalar.to_msl_name();
let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_key = ATOMIC_COMP_EXCH_FUNCTION_KEY;
let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(self.out, "namespace {NAMESPACE} {{")?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
" \
template <typename A>
{struct_name} atomic_{defined_func_key}_explicit(
volatile {address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
"\
template <typename A>
{struct_name} {defined_func_name}(
{address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}

writeln!(self.out, "}}")?;
}
}
}
Expand Down Expand Up @@ -6117,7 +6107,7 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION_KEY,
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
}
}

Expand Down
24 changes: 10 additions & 14 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ struct _atomic_compare_exchange_resultUint4_ {
bool exchanged;
};

namespace metal {
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
int cmp,
int v
) {
Expand All @@ -33,8 +32,8 @@ namespace metal {
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
int cmp,
int v
) {
Expand All @@ -44,12 +43,10 @@ namespace metal {
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
}

namespace metal {
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
uint cmp,
uint v
) {
Expand All @@ -60,8 +57,8 @@ namespace metal {
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
Expand All @@ -71,7 +68,6 @@ namespace metal {
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
Expand Down Expand Up @@ -108,7 +104,7 @@ kernel void test_atomic_compare_exchange_i32_(
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
_atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
Expand Down Expand Up @@ -153,7 +149,7 @@ kernel void test_atomic_compare_exchange_u32_(
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
_atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
Expand Down
12 changes: 5 additions & 7 deletions naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ struct _atomic_compare_exchange_resultUint4_ {
bool exchanged;
};

namespace metal {
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile device A *atomic_ptr,
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
uint cmp,
uint v
) {
Expand All @@ -23,8 +22,8 @@ namespace metal {
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit(
volatile threadgroup A *atomic_ptr,
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
Expand All @@ -34,7 +33,6 @@ namespace metal {
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
}
constant int o = 2;

kernel void f(
Expand All @@ -45,6 +43,6 @@ kernel void f(
metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed);
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
_atomic_compare_exchange_resultUint4_ _e5 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u);
_atomic_compare_exchange_resultUint4_ _e5 = naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u);
return;
}

0 comments on commit 31548e2

Please sign in to comment.