Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga msl-out] Implement atomicCompareExchangeWeak for MSL backend #6265

Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).

- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)

#### Metal

- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265)

### Bug Fixes

- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
Expand Down
130 changes: 80 additions & 50 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -1279,42 +1280,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
Expand Down Expand Up @@ -3182,24 +3147,65 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
let context = &context.expression;

// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let fun_str = if let Some(result) = result {
let fun_key = if let Some(result) = result {
let res_name = Baked(result).to_string();
self.start_baking_expression(result, &context.expression, &res_name)?;
self.start_baking_expression(result, context, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl()
} else if context.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
fun.to_msl()
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

// Put the atomic function invocation.
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}
}

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

// Done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
Expand Down Expand Up @@ -3827,7 +3833,33 @@ impl<W: Write> Writer<W> {
}}"
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let arg_type_name = scalar.to_msl_name();
let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
"\
template <typename A>
{struct_name} {defined_func_name}(
{address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}
}
}
}

Expand Down Expand Up @@ -6065,8 +6097,8 @@ fn test_stack_size() {
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
const fn to_msl(self) -> &'static str {
match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Expand All @@ -6075,10 +6107,8 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
}
}

fn to_msl_64_bit(self) -> Result<&'static str, Error> {
Expand Down
160 changes: 160 additions & 0 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_2 {
metal::atomic_int inner[128];
};
struct type_4 {
metal::atomic_uint inner[128];
};
struct _atomic_compare_exchange_resultSint4_ {
int old_value;
bool exchanged;
};
struct _atomic_compare_exchange_resultUint4_ {
uint old_value;
bool exchanged;
};

template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}

template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
device type_2& arr_i32_ [[user(fake0)]]
) {
uint i = 0u;
int old = {};
bool exchanged = {};
bool loop_init = true;
while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
}
loop_init = false;
uint _e2 = i;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i;
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
break;
}
{
int _e14 = old;
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
}
}
}
return;
}


kernel void test_atomic_compare_exchange_u32_(
device type_4& arr_u32_ [[user(fake0)]]
) {
uint i_1 = 0u;
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
}
loop_init_1 = false;
uint _e2 = i_1;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i_1;
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
break;
}
{
uint _e14 = old_1;
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
}
}
}
return;
}
Loading
Loading