Skip to content

Commit

Permalink
[naga msl-out] Implement atomicCompareExchangeWeak for MSL backend (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
AsherJingkongChen authored Oct 10, 2024
1 parent d9178a1 commit bf33e48
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 52 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).

- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)

#### Metal

- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265)

### Bug Fixes

- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
Expand Down
130 changes: 80 additions & 50 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

Expand Down Expand Up @@ -1279,42 +1280,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

write!(
self.out,
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
Expand Down Expand Up @@ -3182,24 +3147,65 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
let context = &context.expression;

// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let fun_str = if let Some(result) = result {
let fun_key = if let Some(result) = result {
let res_name = Baked(result).to_string();
self.start_baking_expression(result, &context.expression, &res_name)?;
self.start_baking_expression(result, context, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl()
} else if context.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
fun.to_msl()
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
// If the pointer we're passing to the atomic operation needs to be conditional
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
// the pointer operand should be unchecked.
let policy = context.choose_bounds_check_policy(pointer);
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;

// If requested and successfully put bounds checks, continue the ternary expression.
if checked {
write!(self.out, " ? ")?;
}

// Put the atomic function invocation.
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
}
}

// Finish the ternary expression.
if checked {
write!(self.out, " : DefaultConstructible()")?;
}

// Done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
Expand Down Expand Up @@ -3827,7 +3833,33 @@ impl<W: Write> Writer<W> {
}}"
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
let arg_type_name = scalar.to_msl_name();
let called_func_name = "atomic_compare_exchange_weak_explicit";
let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;

for address_space_name in ["device", "threadgroup"] {
writeln!(
self.out,
"\
template <typename A>
{struct_name} {defined_func_name}(
{address_space_name} A *atomic_ptr,
{arg_type_name} cmp,
{arg_type_name} v
) {{
bool swapped = {NAMESPACE}::{called_func_name}(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return {struct_name}{{cmp, swapped}};
}}"
)?;
}
}
}
}

Expand Down Expand Up @@ -6065,8 +6097,8 @@ fn test_stack_size() {
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
const fn to_msl(self) -> &'static str {
match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Expand All @@ -6075,10 +6107,8 @@ impl crate::AtomicFunction {
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
}
}

fn to_msl_64_bit(self) -> Result<&'static str, Error> {
Expand Down
161 changes: 161 additions & 0 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_2 {
metal::atomic_int inner[128];
};
struct type_4 {
metal::atomic_uint inner[128];
};
struct _atomic_compare_exchange_resultSint4_ {
int old_value;
bool exchanged;
};
struct _atomic_compare_exchange_resultUint4_ {
uint old_value;
bool exchanged;
};

template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
int cmp,
int v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
}

template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
device A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
template <typename A>
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
threadgroup A *atomic_ptr,
uint cmp,
uint v
) {
bool swapped = metal::atomic_compare_exchange_weak_explicit(
atomic_ptr, &cmp, v,
metal::memory_order_relaxed, metal::memory_order_relaxed
);
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
}
constant uint SIZE = 128u;

kernel void test_atomic_compare_exchange_i32_(
device type_2& arr_i32_ [[user(fake0)]]
) {
uint i = 0u;
int old = {};
bool exchanged = {};
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
}
loop_init = false;
uint _e2 = i;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i;
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
LOOP_IS_REACHABLE while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
break;
}
{
int _e14 = old;
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
uint _e20 = i;
int _e22 = old;
_atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
old = _e23.old_value;
exchanged = _e23.exchanged;
}
}
}
}
return;
}


kernel void test_atomic_compare_exchange_u32_(
device type_4& arr_u32_ [[user(fake0)]]
) {
uint i_1 = 0u;
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
}
loop_init_1 = false;
uint _e2 = i_1;
if (_e2 < SIZE) {
} else {
break;
}
{
uint _e6 = i_1;
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
LOOP_IS_REACHABLE while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
break;
}
{
uint _e14 = old_1;
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
uint _e20 = i_1;
uint _e22 = old_1;
_atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
}
}
}
return;
}
Loading

0 comments on commit bf33e48

Please sign in to comment.