Skip to content

Commit 7718a18

Browse files
committed
don't do pre-shuffle if not neccesary
1 parent da91552 commit 7718a18

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -936,8 +936,11 @@ namespace TensileLite
936936
auto numAllocatedElements = problem.tensors()[i].totalAllocatedElements();
937937
auto numAllocatedBytes = problem.tensors()[i].totalAllocatedBytes();
938938

939-
if((problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
940-
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B))
939+
bool needSwizzle
940+
= (problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
941+
|| (problem.swizzleTensorB()
942+
&& i == ContractionProblemGemm::TENSOR::B);
943+
if(needSwizzle)
941944
{
942945
//TODO: support more swizzle types,
943946
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -947,6 +950,9 @@ namespace TensileLite
947950
problem.tensors()[i], MiM_N, MiK, PackK);
948951
numAllocatedBytes
949952
= numAllocatedElements * rocisa::GetElementSize(dataType);
953+
954+
// std::cout << "DataInitialization- needSwizzle: numAllocatedElements:"
955+
// << numAllocatedElements << std::endl;
950956
}
951957

952958
pristine.maxElements = std::max(pristine.maxElements, numAllocatedElements);
@@ -1492,8 +1498,10 @@ namespace TensileLite
14921498
{
14931499
padding = pUnit.maxElements - problem.tensors()[i].totalAllocatedElements();
14941500

1495-
if((problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
1496-
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B))
1501+
bool needSwizzle
1502+
= (problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
1503+
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B);
1504+
if(needSwizzle)
14971505
{
14981506
//TODO: support more swizzle types,
14991507
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -2055,7 +2063,8 @@ namespace TensileLite
20552063

20562064
void* ptr{};
20572065

2058-
if(needSwizzle)
2066+
// When needSwizzle, if no need to do validation, we can save the time doing data-relayout
2067+
if(needSwizzle && m_elementsToValidate)
20592068
{
20602069
using Tensor = Tensor::Manipulation::Tensor;
20612070
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -2070,6 +2079,7 @@ namespace TensileLite
20702079
auto swizzleKey
20712080
= std::make_tuple(toBitWidth(desc.dataType()), unrolledSize, tiledSize);
20722081

2082+
// Cache-hit
20732083
if(g_swizzleCache.count(swizzleKey))
20742084
{
20752085
if(swizzleKey != g_swizzleCache.back())
@@ -2086,6 +2096,7 @@ namespace TensileLite
20862096
ptr = p.gpuInput.valid.get();
20872097
}
20882098
}
2099+
// No Cache-hit, do pre-shuffle...
20892100
else
20902101
{
20912102
auto tmpTensor = Tensor({tiledSize, unrolledSize}, desc.elementBytes());
@@ -2108,6 +2119,8 @@ namespace TensileLite
21082119
permuted.getDesc().flattenSize(),
21092120
hipMemcpyHostToDevice);
21102121
g_swizzleCache.emplace(swizzleKey, std::move(permuted));
2122+
// std::cout << "needSwizzle and do permute- Copied elems:"
2123+
// << paddedTensor.getDesc().flattenSize() << std::endl;
21112124
}
21122125
}
21132126
else
@@ -2117,6 +2130,10 @@ namespace TensileLite
21172130
p.cpuInput.valid.get(),
21182131
p.maxElements,
21192132
hipMemcpyHostToDevice);
2133+
// if(needSwizzle)
2134+
// std::cout
2135+
// << "needSwizzle but no validation- don't do pre-shuffle: Copied elems:"
2136+
// << p.maxElements << std::endl;
21202137
}
21212138

21222139
if(ptr == nullptr)

0 commit comments

Comments
 (0)