Skip to content

Commit 8ff3044

Browse files
committed
Add vk_subgroup_zoo tests for maximal reconvergence behaviour
Diverged threads which re-converge at expected points
1 parent 0afc26f commit 8ff3044

File tree

2 files changed

+160
-27
lines changed

2 files changed

+160
-27
lines changed

util/test/demos/vk/vk_subgroup_zoo.cpp

+123-3
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,136 @@ layout(binding = 0, std430) buffer outbuftype {
129129
130130
layout(local_size_x = GROUP_SIZE_X, local_size_y = GROUP_SIZE_Y, local_size_z = 1) in;
131131
132+
vec4 funcD(uint id)
133+
{
134+
return vec4(subgroupAdd(id/2));
135+
}
136+
137+
vec4 nestedFunc(uint id)
138+
{
139+
vec4 ret = funcD(id/3);
140+
ret.w = subgroupAdd(id);
141+
return ret;
142+
}
143+
144+
vec4 funcA(uint id)
145+
{
146+
return nestedFunc(id*2);
147+
}
148+
149+
vec4 funcB(uint id)
150+
{
151+
return nestedFunc(id*4);
152+
}
153+
154+
vec4 funcTest(uint id)
155+
{
156+
if ((id % 2) == 0)
157+
{
158+
return vec4(0);
159+
}
160+
else
161+
{
162+
float value = subgroupAdd(id);
163+
if (id < 10)
164+
{
165+
return vec4(value);
166+
}
167+
value += subgroupAdd(id/2);
168+
return vec4(value);
169+
}
170+
}
171+
172+
void SetOuput(vec4 data)
173+
{
174+
outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data;
175+
}
132176
void main()
133177
{
134178
vec4 data = vec4(0);
179+
uint id = gl_SubgroupInvocationID;
180+
SetOuput(data);
135181
136182
if(IsTest(0))
137-
data = vec4(gl_SubgroupInvocationID, 0, 0, 0);
183+
{
184+
data.x = id;
185+
}
138186
else if(IsTest(1))
139-
data = vec4(subgroupAdd(gl_SubgroupInvocationID), 0, 0, 0);
187+
{
188+
data.x = subgroupAdd(id);
189+
}
190+
else if(IsTest(2))
191+
{
192+
// Diverged threads which reconverge
193+
if (id < 10)
194+
{
195+
// active threads 0-9
196+
data.x = subgroupAdd(id);
140197
141-
outbuf.data[push.test].vals[gl_LocalInvocationID.y * GROUP_SIZE_X + gl_LocalInvocationID.x] = data;
198+
if ((id % 2) == 0)
199+
data.y = subgroupAdd(id);
200+
else
201+
data.y = subgroupAdd(id);
202+
203+
data.x += subgroupAdd(id);
204+
}
205+
else
206+
{
207+
// active threads 10...
208+
data.x = subgroupAdd(id);
209+
}
210+
data.y = subgroupAdd(id);
211+
}
212+
else if(IsTest(3))
213+
{
214+
// Converged threads calling a function
215+
data = funcTest(id);
216+
data.y = subgroupAdd(id);
217+
}
218+
else if(IsTest(4))
219+
{
220+
// Converged threads calling a function which has a nested function call in it
221+
data = nestedFunc(id);
222+
data.y = subgroupAdd(id);
223+
}
224+
else if(IsTest(5))
225+
{
226+
// Diverged threads calling the same function
227+
if (id < 10)
228+
{
229+
data = funcD(id);
230+
}
231+
else
232+
{
233+
data = funcD(id);
234+
}
235+
data.y = subgroupAdd(id);
236+
}
237+
else if(IsTest(6))
238+
{
239+
// Diverged threads calling the same function which has a nested function call in it
240+
if (id < 10)
241+
{
242+
data = funcA(id);
243+
}
244+
else
245+
{
246+
data = funcB(id);
247+
}
248+
data.y = subgroupAdd(id);
249+
}
250+
else if(IsTest(7))
251+
{
252+
// Diverged threads which early exit
253+
if (id < 10)
254+
{
255+
data.x = subgroupAdd(id+10);
256+
SetOuput(data);
257+
return;
258+
}
259+
data.x = subgroupAdd(id);
260+
}
261+
SetOuput(data);
142262
}
143263
144264
)EOSHADER";

util/test/tests/Vulkan/VK_Subgroup_Zoo.py

+37-24
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def check_capture(self):
149149
rdtest.log.success(f"Test {idx} successful")
150150

151151
rdtest.log.end_section("Graphics tests")
152-
152+
overallFailed = False
153153
for comp_dim in compute_dims:
154154
rdtest.log.begin_section(
155155
f"Compute tests with {comp_dim.customName} workgroup")
@@ -175,6 +175,7 @@ def check_capture(self):
175175
bufdata = self.controller.GetBufferData(
176176
rw[0].descriptor.resource, test*16*1024, 16*1024)
177177

178+
failed = False
178179
for t in thread_checks:
179180
xrange = 1
180181
yrange = dim[1]
@@ -195,39 +196,51 @@ def check_capture(self):
195196
if x >= dim[0] or y >= dim[1]:
196197
continue
197198

198-
real = struct.unpack_from(
199-
"4f", bufdata, 16*y*dim[0] + 16*x)
199+
try:
200+
real = struct.unpack_from(
201+
"4f", bufdata, 16*y*dim[0] + 16*x)
200202

201-
trace = self.controller.DebugThread(
202-
(0, 0, 0), (x, y, z))
203+
trace = self.controller.DebugThread(
204+
(0, 0, 0), (x, y, z))
203205

204-
_, variables = self.process_trace(trace)
206+
_, variables = self.process_trace(trace)
205207

206-
if trace.debugger is None:
207-
self.controller.FreeTrace(trace)
208+
if trace.debugger is None:
209+
raise rdtest.TestFailureException(f"Test {test} at {action.eventId} got no debug result at {x},{y},{z}")
208210

209-
rdtest.log.error(
210-
f"Test {test} at {action.eventId} got no debug result at {x},{y},{z}")
211-
continue
211+
# Find the source variable 'data' at the highest instruction index
212+
maxInstInfo = None
213+
for instInfo in trace.instInfo:
214+
for v in instInfo.sourceVars:
215+
if v.name == 'data':
216+
maxInstInfo = instInfo
217+
break
212218

213-
sourceVars = [
214-
v for v in trace.instInfo[-1].sourceVars if v.name == 'data']
219+
sourceVars = [v for v in maxInstInfo.sourceVars if v.name == 'data']
220+
if len(sourceVars) != 1:
221+
raise rdtest.TestFailureException(f"Couldn't find compute source variable 'data' {x}, {y}, {z}")
215222

216-
if len(sourceVars) != 1:
217-
rdtest.log.error(
218-
"Couldn't find compute data variable")
219-
continue
223+
debugged = self.evaluate_source_var(
224+
sourceVars[0], variables)
220225

221-
debugged = self.evaluate_source_var(
222-
sourceVars[0], variables)
226+
debuggedValue = list(debugged.value.f32v[0:4])
223227

224-
debuggedValue = list(debugged.value.f32v[0:4])
228+
if not rdtest.value_compare(real, debuggedValue, eps=5.0E-06):
229+
raise rdtest.TestFailureException(f"EID:{action.eventId} TID:{x},{y},{z} debugged thread value {debuggedValue} does not match output {real}")
225230

226-
if not rdtest.value_compare(real, debuggedValue, eps=5.0E-06):
227-
rdtest.log.error(
228-
f"Test {test} at {action.eventId} debugged thread value {debuggedValue} at {x},{y},{z} does not match output {real}")
231+
except rdtest.TestFailureException as ex:
232+
rdtest.log.error(f"Test {test} failed {ex}")
233+
failed = True
234+
continue
235+
finally:
236+
self.controller.FreeTrace(trace)
229237

230-
rdtest.log.success(f"Test {test} successful")
238+
overallFailed |= failed
239+
if not failed:
240+
rdtest.log.success(f"Test {test} successful")
231241

232242
rdtest.log.end_section(
233243
f"Compute tests with {comp_dim.customName} workgroup")
244+
245+
if overallFailed:
246+
raise rdtest.TestFailureException("Some tests were not as expected")

0 commit comments

Comments
 (0)