Skip to content

Commit c82f577

Browse files
authored
[AMDGPU] convert HIP struct type vector to llvm vector type (llvm#4344)
2 parents ca3d57f + 1d02e4f commit c82f577

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include "llvm/Transforms/Scalar.h"
8484
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
8585
#include "llvm/Transforms/Utils/Local.h"
86+
#include "llvm/TargetParser/Triple.h"
8687
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
8788
#include "llvm/Transforms/Utils/SSAUpdater.h"
8889
#include <algorithm>
@@ -4905,6 +4906,34 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
49054906
// FIXME: We might want to defer PHI speculation until after here.
49064907
// FIXME: return nullptr;
49074908
} else {
4909+
// AMDGPU: If the target is AMDGPU and the chosen SliceTy is a HIP vector
4910+
// struct of 2 or 4 identical elements, canonicalize it to an IR vector.
4911+
// This helps SROA treat it as a single value and unlock vector ld/st.
4912+
// We pattern-match struct names starting with "struct.HIP_vector".
4913+
if (Function *F = AI.getFunction()) {
4914+
Triple TT(F->getParent()->getTargetTriple());
4915+
if (TT.isAMDGPU()) {
4916+
if (auto *STy = dyn_cast<StructType>(SliceTy)) {
4917+
StringRef Name = STy->hasName() ? STy->getName() : StringRef();
4918+
if (Name.starts_with("struct.HIP_vector")) {
4919+
unsigned NumElts = STy->getNumElements();
4920+
if ((NumElts == 2 || NumElts == 4) && NumElts > 0) {
4921+
Type *EltTy = STy->getElementType(0);
4922+
bool AllSame = true;
4923+
for (unsigned I = 1; I < NumElts; ++I)
4924+
if (STy->getElementType(I) != EltTy) {
4925+
AllSame = false;
4926+
break;
4927+
}
4928+
if (AllSame && VectorType::isValidElementType(EltTy)) {
4929+
SliceTy = FixedVectorType::get(EltTy, NumElts);
4930+
}
4931+
}
4932+
}
4933+
}
4934+
}
4935+
}
4936+
49084937
// Make sure the alignment is compatible with P.beginOffset().
49094938
const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());
49104939
// If we will get at least this much alignment from the type alone, leave

0 commit comments

Comments
 (0)