Skip to content

Commit

Permalink
[DirectX][OpLowering] Simplify named struct handling (llvm#128247)
Browse files Browse the repository at this point in the history
This removes "replaceFunctionWithNamedStructOp" and folds its
functionality into "replaceFunctionWithOp". It turns out we were
overcomplicating things and this is trivial to handle generically.

Fixes llvm#113192
  • Loading branch information
bogner authored Feb 22, 2025
1 parent 75bb25b commit f404047
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 61 deletions.
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ def MakeDouble : DXILOp<101, makeDouble> {

def SplitDouble : DXILOp<102, splitDouble> {
let Doc = "Splits a double into 2 uints";
let intrinsics = [IntrinSelect<int_dx_splitdouble>];
let arguments = [OverloadTy];
let result = SplitDoubleTy;
let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,6 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}

StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
return ::getSplitDoubleType(Context);
}

StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ class DXILOpBuilder {
/// Get a `%dx.types.ResRet` type with the given element type.
StructType *getResRetType(Type *ElementTy);

/// Get the `%dx.types.splitdouble` type.
StructType *getSplitDoubleType(LLVMContext &Context);

/// Get the `%dx.types.Handle` type.
StructType *getHandleType();

Expand Down
83 changes: 29 additions & 54 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,30 @@ class OpLowerer {
int Value;
};

/// Replaces uses of a struct with uses of an equivalent named struct.
///
/// DXIL operations that return structs give them well known names, so we need
/// to update uses when we switch from an LLVM intrinsic to an op.
Error replaceNamedStructUses(CallInst *Intrin, CallInst *DXILOp) {
auto *IntrinTy = cast<StructType>(Intrin->getType());
auto *DXILOpTy = cast<StructType>(DXILOp->getType());
if (!IntrinTy->isLayoutIdentical(DXILOpTy))
return make_error<StringError>(
"Type mismatch between intrinsic and DXIL op",
inconvertibleErrorCode());

for (Use &U : make_early_inc_range(Intrin->uses()))
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser()))
EVI->setOperand(0, DXILOp);
else if (auto *IVI = dyn_cast<InsertValueInst>(U.getUser()))
IVI->setOperand(0, DXILOp);
else
return make_error<StringError>("DXIL ops that return structs may only "
"be used by insert- and extractvalue",
inconvertibleErrorCode());
return Error::success();
}

[[nodiscard]] bool
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
ArrayRef<IntrinArgSelect> ArgSelects) {
Expand Down Expand Up @@ -154,32 +178,13 @@ class OpLowerer {
if (Error E = OpCall.takeError())
return E;

CI->replaceAllUsesWith(*OpCall);
CI->eraseFromParent();
return Error::success();
});
}

[[nodiscard]] bool replaceFunctionWithNamedStructOp(
Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
bool IsVectorArgExpansion = isVectorArgExpansion(F);
return replaceFunction(F, [&](CallInst *CI) -> Error {
SmallVector<Value *> Args;
OpBuilder.getIRB().SetInsertPoint(CI);
if (IsVectorArgExpansion) {
SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
Args.append(NewArgs.begin(), NewArgs.end());
if (isa<StructType>(CI->getType())) {
if (Error E = replaceNamedStructUses(CI, *OpCall))
return E;
} else
Args.append(CI->arg_begin(), CI->arg_end());

Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
if (Error E = OpCall.takeError())
return E;
if (Error E = ReplaceUses(CI, *OpCall))
return E;
CI->replaceAllUsesWith(*OpCall);

CI->eraseFromParent();
return Error::success();
});
}
Expand Down Expand Up @@ -359,26 +364,6 @@ class OpLowerer {
return lowerToBindAndAnnotateHandle(F);
}

Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
for (Use &U : make_early_inc_range(Intrin->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {

if (EVI->getNumIndices() != 1)
return createStringError(std::errc::invalid_argument,
"Splitdouble has only 2 elements");
EVI->setOperand(0, Op);
} else {
return make_error<StringError>(
"Splitdouble use is not ExtractValueInst",
inconvertibleErrorCode());
}
}

Intrin->eraseFromParent();

return Error::success();
}

/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
Expand Down Expand Up @@ -814,16 +799,6 @@ class OpLowerer {
case Intrinsic::dx_resource_updatecounter:
HasErrors |= lowerUpdateCounter(F);
break;
// TODO: this can be removed when
// https://github.com/llvm/llvm-project/issues/113192 is fixed
case Intrinsic::dx_splitdouble:
HasErrors |= replaceFunctionWithNamedStructOp(
F, OpCode::SplitDouble,
OpBuilder.getSplitDoubleType(M.getContext()),
[&](CallInst *CI, CallInst *Op) {
return replaceSplitDoubleCallUsages(CI, Op);
});
break;
case Intrinsic::ctpop:
HasErrors |= lowerCtpopToCountBits(F);
break;
Expand Down

0 comments on commit f404047

Please sign in to comment.