Skip to content

Commit 7ebf3ad

Browse files
Add ByteAddressBuffer support
1 parent 42878ed commit 7ebf3ad

File tree

5 files changed

+164
-12
lines changed

5 files changed

+164
-12
lines changed

common/output_stream.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,18 @@ void SpvReflectToYaml::WriteDescriptorBinding(std::ostream& os, const SpvReflect
19201920
assert(itor != descriptor_binding_to_index_.end());
19211921
os << t1 << "uav_counter_binding: *db" << itor->second << " # " << SafeString(db.uav_counter_binding->name) << std::endl;
19221922
}
1923+
1924+
if (db.byte_address_buffer_offset_count > 0) {
1925+
os << t1 << "ByteAddressBuffer offsets: [";
1926+
for (uint32_t i = 0; i < db.byte_address_buffer_offset_count; i++) {
1927+
os << db.byte_address_buffer_offsets[i];
1928+
if (i < (db.byte_address_buffer_offset_count - 1)) {
1929+
os << ", ";
1930+
}
1931+
}
1932+
os << "]\n";
1933+
}
1934+
19231935
if (verbosity_ >= 1) {
19241936
// SpvReflectTypeDescription* type_description;
19251937
if (db.type_description == nullptr) {

spirv_reflect.c

+148-12
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ typedef struct SpvReflectPrvString {
157157
// OpAtomicIAdd -> OpAccessChain -> OpVariable
158158
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
159159
typedef struct SpvReflectPrvAccessedVariable {
160+
SpvReflectPrvNode* p_node;
160161
uint32_t result_id;
161162
uint32_t variable_ptr;
162163
} SpvReflectPrvAccessedVariable;
@@ -981,6 +982,15 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) {
981982
case SpvOpFunctionParameter: {
982983
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
983984
} break;
985+
case SpvOpBitcast:
986+
case SpvOpShiftRightLogical:
987+
case SpvOpIAdd:
988+
case SpvOpISub:
989+
case SpvOpIMul:
990+
case SpvOpUDiv:
991+
case SpvOpSDiv: {
992+
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
993+
} break;
984994
}
985995

986996
if (p_node->is_type) {
@@ -1152,6 +1162,7 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
11521162
const uint32_t ptr_index = p_node->word_offset + 3;
11531163
SpvReflectPrvAccessedVariable* access_ptr = &p_func->accessed_variables[p_func->accessed_variable_count];
11541164

1165+
access_ptr->p_node = p_node;
11551166
// Need to track Result ID as not sure there has been any memory access through here yet
11561167
CHECKED_READU32(p_parser, result_index, access_ptr->result_id);
11571168
CHECKED_READU32(p_parser, ptr_index, access_ptr->variable_ptr);
@@ -1160,11 +1171,12 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
11601171
case SpvOpStore: {
11611172
const uint32_t result_index = p_node->word_offset + 2;
11621173
CHECKED_READU32(p_parser, result_index, p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
1174+
p_func->accessed_variables[p_func->accessed_variable_count].p_node = p_node;
11631175
(++p_func->accessed_variable_count);
11641176
} break;
11651177
case SpvOpCopyMemory:
11661178
case SpvOpCopyMemorySized: {
1167-
// There is no result_id is being zero is same as being invalid
1179+
// There is no result_id or node, being zero is same as being invalid
11681180
CHECKED_READU32(p_parser, p_node->word_offset + 1,
11691181
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
11701182
(++p_func->accessed_variable_count);
@@ -3221,6 +3233,99 @@ static SpvReflectResult TraverseCallGraph(SpvReflectPrvParser* p_parser, SpvRefl
32213233
return SPV_REFLECT_RESULT_SUCCESS;
32223234
}
32233235

3236+
static uint32_t GetUint32Constant(SpvReflectPrvParser* p_parser, uint32_t id) {
3237+
uint32_t result = (uint32_t)INVALID_VALUE;
3238+
SpvReflectPrvNode* p_node = FindNode(p_parser, id);
3239+
if (p_node && p_node->op == SpvOpConstant) {
3240+
UNCHECKED_READU32(p_parser, p_node->word_offset + 3, result);
3241+
}
3242+
return result;
3243+
}
3244+
3245+
static bool HasByteAddressBufferOffset(SpvReflectPrvNode* p_node, SpvReflectDescriptorBinding* p_binding) {
3246+
return IsNotNull(p_node) && IsNotNull(p_binding) && p_node->op == SpvOpAccessChain && p_node->word_count == 6 &&
3247+
(p_binding->user_type == SPV_REFLECT_USER_TYPE_BYTE_ADDRESS_BUFFER ||
3248+
p_binding->user_type == SPV_REFLECT_USER_TYPE_RW_BYTE_ADDRESS_BUFFER);
3249+
}
3250+
3251+
static SpvReflectResult ParseByteAddressBuffer(SpvReflectPrvParser* p_parser, SpvReflectPrvNode* p_node,
3252+
SpvReflectDescriptorBinding* p_binding) {
3253+
const SpvReflectResult not_found = SPV_REFLECT_RESULT_SUCCESS;
3254+
if (!HasByteAddressBufferOffset(p_node, p_binding)) {
3255+
return not_found;
3256+
}
3257+
3258+
uint32_t base_id = 0;
3259+
// expect first index of 2D access is zero
3260+
UNCHECKED_READU32(p_parser, p_node->word_offset + 4, base_id);
3261+
if (GetUint32Constant(p_parser, base_id) != 0) {
3262+
return not_found;
3263+
}
3264+
UNCHECKED_READU32(p_parser, p_node->word_offset + 5, base_id);
3265+
SpvReflectPrvNode* p_next_node = FindNode(p_parser, base_id);
3266+
if (IsNull(p_next_node)) {
3267+
return not_found;
3268+
}
3269+
3270+
// there is usually 2 (sometimes 3) instrucitons that make up the arithmetic logic to calculate the offset
3271+
SpvReflectPrvNode* arithmetic_node_stack[5];
3272+
uint32_t arithmetic_count = 0;
3273+
3274+
while (IsNotNull(p_next_node)) {
3275+
if (p_next_node->op == SpvOpLoad || p_next_node->op == SpvOpBitcast) {
3276+
break; // arithmetic starts here
3277+
}
3278+
arithmetic_node_stack[arithmetic_count++] = p_next_node;
3279+
if (arithmetic_count > 5) {
3280+
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3281+
}
3282+
3283+
UNCHECKED_READU32(p_parser, p_next_node->word_offset + 3, base_id);
3284+
p_next_node = FindNode(p_parser, base_id);
3285+
}
3286+
3287+
uint32_t offset = 0; // starting offset
3288+
const uint32_t count = arithmetic_count;
3289+
for (uint32_t i = 0; i < count; i++) {
3290+
p_next_node = arithmetic_node_stack[--arithmetic_count];
3291+
// All arithmetic ops takes 2 operands, assumption is the 2nd operand has the constant
3292+
UNCHECKED_READU32(p_parser, p_next_node->word_offset + 4, base_id);
3293+
uint32_t value = GetUint32Constant(p_parser, base_id);
3294+
if (value == INVALID_VALUE) {
3295+
return not_found;
3296+
}
3297+
3298+
switch (p_next_node->op) {
3299+
case SpvOpShiftRightLogical:
3300+
offset >>= value;
3301+
break;
3302+
case SpvOpIAdd:
3303+
offset += value;
3304+
break;
3305+
case SpvOpISub:
3306+
offset -= value;
3307+
break;
3308+
case SpvOpIMul:
3309+
offset *= value;
3310+
break;
3311+
case SpvOpUDiv:
3312+
offset /= value;
3313+
break;
3314+
case SpvOpSDiv:
3315+
// OpConstant might be signed, but value should never be negative
3316+
assert((int32_t)value > 0);
3317+
offset /= value;
3318+
break;
3319+
default:
3320+
return not_found;
3321+
}
3322+
}
3323+
3324+
p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
3325+
p_binding->byte_address_buffer_offset_count++;
3326+
return SPV_REFLECT_RESULT_SUCCESS;
3327+
}
3328+
32243329
static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module,
32253330
SpvReflectEntryPoint* p_entry, size_t uniform_count, uint32_t* uniforms,
32263331
size_t push_constant_count, uint32_t* push_constants) {
@@ -3253,6 +3358,7 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
32533358
called_function_count = 0;
32543359
result = TraverseCallGraph(p_parser, p_func, &called_function_count, p_called_functions, 0);
32553360
if (result != SPV_REFLECT_RESULT_SUCCESS) {
3361+
SafeFree(p_called_functions);
32563362
return result;
32573363
}
32583364

@@ -3296,30 +3402,57 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
32963402

32973403
// Do set intersection to find the used uniform and push constants
32983404
size_t used_uniform_count = 0;
3299-
SpvReflectResult result0 = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count,
3300-
&p_entry->used_uniforms, &used_uniform_count);
3405+
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count, &p_entry->used_uniforms,
3406+
&used_uniform_count);
3407+
if (result != SPV_REFLECT_RESULT_SUCCESS) {
3408+
SafeFree(p_used_accesses);
3409+
return result;
3410+
}
33013411

33023412
size_t used_push_constant_count = 0;
3303-
SpvReflectResult result1 =
3304-
IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
3305-
&p_entry->used_push_constants, &used_push_constant_count);
3413+
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
3414+
&p_entry->used_push_constants, &used_push_constant_count);
3415+
if (result != SPV_REFLECT_RESULT_SUCCESS) {
3416+
SafeFree(p_used_accesses);
3417+
return result;
3418+
}
33063419

33073420
for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
33083421
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[i];
3422+
uint32_t byte_address_buffer_offset_count = 0;
3423+
33093424
for (uint32_t j = 0; j < used_acessed_count; j++) {
33103425
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
33113426
p_binding->accessed = 1;
3427+
3428+
if (HasByteAddressBufferOffset(p_used_accesses[j].p_node, p_binding)) {
3429+
byte_address_buffer_offset_count++;
3430+
}
3431+
}
3432+
}
3433+
3434+
// only if SPIR-V has ByteAddressBuffer user type
3435+
if (byte_address_buffer_offset_count > 0) {
3436+
// possible not all allocated offset slots are used, but this will be a max per binding
3437+
p_binding->byte_address_buffer_offsets = (uint32_t*)calloc(byte_address_buffer_offset_count, sizeof(uint32_t));
3438+
if (IsNull(p_binding->byte_address_buffer_offsets)) {
3439+
SafeFree(p_used_accesses);
3440+
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3441+
}
3442+
3443+
for (uint32_t j = 0; j < used_acessed_count; j++) {
3444+
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
3445+
result = ParseByteAddressBuffer(p_parser, p_used_accesses[j].p_node, p_binding);
3446+
if (result != SPV_REFLECT_RESULT_SUCCESS) {
3447+
SafeFree(p_used_accesses);
3448+
return result;
3449+
}
3450+
}
33123451
}
33133452
}
33143453
}
33153454

33163455
SafeFree(p_used_accesses);
3317-
if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
3318-
return result0;
3319-
}
3320-
if (result1 != SPV_REFLECT_RESULT_SUCCESS) {
3321-
return result1;
3322-
}
33233456

33243457
p_entry->used_uniform_count = (uint32_t)used_uniform_count;
33253458
p_entry->used_push_constant_count = (uint32_t)used_push_constant_count;
@@ -4112,6 +4245,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) {
41124245
// Descriptor binding blocks
41134246
for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) {
41144247
SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i];
4248+
if (IsNotNull(p_descriptor->byte_address_buffer_offsets)) {
4249+
SafeFree(p_descriptor->byte_address_buffer_offsets);
4250+
}
41154251
SafeFreeBlockVariables(&p_descriptor->block);
41164252
}
41174253
SafeFree(p_module->descriptor_bindings);

spirv_reflect.h

+2
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ typedef struct SpvReflectDescriptorBinding {
481481
uint32_t accessed;
482482
uint32_t uav_counter_id;
483483
struct SpvReflectDescriptorBinding* uav_counter_binding;
484+
uint32_t byte_address_buffer_offset_count;
485+
uint32_t* byte_address_buffer_offsets;
484486

485487
SpvReflectTypeDescription* type_description;
486488

tests/user_type/byte_address_buffer.spv.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ all_descriptor_bindings:
100100
accessed: 1
101101
uav_counter_id: 4294967295
102102
uav_counter_binding:
103+
ByteAddressBuffer offsets: [4, 5, 11, 13]
103104
type_description: *td1
104105
word_offset: { binding: 129, set: 125 }
105106
user_type: ByteAddressBuffer

tests/user_type/rw_byte_address_buffer.spv.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ all_descriptor_bindings:
100100
accessed: 1
101101
uav_counter_id: 4294967295
102102
uav_counter_binding:
103+
ByteAddressBuffer offsets: [4, 5, 11, 13]
103104
type_description: *td1
104105
word_offset: { binding: 130, set: 126 }
105106
user_type: RWByteAddressBuffer

0 commit comments

Comments
 (0)