Skip to content

Commit

Permalink
[AutoDiff] NFC: formatting. (swiftlang#30573)
Browse files Browse the repository at this point in the history
Run `clang-format` on changes in swiftlang#30564.
  • Loading branch information
dan-zheng authored Mar 23, 2020
1 parent 51bc44e commit 11551e1
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 52 deletions.
4 changes: 1 addition & 3 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {

public:
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
IndexSubset *getParameterIndices() const {
return parameterIndices;
}
IndexSubset *getParameterIndices() const { return parameterIndices; }
GenericSignature getDerivativeGenericSignature() const {
return derivativeGenericSignature;
}
Expand Down
53 changes: 24 additions & 29 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,17 @@ struct SILDeclRef {
unsigned defaultArgIndex : 10;
/// The derivative function identifier.
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;

/// Produces a null SILDeclRef.
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

SILDeclRef()
: loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

/// Produces a SILDeclRef of the given kind for the given decl.
explicit SILDeclRef(ValueDecl *decl, Kind kind,
bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);
explicit SILDeclRef(
ValueDecl *decl, Kind kind, bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);

/// Produces a SILDeclRef for the given ValueDecl or
/// AbstractClosureExpr:
/// - If 'loc' is a func or closure, this returns a Func SILDeclRef.
Expand Down Expand Up @@ -283,11 +284,10 @@ struct SILDeclRef {
}

bool operator==(SILDeclRef rhs) const {
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue()
&& kind == rhs.kind
&& isForeign == rhs.isForeign
&& defaultArgIndex == rhs.defaultArgIndex
&& derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() &&
kind == rhs.kind && isForeign == rhs.isForeign &&
defaultArgIndex == rhs.defaultArgIndex &&
derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
}
bool operator!=(SILDeclRef rhs) const {
return !(*this == rhs);
Expand All @@ -301,8 +301,8 @@ struct SILDeclRef {
/// Returns the foreign (or native) entry point corresponding to the same
/// decl.
SILDeclRef asForeign(bool foreign = true) const {
return SILDeclRef(loc.getOpaqueValue(), kind,
foreign, defaultArgIndex, derivativeFunctionIdentifier);
return SILDeclRef(loc.getOpaqueValue(), kind, foreign, defaultArgIndex,
derivativeFunctionIdentifier);
}

/// Returns the entry point for the corresponding autodiff derivative
Expand Down Expand Up @@ -400,16 +400,12 @@ struct SILDeclRef {
private:
friend struct llvm::DenseMapInfo<swift::SILDeclRef>;
/// Produces a SILDeclRef from an opaque value.
explicit SILDeclRef(void *opaqueLoc,
Kind kind,
bool isForeign,
explicit SILDeclRef(void *opaqueLoc, Kind kind, bool isForeign,
unsigned defaultArgIndex,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId)
{}

: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId) {}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, SILDeclRef C) {
Expand All @@ -430,12 +426,12 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
using UnsignedInfo = DenseMapInfo<unsigned>;

static SILDeclRef getEmptyKey() {
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
false, 0, nullptr);
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func, false, 0,
nullptr);
}
static SILDeclRef getTombstoneKey() {
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
false, 0, nullptr);
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func, false, 0,
nullptr);
}
static unsigned getHashValue(swift::SILDeclRef Val) {
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
Expand All @@ -444,8 +440,7 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
: 0;
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
unsigned h5 =
PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
unsigned h5 = PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
}
static bool isEqual(swift::SILDeclRef const &LHS,
Expand Down
10 changes: 5 additions & 5 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

using namespace swift;

AutoDiffDerivativeFunctionKind::
AutoDiffDerivativeFunctionKind(StringRef string) {
Optional<innerty> result =
llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP).Case("vjp", VJP);
AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind(
StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP)
.Case("vjp", VJP);
assert(result && "Invalid string");
rawValue = *result;
}
Expand Down
4 changes: 2 additions & 2 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,8 +1456,8 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result,
return true;
}
// Parse parameter indices.
parameterIndices = IndexSubset::getFromString(
SILMod.getASTContext(), P.Tok.getText());
parameterIndices =
IndexSubset::getFromString(SILMod.getASTContext(), P.Tok.getText());
if (!parameterIndices) {
P.diagnose(P.Tok, diag::invalid_index_subset);
return true;
Expand Down
17 changes: 6 additions & 11 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,13 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
return false;
}

SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind,
bool isForeign,
SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
derivativeFunctionIdentifier(derivativeId)
{}
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0),
derivativeFunctionIdentifier(derivativeId) {}

SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr)
{
SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
: defaultArgIndex(0), derivativeFunctionIdentifier(nullptr) {
if (auto *vd = baseLoc.dyn_cast<ValueDecl*>()) {
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
// Map FuncDecls directly to Func SILDeclRefs.
Expand Down Expand Up @@ -900,16 +897,14 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
return SILDeclRef();

// JVPs/VJPs are overridden only if the base declaration has a
// `@differentiable` with the same parameter indices.
// `@differentiable` attribute with the same parameter indices.
if (derivativeFunctionIdentifier) {
auto overriddenAttrs =
overridden.getDecl()->getAttrs().getAttributes<DifferentiableAttr>();
for (const auto *attr : overriddenAttrs) {
if (attr->getParameterIndices() !=
derivativeFunctionIdentifier->getParameterIndices())
continue;

// TODO(TF-1056): Do we need to check generic signature requirements?
auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier;
overridden.derivativeFunctionIdentifier =
AutoDiffDerivativeFunctionIdentifier::get(
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3115,8 +3115,8 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
auto *loweredIndices = autodiff::getLoweredParameterIndices(
derivativeId->getParameterIndices(), formalInterfaceType);
silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType(
loweredIndices, /*resultIndex*/ 0, derivativeId->getKind(),
*this, LookUpConformanceInModule(&M));
loweredIndices, /*resultIndex*/ 0, derivativeId->getKind(), *this,
LookUpConformanceInModule(&M));
}

LLVM_DEBUG(llvm::dbgs() << "lowering type for constant ";
Expand Down

0 comments on commit 11551e1

Please sign in to comment.