Skip to content

Commit

Permalink
Preserve primal shadow info (#2113)
Browse files Browse the repository at this point in the history
* Preserve primal shadow info

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 9, 2024
1 parent 152441a commit 716d674
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 27 deletions.
10 changes: 6 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using namespace mlir;
using namespace mlir::enzyme;

void createTerminator(MGradientUtils *gutils, mlir::Block *oBB,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows) {
const ArrayRef<bool> returnPrimals,
const ArrayRef<bool> returnShadows) {
auto inst = oBB->getTerminator();

mlir::Block *nBB = gutils->getNewFromOriginal(inst->getBlock());
Expand Down Expand Up @@ -100,8 +100,10 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
for (auto act : RetActivity) {
returnShadows.push_back(act != DIFFE_TYPE::CONSTANT);
}
SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());
auto gutils = MDiffeGradientUtils::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimals, returnShadows,
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
RetActivity, ArgActivity, addedType,
/*omp*/ false);
ForwardCachedFunctions[tup] = gutils->newFunc;
Expand Down Expand Up @@ -166,7 +168,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
valid &= res.succeeded();
}

createTerminator(gutils, &oBB, returnPrimals, returnShadows);
createTerminator(gutils, &oBB, returnPrimalsP, returnShadowsP);
}

// if (mode == DerivativeMode::ForwardModeSplit && augmenteddata)
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,11 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
llvm_unreachable("Differentiating empty function");
}

SmallVector<bool> returnPrimalsP(returnPrimals.begin(), returnPrimals.end());
SmallVector<bool> returnShadowsP(returnShadows.begin(), returnShadows.end());

MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimals, returnShadows,
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
retType, constants, addedType);

Region &oldRegion = gutils->oldFunc.getFunctionBody();
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ using namespace mlir::enzyme;
mlir::enzyme::MGradientUtils::MGradientUtils(
MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_,
IRMapping &invertedPointers_,
IRMapping &invertedPointers_, const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
Expand All @@ -40,6 +41,7 @@ mlir::enzyme::MGradientUtils::MGradientUtils(
: newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
returnPrimals(returnPrimals), returnShadows(returnShadows),
activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)),
TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_),
Expand Down
24 changes: 16 additions & 8 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class MGradientUtils {
MTypeAnalysis &TA;
MTypeResults TR;
bool omp;
const llvm::ArrayRef<bool> returnPrimals;
const llvm::ArrayRef<bool> returnShadows;

unsigned width;
ArrayRef<DIFFE_TYPE> ArgDiffeTypes;
Expand All @@ -48,6 +50,8 @@ class MGradientUtils {
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
MTypeResults TR_, IRMapping &invertedPointers_,
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivities,
Expand Down Expand Up @@ -102,24 +106,28 @@ class MDiffeGradientUtils : public MGradientUtils {
MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA,
MTypeResults TR, IRMapping &invertedPointers_,
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
constantvalues_, activevals_, RetActivity, ArgActivity,
origToNew_, origToNewOps_, mode, width, omp),
returnPrimals, returnShadows, constantvalues_,
activevals_, RetActivity, ArgActivity, origToNew_,
origToNewOps_, mode, width, omp),
initializationBlock(&*(newFunc.getFunctionBody().begin())) {}

// Technically diffe constructor
static MDiffeGradientUtils *CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, mlir::Type additionalArg, bool omp) {
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
mlir::Type additionalArg, bool omp) {
std::string prefix;

switch (mode) {
Expand Down Expand Up @@ -153,9 +161,9 @@ class MDiffeGradientUtils : public MGradientUtils {
additionalArg);
MTypeResults TR; // TODO
return new MDiffeGradientUtils(
Logic, newFunc, todiff, TA, TR, invertedPointers, constant_values,
nonconstant_values, RetActivity, ArgActivity, originalToNew,
originalToNewOps, mode, width, omp);
Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, RetActivity,
ArgActivity, originalToNew, originalToNewOps, mode, width, omp);
}
};

Expand Down
24 changes: 13 additions & 11 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ using namespace mlir::enzyme;
mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
IRMapping invertedPointers_,
IRMapping invertedPointers_, const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, constantvalues_, activevals_,
ReturnActivity, ArgDiffeTypes_, originalToNewFn_,
originalToNewFnOps_, mode_, width, /*omp*/ false) {}
invertedPointers_, returnPrimals, returnShadows,
constantvalues_, activevals_, ReturnActivity,
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
mode_, width, /*omp*/ false) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -134,9 +136,9 @@ void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc,
MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> retType,
ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg) {
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
mlir::Type additionalArg) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -169,8 +171,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
prefix + todiff.getName(), originalToNew, originalToNewOps,
additionalArg);

return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers,
constant_values, nonconstant_values, retType,
constant_args, originalToNew,
originalToNewOps, mode_, width);
return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, retType,
constant_args, originalToNew, originalToNewOps, mode_, width);
}
5 changes: 3 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
IRMapping invertedPointers_,
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivity,
Expand Down Expand Up @@ -62,8 +64,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
static MGradientUtilsReverse *CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> retType,
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg);
};
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using namespace mlir::enzyme;
using namespace enzyme;

std::vector<DIFFE_TYPE> parseActivityString(StringRef inp) {
if (inp.size() == 0)
return {};
std::vector<DIFFE_TYPE> ArgActivity;
SmallVector<StringRef, 1> split;
StringRef(inp.data(), inp.size()).split(split, ',');
Expand Down

0 comments on commit 716d674

Please sign in to comment.