Open
Conversation
…e that will be always the same for improved performance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
the current function works as:
device uint64_t make_smem_desc(bf16* ptr) {
// Convert shared memory pointer to integer
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
uint64_t desc = matrix_descriptor_encode(addr);
desc |= matrix_descriptor_encode((uint64_t)16) << 16;
desc |= matrix_descriptor_encode((uint64_t)1024) << 32;
desc |= 1llu << 62; // 128B swizzle
return desc;
}
which does a lot of unecessary computations that can be otimized.
the line "desc |= matrix_descriptor_encode((uint64_t)16) << 16;" can be completely taken out as on the ptx documentation it says that for K-major swizzled layouts "not used, assumed to be 1." so it is not necessary to write any number(can remain 0-ed)
the line "desc |= matrix_descriptor_encode((uint64_t)1024) << 32;" does the following computation
1024 & 0x3FFFF = 1024 -> 1024 >> 16 = 64 every time that the function is called so we can change that for 0x0000004000000000
so when the "or" operation happens it will change only its value on the descriptor
the line "desc |= 1llu << 62; // 128B swizzle" will always be 0x4000000000000000 so it is unecessary to do the left shift
this way we can combine our last two alterations, making the descriptor "0x4000004000000000" and then we will only need to compute our addr
Logically, there is no need to this function exist so we can delete it completely and write a logical "or" for every smem descriptor created like showed bellow:
...
uint64_t desc_a = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sA[0]))));
uint64_t desc_b = 0x4000004000000000 | (matrix_descriptor_encode(static_cast<uint32_t>(__cvta_generic_to_shared(&sB[0]))));
...
This way improving the performance of the kernel by about 1% :)