Skip to content

Commit 782a0fc

Browse files
committed
coop: rewire WGSL support using references
1 parent 626cf05 commit 782a0fc

22 files changed

+217
-218
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162).
166166

167167
- Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386).
168168

169-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
169+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
170170

171171
### Changes
172172

naga/src/back/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314+
impl crate::TypeInner {
315+
/// Returns true if a variable of this type is a handle.
316+
pub const fn is_handle(&self) -> bool {
317+
match *self {
318+
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
319+
_ => false,
320+
}
321+
}
322+
}
323+
314324
impl crate::Statement {
315325
/// Returns true if the statement directly terminates the current block.
316326
///

naga/src/back/msl/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ pub enum Error {
228228
UnsupportedArrayOf(String),
229229
#[error("array of type '{0:?}' is not supported")]
230230
UnsupportedArrayOfType(Handle<crate::Type>),
231-
#[error("ray tracing is not supported prior to MSL 2.3")]
231+
#[error("ray tracing is not supported prior to MSL 2.4")]
232232
UnsupportedRayTracing,
233+
#[error("cooperative matrix is not supported prior to MSL 2.3")]
234+
UnsupportedCooperativeMatrix,
233235
#[error("overrides should not be present at this stage")]
234236
Override,
235237
#[error("bitcasting to {0:?} is not supported")]

naga/src/back/msl/writer.rs

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ impl Display for TypeContext<'_> {
236236
rows,
237237
scalar,
238238
} => put_numeric_type(out, scalar, &[rows, columns]),
239+
// Requires Metal-2.3
239240
crate::TypeInner::CooperativeMatrix {
240241
columns,
241242
rows,
@@ -244,8 +245,7 @@ impl Display for TypeContext<'_> {
244245
} => {
245246
write!(
246247
out,
247-
"{}::simdgroup_{}{}x{}",
248-
NAMESPACE,
248+
"{NAMESPACE}::simdgroup_{}{}x{}",
249249
scalar.to_msl_name(),
250250
columns as u32,
251251
rows as u32,
@@ -485,6 +485,7 @@ enum WrappedFunction {
485485
class: crate::ImageClass,
486486
},
487487
CooperativeMultiplyAdd {
488+
space: crate::AddressSpace,
488489
columns: crate::CooperativeSize,
489490
rows: crate::CooperativeSize,
490491
intermediate: crate::CooperativeSize,
@@ -2842,6 +2843,9 @@ impl<W: Write> Writer<W> {
28422843
write!(self.out, "}}")?;
28432844
}
28442845
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2846+
if context.lang_version < (2, 3) {
2847+
return Err(Error::UnsupportedCooperativeMatrix);
2848+
}
28452849
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
28462850
self.put_expression(a, context, true)?;
28472851
write!(self.out, ", ")?;
@@ -4239,10 +4243,14 @@ impl<W: Write> Writer<W> {
42394243
row_major,
42404244
} => {
42414245
let op_str = if store { "store" } else { "load" };
4242-
write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?;
4246+
write!(self.out, "{level}simdgroup_{op_str}(")?;
42434247
self.put_expression(target, &context.expression, true)?;
4244-
write!(self.out, ", ")?;
4245-
self.put_expression(pointer, &context.expression, true)?;
4248+
write!(self.out, ", &")?;
4249+
self.put_access_chain(
4250+
pointer,
4251+
context.expression.policies.index,
4252+
&context.expression,
4253+
)?;
42464254
write!(self.out, ", ")?;
42474255
self.put_expression(stride, &context.expression, true)?;
42484256
if row_major {
@@ -6312,23 +6320,31 @@ template <typename A>
63126320
&mut self,
63136321
module: &crate::Module,
63146322
func_ctx: &back::FunctionCtx,
6323+
space: crate::AddressSpace,
63156324
a: Handle<crate::Expression>,
63166325
b: Handle<crate::Expression>,
63176326
) -> BackendResult {
63186327
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6319-
crate::TypeInner::CooperativeMatrix {
6320-
columns,
6321-
rows,
6322-
scalar,
6323-
..
6324-
} => (columns, rows, scalar),
6328+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6329+
crate::TypeInner::CooperativeMatrix {
6330+
columns,
6331+
rows,
6332+
scalar,
6333+
..
6334+
} => (columns, rows, scalar),
6335+
_ => unreachable!(),
6336+
},
63256337
_ => unreachable!(),
63266338
};
63276339
let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6328-
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6340+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6341+
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6342+
_ => unreachable!(),
6343+
},
63296344
_ => unreachable!(),
63306345
};
63316346
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6347+
space,
63326348
columns: b_c,
63336349
rows: a_r,
63346350
intermediate: a_c,
@@ -6337,15 +6353,11 @@ template <typename A>
63376353
if !self.wrapped_functions.insert(wrapped) {
63386354
return Ok(());
63396355
}
6340-
let scalar_name = match scalar.width {
6341-
2 => "half",
6342-
4 => "float",
6343-
8 => "double",
6344-
_ => unreachable!(),
6345-
};
6356+
let space_name = space.to_msl_name().unwrap_or_default();
6357+
let scalar_name = scalar.to_msl_name();
63466358
writeln!(
63476359
self.out,
6348-
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6360+
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
63496361
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
63506362
)?;
63516363
let l1 = back::Level(1);
@@ -6354,10 +6366,7 @@ template <typename A>
63546366
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
63556367
b_c as u32, a_r as u32
63566368
)?;
6357-
writeln!(
6358-
self.out,
6359-
"{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
6360-
)?;
6369+
writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
63616370
writeln!(self.out, "{l1}return d;")?;
63626371
writeln!(self.out, "}}")?;
63636372
writeln!(self.out)?;
@@ -6439,7 +6448,8 @@ template <typename A>
64396448
self.write_wrapped_image_query(module, func_ctx, image, query)?;
64406449
}
64416450
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6442-
self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?;
6451+
let space = crate::AddressSpace::Private;
6452+
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
64436453
}
64446454
_ => {}
64456455
}
@@ -6632,7 +6642,6 @@ template <typename A>
66326642
names: &self.names,
66336643
handle,
66346644
usage: fun_info[handle],
6635-
66366645
reference: true,
66376646
};
66386647
let separator =

naga/src/back/spv/block.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3717,7 +3717,13 @@ impl BlockContext<'_> {
37173717
self.cached[stride],
37183718
));
37193719
} else {
3720-
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3720+
let result_type_id =
3721+
match *self.fun_info[target].ty.inner_with(&self.ir_module.types) {
3722+
crate::TypeInner::Pointer { base, space: _ } => {
3723+
self.get_handle_type_id(base)
3724+
}
3725+
_ => unreachable!(),
3726+
};
37213727
let id = self.gen_id();
37223728
block.body.push(Instruction::coop_load(
37233729
result_type_id,

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,13 @@ impl Writer {
971971
}
972972
}
973973

974-
// Handle globals are pre-emitted and should be loaded automatically.
975-
//
976-
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
977974
match ir_module.types[var.ty].inner {
975+
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
978976
crate::TypeInner::BindingArray { .. } => {
979977
gv.access_id = gv.var_id;
980978
}
981979
_ => {
980+
// Handle globals are pre-emitted and should be loaded automatically.
982981
if var.space == crate::AddressSpace::Handle {
983982
let var_type_id = self.get_handle_type_id(var.ty);
984983
let id = self.id_gen.next();
@@ -1064,6 +1063,7 @@ impl Writer {
10641063
}
10651064
}),
10661065
);
1066+
10671067
context
10681068
.function
10691069
.variables

naga/src/back/wgsl/writer.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,13 @@ impl<W: Write> Writer<W> {
993993
} => {
994994
let op_str = if store { "Store" } else { "Load" };
995995
let suffix = if row_major { "T" } else { "" };
996-
write!(self.out, "coop{op_str}{suffix}(")?;
997-
self.write_expr(module, target, func_ctx)?;
996+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
997+
self.write_expr_with_indirection(module, target, func_ctx, Indirection::Reference)?;
998998
write!(self.out, ", ")?;
999999
self.write_expr(module, pointer, func_ctx)?;
10001000
write!(self.out, ", ")?;
10011001
self.write_expr(module, stride, func_ctx)?;
1002-
write!(self.out, ")")?
1002+
writeln!(self.out, ");")?
10031003
}
10041004
}
10051005

@@ -1714,11 +1714,11 @@ impl<W: Write> Writer<W> {
17141714
| Expression::WorkGroupUniformLoadResult { .. } => {}
17151715
Expression::CooperativeMultiplyAdd { a, b, c } => {
17161716
write!(self.out, "coopMultiplyAdd(")?;
1717-
self.write_expr(module, a, func_ctx)?;
1717+
self.write_expr_with_indirection(module, a, func_ctx, Indirection::Reference)?;
17181718
write!(self.out, ", ")?;
1719-
self.write_expr(module, b, func_ctx)?;
1719+
self.write_expr_with_indirection(module, b, func_ctx, Indirection::Reference)?;
17201720
write!(self.out, ", ")?;
1721-
self.write_expr(module, c, func_ctx)?;
1721+
self.write_expr_with_indirection(module, c, func_ctx, Indirection::Reference)?;
17221722
write!(self.out, ")")?;
17231723
}
17241724
}

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ pub(crate) enum Error<'a> {
412412
TypeTooLarge {
413413
span: Span,
414414
},
415+
InvalidCooperativeMatrix,
415416
UnderspecifiedCooperativeMatrix,
416417
UnsupportedCooperativeScalar(Span),
417418
}
@@ -1388,6 +1389,11 @@ impl<'a> Error<'a> {
13881389
crate::valid::MAX_TYPE_SIZE
13891390
)],
13901391
},
1392+
Error::InvalidCooperativeMatrix => ParseError {
1393+
message: "given type is not a cooperative matrix".into(),
1394+
labels: vec![],
1395+
notes: vec![format!("must be coop_mat")],
1396+
},
13911397
Error::UnderspecifiedCooperativeMatrix => ParseError {
13921398
message: "cooperative matrix constructor is underspecified".into(),
13931399
labels: vec![],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
846846
fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle<ir::Type> {
847847
self.as_global().ensure_type_exists(None, inner)
848848
}
849+
850+
fn _get_runtime_expression(&self, expr: Handle<ir::Expression>) -> &ir::Expression {
851+
match self.expr_type {
852+
ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr],
853+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
854+
unreachable!()
855+
}
856+
}
857+
}
849858
}
850859

851860
struct ArgumentContext<'ctx, 'source> {
@@ -955,6 +964,13 @@ impl<T> Typed<T> {
955964
Self::Plain(expr) => Typed::Plain(f(expr)?),
956965
})
957966
}
967+
968+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
969+
match self {
970+
Self::Reference(v) => Ok(v),
971+
Self::Plain(_) => Err(error),
972+
}
973+
}
958974
}
959975

960976
/// A single vector component or swizzle.
@@ -1677,12 +1693,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16771693
.as_expression(block, &mut emitter)
16781694
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16791695
block.extend(emitter.finish(&ctx.function.expressions));
1680-
let typed = if ctx.module.types[ty].inner.is_handle() {
1681-
Typed::Plain(handle)
1682-
} else {
1683-
Typed::Reference(handle)
1684-
};
1685-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1696+
ctx.local_table
1697+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16861698

16871699
match initializer {
16881700
Some(initializer) => ir::Statement::Store {
@@ -1977,12 +1989,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19771989
let value_span = ctx.ast_expressions.get_span(value);
19781990
let target = self
19791991
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1980-
let target_handle = match target {
1981-
Typed::Reference(handle) => handle,
1982-
Typed::Plain(_) => {
1983-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1984-
}
1985-
};
1992+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19861993

19871994
let mut ectx = ctx.as_expression(block, &mut emitter);
19881995
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2139,10 +2146,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21392146
LoweredGlobalDecl::Var(handle) => {
21402147
let expr = ir::Expression::GlobalVariable(handle);
21412148
let v = &ctx.module.global_variables[handle];
2142-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21432149
match v.space {
21442150
ir::AddressSpace::Handle => Typed::Plain(expr),
2145-
_ if force_value => Typed::Plain(expr),
21462151
_ => Typed::Reference(expr),
21472152
}
21482153
}
@@ -3140,7 +3145,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31403145
let row_major = function.name.ends_with("T");
31413146

31423147
let mut args = ctx.prepare_args(arguments, 2, span);
3143-
let target = self.expression(args.next()?, ctx)?;
3148+
let target = self
3149+
.expression_for_reference(args.next()?, ctx)?
3150+
.ref_or(Error::InvalidCooperativeMatrix)?;
31443151
let pointer = self.expression(args.next()?, ctx)?;
31453152
let stride = if args.total_args > 2 {
31463153
self.expression(args.next()?, ctx)?
@@ -3178,9 +3185,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31783185
}
31793186
"coopMultiplyAdd" => {
31803187
let mut args = ctx.prepare_args(arguments, 3, span);
3181-
let a = self.expression(args.next()?, ctx)?;
3182-
let b = self.expression(args.next()?, ctx)?;
3183-
let c = self.expression(args.next()?, ctx)?;
3188+
let a = self
3189+
.expression_for_reference(args.next()?, ctx)?
3190+
.ref_or(Error::InvalidCooperativeMatrix)?;
3191+
let b = self
3192+
.expression_for_reference(args.next()?, ctx)?
3193+
.ref_or(Error::InvalidCooperativeMatrix)?;
3194+
let c = self
3195+
.expression_for_reference(args.next()?, ctx)?
3196+
.ref_or(Error::InvalidCooperativeMatrix)?;
31843197
args.finish()?;
31853198

31863199
ir::Expression::CooperativeMultiplyAdd { a, b, c }

naga/src/proc/type_methods.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ impl crate::TypeInner {
191191
}
192192
}
193193

194-
/// Returns true if a variable of this type is a handle.
195-
pub const fn is_handle(&self) -> bool {
196-
match *self {
197-
Self::Image { .. }
198-
| Self::Sampler { .. }
199-
| Self::AccelerationStructure { .. }
200-
| Self::CooperativeMatrix { .. } => true,
201-
_ => false,
202-
}
203-
}
204-
205194
/// Attempt to calculate the size of this type. Returns `None` if the size
206195
/// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`].
207196
pub fn try_size(&self, gctx: super::GlobalCtx) -> Option<u32> {

0 commit comments

Comments
 (0)