Skip to content

Commit

Permalink
JLInstSimplify multi arg
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 11, 2024
1 parent 636db9f commit c1b6729
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 14 deletions.
113 changes: 99 additions & 14 deletions enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,58 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst) {
return true;
}

static inline SetVector<llvm::Value *> getBaseObjects(llvm::Value *V,
bool offsetAllowed) {
SetVector<llvm::Value *> results;

SmallPtrSet<llvm::Value *, 2> seen;
SmallVector<llvm::Value *, 1> todo = {V};

while (todo.size()) {
auto cur = todo.back();
todo.pop_back();
if (seen.count(cur))
continue;
seen.insert(cur);
auto obj = getBaseObject(cur, offsetAllowed);
if (auto PN = dyn_cast<PHINode>(obj)) {
for (auto &val : PN->incoming_values()) {
todo.push_back(val);
}
continue;
}
if (auto SI = dyn_cast<SelectInst>(obj)) {
todo.push_back(SI->getTrueValue());
todo.push_back(SI->getFalseValue());
continue;
}
results.insert(obj);
}
return results;
}

bool noaliased_or_arg(SetVector<llvm::Value *> &lhs_v,
SetVector<llvm::Value *> &rhs_v) {
for (auto lhs : lhs_v) {
auto lhs_na = isNoAlias(lhs);
auto lhs_arg = isa<Argument>(lhs);

// This LHS value is neither noalias or an argument
if (!lhs_na && !lhs_arg)
return false;

for (auto rhs : rhs_v) {
if (lhs == rhs)
return false;
if (isNoAlias(lhs))
continue;
if (!lhs_na && !isa<Argument>(rhs))
return false;
}
}
return true;
}

bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
llvm::AAResults &AA, llvm::LoopInfo &LI) {
bool changed = false;
Expand Down Expand Up @@ -175,33 +227,59 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
}

if (legal) {
auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs == rhs) {
auto lhs_v = getBaseObjects(I.getOperand(0), /*offsetAllowed*/ false);
auto rhs_v = getBaseObjects(I.getOperand(1), /*offsetAllowed*/ false);
if (lhs_v.size() == 1 && rhs_v.size() == 1 && lhs_v[0] == rhs_v[0]) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 1)
: ConstantInt::get(I.getType(), 0);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa<Argument>(rhs))) ||
(isNoAlias(rhs) && isa<Argument>(lhs))) {
if (noaliased_or_arg(lhs_v, rhs_v)) {
auto repval = ICmpInst::isTrueWhenEqual(pred)
? ConstantInt::get(I.getType(), 0)
: ConstantInt::get(I.getType(), 1);
I.replaceAllUsesWith(repval);
changed = true;
continue;
}
auto llhs = dyn_cast<LoadInst>(lhs);
auto lrhs = dyn_cast<LoadInst>(rhs);
if (llhs && lrhs && isa<PointerType>(llhs->getType()) &&
isa<PointerType>(lrhs->getType())) {
auto lhsv =
getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false);
auto rhsv =
getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false);
bool loadlegal = true;
SmallVector<LoadInst *, 1> llhs, lrhs;
for (auto lhs : lhs_v) {
auto ld = dyn_cast<LoadInst>(lhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
llhs.push_back(ld);
}
for (auto rhs : rhs_v) {
auto ld = dyn_cast<LoadInst>(rhs);
if (!ld || !isa<PointerType>(ld->getType())) {
loadlegal = false;
break;
}
lrhs.push_back(ld);
}
SetVector<Value *> llhs_s, lrhs_s;
for (auto v : llhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
llhs_s.insert(obj);
}
}
for (auto v : lrhs) {
for (auto obj :
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
lrhs_s.insert(obj);
}
}
// TODO handle multi size
if (llhs_s.size() == 1 && lrhs_s.size() == 1 && loadlegal) {
auto lhsv = llhs_s[0];
auto rhsv = lrhs_s[0];
if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa<Argument>(rhsv) ||
notCapturedBefore(lhsv, &I))) ||
(isNoAlias(rhsv) &&
Expand All @@ -225,7 +303,14 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

for (auto LI : {llhs, lrhs})
for (auto LI : llhs)
if (writesToMemoryReadBy(AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
for (auto LI : lrhs)
if (writesToMemoryReadBy(AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/Enzyme/JLSimplify/yesptr2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -jl-inst-simplify -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="jl-inst-simplify" -S | FileCheck %s

declare i8** @malloc(i64)

define fastcc i1 @augmented_julia__affine_normalize_1484(i1 %c) {
%i5 = call noalias i8** @malloc(i64 16)
br i1 %c, label %tval, label %fval

tval:
%j29 = load i8*, i8** %i5, align 8
br label %end

fval:
%k29 = load i8*, i8** %i5, align 8
br label %end

end:
%i29 = phi i8* [ %j29, %tval ], [ %k29, %fval ]
%i31 = call noalias nonnull i8* addrspace(10)* inttoptr (i64 137352001798896 to i8* addrspace(10)* ({} addrspace(10)*, i64, i64)*)({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 137351863426640 to {}*) to {} addrspace(10)*), i64 10, i64 10)
%i35 = load i8*, i8* addrspace(10)* %i31, align 8
%i39 = icmp ne i8* %i35, %i29
ret i1 %i39
}

; CHECK: ret i1 true

0 comments on commit c1b6729

Please sign in to comment.