@@ -936,8 +936,11 @@ namespace TensileLite
936
936
auto numAllocatedElements = problem.tensors ()[i].totalAllocatedElements ();
937
937
auto numAllocatedBytes = problem.tensors ()[i].totalAllocatedBytes ();
938
938
939
- if ((problem.swizzleTensorA () && i == ContractionProblemGemm::TENSOR::A)
940
- || (problem.swizzleTensorB () && i == ContractionProblemGemm::TENSOR::B))
939
+ bool needSwizzle
940
+ = (problem.swizzleTensorA () && i == ContractionProblemGemm::TENSOR::A)
941
+ || (problem.swizzleTensorB ()
942
+ && i == ContractionProblemGemm::TENSOR::B);
943
+ if (needSwizzle)
941
944
{
942
945
// TODO: support more swizzle types,
943
946
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -947,6 +950,9 @@ namespace TensileLite
947
950
problem.tensors ()[i], MiM_N, MiK, PackK);
948
951
numAllocatedBytes
949
952
= numAllocatedElements * rocisa::GetElementSize (dataType);
953
+
954
+ // std::cout << "DataInitialization- needSwizzle: numAllocatedElements:"
955
+ // << numAllocatedElements << std::endl;
950
956
}
951
957
952
958
pristine.maxElements = std::max (pristine.maxElements , numAllocatedElements);
@@ -1492,8 +1498,10 @@ namespace TensileLite
1492
1498
{
1493
1499
padding = pUnit.maxElements - problem.tensors ()[i].totalAllocatedElements ();
1494
1500
1495
- if ((problem.swizzleTensorA () && i == ContractionProblemGemm::TENSOR::A)
1496
- || (problem.swizzleTensorB () && i == ContractionProblemGemm::TENSOR::B))
1501
+ bool needSwizzle
1502
+ = (problem.swizzleTensorA () && i == ContractionProblemGemm::TENSOR::A)
1503
+ || (problem.swizzleTensorB () && i == ContractionProblemGemm::TENSOR::B);
1504
+ if (needSwizzle)
1497
1505
{
1498
1506
// TODO: support more swizzle types,
1499
1507
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -2055,7 +2063,8 @@ namespace TensileLite
2055
2063
2056
2064
void * ptr{};
2057
2065
2058
- if (needSwizzle)
2066
+ // When needSwizzle, if no need to do validation, we can save the time doing data-relayout
2067
+ if (needSwizzle && m_elementsToValidate)
2059
2068
{
2060
2069
using Tensor = Tensor::Manipulation::Tensor;
2061
2070
// currently, if A then it means MiM = 16, if B then it means MiN = 16
@@ -2070,6 +2079,7 @@ namespace TensileLite
2070
2079
auto swizzleKey
2071
2080
= std::make_tuple (toBitWidth (desc.dataType ()), unrolledSize, tiledSize);
2072
2081
2082
+ // Cache-hit
2073
2083
if (g_swizzleCache.count (swizzleKey))
2074
2084
{
2075
2085
if (swizzleKey != g_swizzleCache.back ())
@@ -2086,6 +2096,7 @@ namespace TensileLite
2086
2096
ptr = p.gpuInput .valid .get ();
2087
2097
}
2088
2098
}
2099
+ // No Cache-hit, do pre-shuffle...
2089
2100
else
2090
2101
{
2091
2102
auto tmpTensor = Tensor ({tiledSize, unrolledSize}, desc.elementBytes ());
@@ -2108,6 +2119,8 @@ namespace TensileLite
2108
2119
permuted.getDesc ().flattenSize (),
2109
2120
hipMemcpyHostToDevice);
2110
2121
g_swizzleCache.emplace (swizzleKey, std::move (permuted));
2122
+ // std::cout << "needSwizzle and do permute- Copied elems:"
2123
+ // << paddedTensor.getDesc().flattenSize() << std::endl;
2111
2124
}
2112
2125
}
2113
2126
else
@@ -2117,6 +2130,10 @@ namespace TensileLite
2117
2130
p.cpuInput .valid .get (),
2118
2131
p.maxElements ,
2119
2132
hipMemcpyHostToDevice);
2133
+ // if(needSwizzle)
2134
+ // std::cout
2135
+ // << "needSwizzle but no validation- don't do pre-shuffle: Copied elems:"
2136
+ // << p.maxElements << std::endl;
2120
2137
}
2121
2138
2122
2139
if (ptr == nullptr )
0 commit comments