@@ -157,6 +157,7 @@ typedef struct SpvReflectPrvString {
157
157
// OpAtomicIAdd -> OpAccessChain -> OpVariable
158
158
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
159
159
typedef struct SpvReflectPrvAccessedVariable {
160
+ SpvReflectPrvNode * p_node ;
160
161
uint32_t result_id ;
161
162
uint32_t variable_ptr ;
162
163
} SpvReflectPrvAccessedVariable ;
@@ -981,6 +982,15 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) {
981
982
case SpvOpFunctionParameter : {
982
983
CHECKED_READU32 (p_parser , p_node -> word_offset + 2 , p_node -> result_id );
983
984
} break ;
985
+ case SpvOpBitcast :
986
+ case SpvOpShiftRightLogical :
987
+ case SpvOpIAdd :
988
+ case SpvOpISub :
989
+ case SpvOpIMul :
990
+ case SpvOpUDiv :
991
+ case SpvOpSDiv : {
992
+ CHECKED_READU32 (p_parser , p_node -> word_offset + 2 , p_node -> result_id );
993
+ } break ;
984
994
}
985
995
986
996
if (p_node -> is_type ) {
@@ -1152,6 +1162,7 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
1152
1162
const uint32_t ptr_index = p_node -> word_offset + 3 ;
1153
1163
SpvReflectPrvAccessedVariable * access_ptr = & p_func -> accessed_variables [p_func -> accessed_variable_count ];
1154
1164
1165
+ access_ptr -> p_node = p_node ;
1155
1166
// Need to track Result ID as not sure there has been any memory access through here yet
1156
1167
CHECKED_READU32 (p_parser , result_index , access_ptr -> result_id );
1157
1168
CHECKED_READU32 (p_parser , ptr_index , access_ptr -> variable_ptr );
@@ -1160,11 +1171,12 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
1160
1171
case SpvOpStore : {
1161
1172
const uint32_t result_index = p_node -> word_offset + 2 ;
1162
1173
CHECKED_READU32 (p_parser , result_index , p_func -> accessed_variables [p_func -> accessed_variable_count ].variable_ptr );
1174
+ p_func -> accessed_variables [p_func -> accessed_variable_count ].p_node = p_node ;
1163
1175
(++ p_func -> accessed_variable_count );
1164
1176
} break ;
1165
1177
case SpvOpCopyMemory :
1166
1178
case SpvOpCopyMemorySized : {
1167
- // There is no result_id is being zero is same as being invalid
1179
+ // There is no result_id or node, being zero is same as being invalid
1168
1180
CHECKED_READU32 (p_parser , p_node -> word_offset + 1 ,
1169
1181
p_func -> accessed_variables [p_func -> accessed_variable_count ].variable_ptr );
1170
1182
(++ p_func -> accessed_variable_count );
@@ -3221,6 +3233,99 @@ static SpvReflectResult TraverseCallGraph(SpvReflectPrvParser* p_parser, SpvRefl
3221
3233
return SPV_REFLECT_RESULT_SUCCESS ;
3222
3234
}
3223
3235
3236
+ static uint32_t GetUint32Constant (SpvReflectPrvParser * p_parser , uint32_t id ) {
3237
+ uint32_t result = (uint32_t )INVALID_VALUE ;
3238
+ SpvReflectPrvNode * p_node = FindNode (p_parser , id );
3239
+ if (p_node && p_node -> op == SpvOpConstant ) {
3240
+ UNCHECKED_READU32 (p_parser , p_node -> word_offset + 3 , result );
3241
+ }
3242
+ return result ;
3243
+ }
3244
+
3245
+ static bool HasByteAddressBufferOffset (SpvReflectPrvNode * p_node , SpvReflectDescriptorBinding * p_binding ) {
3246
+ return IsNotNull (p_node ) && IsNotNull (p_binding ) && p_node -> op == SpvOpAccessChain && p_node -> word_count == 6 &&
3247
+ (p_binding -> user_type == SPV_REFLECT_USER_TYPE_BYTE_ADDRESS_BUFFER ||
3248
+ p_binding -> user_type == SPV_REFLECT_USER_TYPE_RW_BYTE_ADDRESS_BUFFER );
3249
+ }
3250
+
3251
+ static SpvReflectResult ParseByteAddressBuffer (SpvReflectPrvParser * p_parser , SpvReflectPrvNode * p_node ,
3252
+ SpvReflectDescriptorBinding * p_binding ) {
3253
+ const SpvReflectResult not_found = SPV_REFLECT_RESULT_SUCCESS ;
3254
+ if (!HasByteAddressBufferOffset (p_node , p_binding )) {
3255
+ return not_found ;
3256
+ }
3257
+
3258
+ uint32_t base_id = 0 ;
3259
+ // expect first index of 2D access is zero
3260
+ UNCHECKED_READU32 (p_parser , p_node -> word_offset + 4 , base_id );
3261
+ if (GetUint32Constant (p_parser , base_id ) != 0 ) {
3262
+ return not_found ;
3263
+ }
3264
+ UNCHECKED_READU32 (p_parser , p_node -> word_offset + 5 , base_id );
3265
+ SpvReflectPrvNode * p_next_node = FindNode (p_parser , base_id );
3266
+ if (IsNull (p_next_node )) {
3267
+ return not_found ;
3268
+ }
3269
+
3270
+ // there is usually 2 (sometimes 3) instrucitons that make up the arithmetic logic to calculate the offset
3271
+ SpvReflectPrvNode * arithmetic_node_stack [5 ];
3272
+ uint32_t arithmetic_count = 0 ;
3273
+
3274
+ while (IsNotNull (p_next_node )) {
3275
+ if (p_next_node -> op == SpvOpLoad || p_next_node -> op == SpvOpBitcast ) {
3276
+ break ; // arithmetic starts here
3277
+ }
3278
+ arithmetic_node_stack [arithmetic_count ++ ] = p_next_node ;
3279
+ if (arithmetic_count > 5 ) {
3280
+ return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED ;
3281
+ }
3282
+
3283
+ UNCHECKED_READU32 (p_parser , p_next_node -> word_offset + 3 , base_id );
3284
+ p_next_node = FindNode (p_parser , base_id );
3285
+ }
3286
+
3287
+ uint32_t offset = 0 ; // starting offset
3288
+ const uint32_t count = arithmetic_count ;
3289
+ for (uint32_t i = 0 ; i < count ; i ++ ) {
3290
+ p_next_node = arithmetic_node_stack [-- arithmetic_count ];
3291
+ // All arithmetic ops takes 2 operands, assumption is the 2nd operand has the constant
3292
+ UNCHECKED_READU32 (p_parser , p_next_node -> word_offset + 4 , base_id );
3293
+ uint32_t value = GetUint32Constant (p_parser , base_id );
3294
+ if (value == INVALID_VALUE ) {
3295
+ return not_found ;
3296
+ }
3297
+
3298
+ switch (p_next_node -> op ) {
3299
+ case SpvOpShiftRightLogical :
3300
+ offset >>= value ;
3301
+ break ;
3302
+ case SpvOpIAdd :
3303
+ offset += value ;
3304
+ break ;
3305
+ case SpvOpISub :
3306
+ offset -= value ;
3307
+ break ;
3308
+ case SpvOpIMul :
3309
+ offset *= value ;
3310
+ break ;
3311
+ case SpvOpUDiv :
3312
+ offset /= value ;
3313
+ break ;
3314
+ case SpvOpSDiv :
3315
+ // OpConstant might be signed, but value should never be negative
3316
+ assert ((int32_t )value > 0 );
3317
+ offset /= value ;
3318
+ break ;
3319
+ default :
3320
+ return not_found ;
3321
+ }
3322
+ }
3323
+
3324
+ p_binding -> byte_address_buffer_offsets [p_binding -> byte_address_buffer_offset_count ] = offset ;
3325
+ p_binding -> byte_address_buffer_offset_count ++ ;
3326
+ return SPV_REFLECT_RESULT_SUCCESS ;
3327
+ }
3328
+
3224
3329
static SpvReflectResult ParseStaticallyUsedResources (SpvReflectPrvParser * p_parser , SpvReflectShaderModule * p_module ,
3225
3330
SpvReflectEntryPoint * p_entry , size_t uniform_count , uint32_t * uniforms ,
3226
3331
size_t push_constant_count , uint32_t * push_constants ) {
@@ -3253,6 +3358,7 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
3253
3358
called_function_count = 0 ;
3254
3359
result = TraverseCallGraph (p_parser , p_func , & called_function_count , p_called_functions , 0 );
3255
3360
if (result != SPV_REFLECT_RESULT_SUCCESS ) {
3361
+ SafeFree (p_called_functions );
3256
3362
return result ;
3257
3363
}
3258
3364
@@ -3296,30 +3402,57 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
3296
3402
3297
3403
// Do set intersection to find the used uniform and push constants
3298
3404
size_t used_uniform_count = 0 ;
3299
- SpvReflectResult result0 = IntersectSortedAccessedVariable (p_used_accesses , used_acessed_count , uniforms , uniform_count ,
3300
- & p_entry -> used_uniforms , & used_uniform_count );
3405
+ result = IntersectSortedAccessedVariable (p_used_accesses , used_acessed_count , uniforms , uniform_count , & p_entry -> used_uniforms ,
3406
+ & used_uniform_count );
3407
+ if (result != SPV_REFLECT_RESULT_SUCCESS ) {
3408
+ SafeFree (p_used_accesses );
3409
+ return result ;
3410
+ }
3301
3411
3302
3412
size_t used_push_constant_count = 0 ;
3303
- SpvReflectResult result1 =
3304
- IntersectSortedAccessedVariable (p_used_accesses , used_acessed_count , push_constants , push_constant_count ,
3305
- & p_entry -> used_push_constants , & used_push_constant_count );
3413
+ result = IntersectSortedAccessedVariable (p_used_accesses , used_acessed_count , push_constants , push_constant_count ,
3414
+ & p_entry -> used_push_constants , & used_push_constant_count );
3415
+ if (result != SPV_REFLECT_RESULT_SUCCESS ) {
3416
+ SafeFree (p_used_accesses );
3417
+ return result ;
3418
+ }
3306
3419
3307
3420
for (uint32_t i = 0 ; i < p_module -> descriptor_binding_count ; ++ i ) {
3308
3421
SpvReflectDescriptorBinding * p_binding = & p_module -> descriptor_bindings [i ];
3422
+ uint32_t byte_address_buffer_offset_count = 0 ;
3423
+
3309
3424
for (uint32_t j = 0 ; j < used_acessed_count ; j ++ ) {
3310
3425
if (p_used_accesses [j ].variable_ptr == p_binding -> spirv_id ) {
3311
3426
p_binding -> accessed = 1 ;
3427
+
3428
+ if (HasByteAddressBufferOffset (p_used_accesses [j ].p_node , p_binding )) {
3429
+ byte_address_buffer_offset_count ++ ;
3430
+ }
3431
+ }
3432
+ }
3433
+
3434
+ // only if SPIR-V has ByteAddressBuffer user type
3435
+ if (byte_address_buffer_offset_count > 0 ) {
3436
+ // possible not all allocated offset slots are used, but this will be a max per binding
3437
+ p_binding -> byte_address_buffer_offsets = (uint32_t * )calloc (byte_address_buffer_offset_count , sizeof (uint32_t ));
3438
+ if (IsNull (p_binding -> byte_address_buffer_offsets )) {
3439
+ SafeFree (p_used_accesses );
3440
+ return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED ;
3441
+ }
3442
+
3443
+ for (uint32_t j = 0 ; j < used_acessed_count ; j ++ ) {
3444
+ if (p_used_accesses [j ].variable_ptr == p_binding -> spirv_id ) {
3445
+ result = ParseByteAddressBuffer (p_parser , p_used_accesses [j ].p_node , p_binding );
3446
+ if (result != SPV_REFLECT_RESULT_SUCCESS ) {
3447
+ SafeFree (p_used_accesses );
3448
+ return result ;
3449
+ }
3450
+ }
3312
3451
}
3313
3452
}
3314
3453
}
3315
3454
3316
3455
SafeFree (p_used_accesses );
3317
- if (result0 != SPV_REFLECT_RESULT_SUCCESS ) {
3318
- return result0 ;
3319
- }
3320
- if (result1 != SPV_REFLECT_RESULT_SUCCESS ) {
3321
- return result1 ;
3322
- }
3323
3456
3324
3457
p_entry -> used_uniform_count = (uint32_t )used_uniform_count ;
3325
3458
p_entry -> used_push_constant_count = (uint32_t )used_push_constant_count ;
@@ -4112,6 +4245,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) {
4112
4245
// Descriptor binding blocks
4113
4246
for (size_t i = 0 ; i < p_module -> descriptor_binding_count ; ++ i ) {
4114
4247
SpvReflectDescriptorBinding * p_descriptor = & p_module -> descriptor_bindings [i ];
4248
+ if (IsNotNull (p_descriptor -> byte_address_buffer_offsets )) {
4249
+ SafeFree (p_descriptor -> byte_address_buffer_offsets );
4250
+ }
4115
4251
SafeFreeBlockVariables (& p_descriptor -> block );
4116
4252
}
4117
4253
SafeFree (p_module -> descriptor_bindings );
0 commit comments