Skip to content

Commit

Permalink
Merge pull request #471 from swiftwasm/master
Browse files Browse the repository at this point in the history
[pull] swiftwasm from master
  • Loading branch information
pull[bot] authored Mar 23, 2020
2 parents 2dcf860 + 11551e1 commit eac57a3
Show file tree
Hide file tree
Showing 29 changed files with 908 additions and 135 deletions.
7 changes: 6 additions & 1 deletion docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,9 @@ Declaration References
::

sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref?
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)?
sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? ('.' sil-decl-autodiff)?
sil-decl-subref ::= '!' sil-decl-lang
sil-decl-subref ::= '!' sil-decl-autodiff
sil-decl-subref-part ::= 'getter'
sil-decl-subref-part ::= 'setter'
sil-decl-subref-part ::= 'allocator'
Expand All @@ -1102,6 +1103,10 @@ Declaration References
sil-decl-subref-part ::= 'ivarinitializer'
sil-decl-subref-part ::= 'defaultarg' '.' [0-9]+
sil-decl-lang ::= 'foreign'
sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices
sil-decl-autodiff-kind ::= 'jvp'
sil-decl-autodiff-kind ::= 'vjp'
sil-decl-autodiff-indices ::= [SU]+

Some SIL instructions need to reference Swift declarations directly. These
references are introduced with the ``#`` sigil followed by the fully qualified
Expand Down
33 changes: 33 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,39 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// A derivative function configuration, uniqued in `ASTContext`.
/// Identifies a specific derivative function given an original function.
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffDerivativeFunctionKind kind;
IndexSubset *const parameterIndices;
GenericSignature derivativeGenericSignature;

AutoDiffDerivativeFunctionIdentifier(
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature)
: kind(kind), parameterIndices(parameterIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

public:
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
IndexSubset *getParameterIndices() const { return parameterIndices; }
GenericSignature getDerivativeGenericSignature() const {
return derivativeGenericSignature;
}

static AutoDiffDerivativeFunctionIdentifier *
get(AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature, ASTContext &C);

void Profile(llvm::FoldingSetNodeID &ID) {
ID.AddInteger(kind);
ID.AddPointer(parameterIndices);
auto derivativeCanGenSig =
derivativeGenericSignature.getCanonicalSignature();
ID.AddPointer(derivativeCanGenSig.getPointer());
}
};

/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ ERROR(expected_sil_colon,none,
"expected ':' before %0", (StringRef))
ERROR(expected_sil_tuple_index,none,
"expected tuple element index", ())
ERROR(invalid_index_subset,none,
"invalid index subset; expected '[SU]+' where 'S' represents set indices "
"and 'U' represents unset indices", ())

// SIL Values
ERROR(sil_value_redefinition,none,
Expand Down
6 changes: 3 additions & 3 deletions include/swift/SIL/OwnershipUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ struct BorrowingOperand {

/// Visit all of the results of the operand's user instruction that are
/// consuming uses.
void visitUserResultConsumingUses(function_ref<void(Operand *)> visitor);
void visitUserResultConsumingUses(function_ref<void(Operand *)> visitor) const;

/// Visit all of the "results" of the user of this operand that are borrow
/// scope introducers for the specific scope that this borrow scope operand
/// summarizes.
void
visitBorrowIntroducingUserResults(function_ref<void(BorrowedValue)> visitor);
visitBorrowIntroducingUserResults(function_ref<void(BorrowedValue)> visitor) const;

/// Passes to visitor all of the consuming uses of this use's using
/// instruction.
Expand All @@ -192,7 +192,7 @@ struct BorrowingOperand {
/// guaranteed scope by using a worklist and checking if any of the operands
/// are BorrowScopeOperands.
void visitConsumingUsesOfBorrowIntroducingUserResults(
function_ref<void(Operand *)> visitor);
function_ref<void(Operand *)> visitor) const;

void print(llvm::raw_ostream &os) const;
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); }
Expand Down
84 changes: 57 additions & 27 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace swift {
enum class EffectsKind : uint8_t;
class AbstractFunctionDecl;
class AbstractClosureExpr;
class AutoDiffDerivativeFunctionIdentifier;
class ValueDecl;
class FuncDecl;
class ClosureExpr;
Expand Down Expand Up @@ -147,14 +148,19 @@ struct SILDeclRef {
unsigned isForeign : 1;
/// The default argument index for a default argument getter.
unsigned defaultArgIndex : 10;

/// The derivative function identifier.
AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr;

/// Produces a null SILDeclRef.
SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0) {}

SILDeclRef()
: loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0),
derivativeFunctionIdentifier(nullptr) {}

/// Produces a SILDeclRef of the given kind for the given decl.
explicit SILDeclRef(ValueDecl *decl, Kind kind,
bool isForeign = false);

explicit SILDeclRef(
ValueDecl *decl, Kind kind, bool isForeign = false,
AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr);

/// Produces a SILDeclRef for the given ValueDecl or
/// AbstractClosureExpr:
/// - If 'loc' is a func or closure, this returns a Func SILDeclRef.
Expand All @@ -166,8 +172,7 @@ struct SILDeclRef {
/// for the containing ClassDecl.
/// - If 'loc' is a global VarDecl, this returns its GlobalAccessor
/// SILDeclRef.
explicit SILDeclRef(Loc loc,
bool isForeign = false);
explicit SILDeclRef(Loc loc, bool isForeign = false);

/// Produce a SIL constant for a default argument generator.
static SILDeclRef getDefaultArgGenerator(Loc loc, unsigned defaultArgIndex);
Expand Down Expand Up @@ -279,10 +284,10 @@ struct SILDeclRef {
}

bool operator==(SILDeclRef rhs) const {
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue()
&& kind == rhs.kind
&& isForeign == rhs.isForeign
&& defaultArgIndex == rhs.defaultArgIndex;
return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() &&
kind == rhs.kind && isForeign == rhs.isForeign &&
defaultArgIndex == rhs.defaultArgIndex &&
derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier;
}
bool operator!=(SILDeclRef rhs) const {
return !(*this == rhs);
Expand All @@ -296,8 +301,34 @@ struct SILDeclRef {
/// Returns the foreign (or native) entry point corresponding to the same
/// decl.
SILDeclRef asForeign(bool foreign = true) const {
return SILDeclRef(loc.getOpaqueValue(), kind,
foreign, defaultArgIndex);
return SILDeclRef(loc.getOpaqueValue(), kind, foreign, defaultArgIndex,
derivativeFunctionIdentifier);
}

/// Returns the entry point for the corresponding autodiff derivative
/// function.
SILDeclRef asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier *derivativeId) const {
assert(!derivativeFunctionIdentifier);
SILDeclRef declRef = *this;
declRef.derivativeFunctionIdentifier = derivativeId;
return declRef;
}

/// Returns the entry point for the original function corresponding to an
/// autodiff derivative function.
SILDeclRef asAutoDiffOriginalFunction() const {
assert(derivativeFunctionIdentifier);
SILDeclRef declRef = *this;
declRef.derivativeFunctionIdentifier = nullptr;
return declRef;
}

/// Returns this `SILDeclRef` replacing `loc` with `decl`.
SILDeclRef withDecl(ValueDecl *decl) const {
SILDeclRef result = *this;
result.loc = decl;
return result;
}

/// True if the decl ref references a thunk from a natively foreign
Expand Down Expand Up @@ -369,14 +400,12 @@ struct SILDeclRef {
private:
friend struct llvm::DenseMapInfo<swift::SILDeclRef>;
/// Produces a SILDeclRef from an opaque value.
explicit SILDeclRef(void *opaqueLoc,
Kind kind,
bool isForeign,
unsigned defaultArgIndex)
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex)
{}

explicit SILDeclRef(void *opaqueLoc, Kind kind, bool isForeign,
unsigned defaultArgIndex,
AutoDiffDerivativeFunctionIdentifier *derivativeId)
: loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind),
isForeign(isForeign), defaultArgIndex(defaultArgIndex),
derivativeFunctionIdentifier(derivativeId) {}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, SILDeclRef C) {
Expand All @@ -397,12 +426,12 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
using UnsignedInfo = DenseMapInfo<unsigned>;

static SILDeclRef getEmptyKey() {
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func,
false, 0);
return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func, false, 0,
nullptr);
}
static SILDeclRef getTombstoneKey() {
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func,
false, 0);
return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func, false, 0,
nullptr);
}
static unsigned getHashValue(swift::SILDeclRef Val) {
unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue());
Expand All @@ -411,7 +440,8 @@ template<> struct DenseMapInfo<swift::SILDeclRef> {
? UnsignedInfo::getHashValue(Val.defaultArgIndex)
: 0;
unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7);
unsigned h5 = PointerInfo::getHashValue(Val.derivativeFunctionIdentifier);
return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11);
}
static bool isEqual(swift::SILDeclRef const &LHS,
swift::SILDeclRef const &RHS) {
Expand Down
38 changes: 36 additions & 2 deletions include/swift/SIL/SILVTableVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,24 @@ template <class T> class SILVTableVisitor {
void maybeAddMethod(FuncDecl *fd) {
assert(!fd->hasClangNode());

maybeAddEntry(SILDeclRef(fd, SILDeclRef::Kind::Func));
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
maybeAddEntry(constant);

for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(vjpConstant);
}
}

void maybeAddConstructor(ConstructorDecl *cd) {
Expand All @@ -96,7 +113,24 @@ template <class T> class SILVTableVisitor {
// The initializing entry point for designated initializers is only
// necessary for super.init chaining, which is sufficiently constrained
// to never need dynamic dispatch.
maybeAddEntry(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
maybeAddEntry(constant);

for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(vjpConstant);
}
}

void maybeAddAccessors(AbstractStorageDecl *asd) {
Expand Down
33 changes: 30 additions & 3 deletions include/swift/SIL/SILWitnessVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,19 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {

void visitAbstractStorageDecl(AbstractStorageDecl *sd) {
sd->visitOpaqueAccessors([&](AccessorDecl *accessor) {
if (SILDeclRef::requiresNewWitnessTableEntry(accessor))
if (SILDeclRef::requiresNewWitnessTableEntry(accessor)) {
asDerived().addMethod(SILDeclRef(accessor, SILDeclRef::Kind::Func));
addAutoDiffDerivativeMethodsIfRequired(accessor,
SILDeclRef::Kind::Func);
}
});
}

void visitConstructorDecl(ConstructorDecl *cd) {
if (SILDeclRef::requiresNewWitnessTableEntry(cd))
if (SILDeclRef::requiresNewWitnessTableEntry(cd)) {
asDerived().addMethod(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
addAutoDiffDerivativeMethodsIfRequired(cd, SILDeclRef::Kind::Allocator);
}
}

void visitAccessorDecl(AccessorDecl *func) {
Expand All @@ -138,8 +143,10 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {

void visitFuncDecl(FuncDecl *func) {
assert(!isa<AccessorDecl>(func));
if (SILDeclRef::requiresNewWitnessTableEntry(func))
if (SILDeclRef::requiresNewWitnessTableEntry(func)) {
asDerived().addMethod(SILDeclRef(func, SILDeclRef::Kind::Func));
addAutoDiffDerivativeMethodsIfRequired(func, SILDeclRef::Kind::Func);
}
}

void visitMissingMemberDecl(MissingMemberDecl *placeholder) {
Expand All @@ -166,6 +173,26 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
void visitPoundDiagnosticDecl(PoundDiagnosticDecl *pdd) {
// We don't care about diagnostics at this stage.
}

private:
void addAutoDiffDerivativeMethodsIfRequired(AbstractFunctionDecl *AFD,
SILDeclRef::Kind kind) {
SILDeclRef declRef(AFD, kind);
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
AFD->getASTContext())));
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
AFD->getASTContext())));
}
}
};

} // end namespace swift
Expand Down
31 changes: 29 additions & 2 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ struct ASTContext::Implementation {
llvm::FoldingSet<BuiltinVectorType> BuiltinVectorTypes;
llvm::FoldingSet<DeclName::CompoundDeclName> CompoundNames;
llvm::DenseMap<UUID, OpenedArchetypeType *> OpenedExistentialArchetypes;

/// For uniquifying `IndexSubset` allocations.
llvm::FoldingSet<IndexSubset> IndexSubsets;
llvm::FoldingSet<AutoDiffDerivativeFunctionIdentifier>
AutoDiffDerivativeFunctionIdentifiers;

/// A cache of information about whether particular nominal types
/// are representable in a foreign language.
Expand Down Expand Up @@ -4754,3 +4754,30 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
foldingSet.InsertNode(newNode, insertPos);
return newNode;
}

AutoDiffDerivativeFunctionIdentifier *AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature, ASTContext &C) {
assert(parameterIndices);
auto &foldingSet = C.getImpl().AutoDiffDerivativeFunctionIdentifiers;
llvm::FoldingSetNodeID id;
id.AddInteger((unsigned)kind);
id.AddPointer(parameterIndices);
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
id.AddPointer(derivativeCanGenSig.getPointer());

void *insertPos;
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
if (existing)
return existing;

void *mem = C.Allocate(sizeof(AutoDiffDerivativeFunctionIdentifier),
alignof(AutoDiffDerivativeFunctionIdentifier));
auto *newNode = ::new (mem) AutoDiffDerivativeFunctionIdentifier(
kind, parameterIndices, derivativeGenericSignature);
foldingSet.InsertNode(newNode, insertPos);

return newNode;
}
Loading

0 comments on commit eac57a3

Please sign in to comment.