diff --git a/docs/SIL.rst b/docs/SIL.rst index 932e585b0066c..337339c061798 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -1088,8 +1088,9 @@ Declaration References :: sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref? - sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? + sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? ('.' sil-decl-autodiff)? sil-decl-subref ::= '!' sil-decl-lang + sil-decl-subref ::= '!' sil-decl-autodiff sil-decl-subref-part ::= 'getter' sil-decl-subref-part ::= 'setter' sil-decl-subref-part ::= 'allocator' @@ -1102,6 +1103,10 @@ Declaration References sil-decl-subref-part ::= 'ivarinitializer' sil-decl-subref-part ::= 'defaultarg' '.' [0-9]+ sil-decl-lang ::= 'foreign' + sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices + sil-decl-autodiff-kind ::= 'jvp' + sil-decl-autodiff-kind ::= 'vjp' + sil-decl-autodiff-indices ::= [SU]+ Some SIL instructions need to reference Swift declarations directly. These references are introduced with the ``#`` sigil followed by the fully qualified diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 6f95ff13d7df3..a0ec63684162b 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -75,6 +75,39 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// A derivative function configuration, uniqued in `ASTContext`. +/// Identifies a specific derivative function given an original function. +class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { + const AutoDiffDerivativeFunctionKind kind; + IndexSubset *const parameterIndices; + GenericSignature derivativeGenericSignature; + + AutoDiffDerivativeFunctionIdentifier( + AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices, + GenericSignature derivativeGenericSignature) + : kind(kind), parameterIndices(parameterIndices), + derivativeGenericSignature(derivativeGenericSignature) {} + +public: + AutoDiffDerivativeFunctionKind getKind() const { return kind; } + IndexSubset *getParameterIndices() const { return parameterIndices; } + GenericSignature getDerivativeGenericSignature() const { + return derivativeGenericSignature; + } + + static AutoDiffDerivativeFunctionIdentifier * + get(AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices, + GenericSignature derivativeGenericSignature, ASTContext &C); + + void Profile(llvm::FoldingSetNodeID &ID) { + ID.AddInteger(kind); + ID.AddPointer(parameterIndices); + auto derivativeCanGenSig = + derivativeGenericSignature.getCanonicalSignature(); + ID.AddPointer(derivativeCanGenSig.getPointer()); + } +}; + /// The kind of a differentiability witness function. struct DifferentiabilityWitnessFunctionKind { enum innerty : uint8_t { diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 6ef74d41f59be..bca7857545d2c 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -504,6 +504,9 @@ ERROR(expected_sil_colon,none, "expected ':' before %0", (StringRef)) ERROR(expected_sil_tuple_index,none, "expected tuple element index", ()) +ERROR(invalid_index_subset,none, + "invalid index subset; expected '[SU]+' where 'S' represents set indices " + "and 'U' represents unset indices", ()) // SIL Values ERROR(sil_value_redefinition,none, diff --git a/include/swift/SIL/OwnershipUtils.h b/include/swift/SIL/OwnershipUtils.h index ce285fb8210a0..a66c74221aefb 100644 --- a/include/swift/SIL/OwnershipUtils.h +++ b/include/swift/SIL/OwnershipUtils.h @@ -177,13 +177,13 @@ struct BorrowingOperand { /// Visit all of the results of the operand's user instruction that are /// consuming uses. - void visitUserResultConsumingUses(function_ref visitor); + void visitUserResultConsumingUses(function_ref visitor) const; /// Visit all of the "results" of the user of this operand that are borrow /// scope introducers for the specific scope that this borrow scope operand /// summarizes. void - visitBorrowIntroducingUserResults(function_ref visitor); + visitBorrowIntroducingUserResults(function_ref visitor) const; /// Passes to visitor all of the consuming uses of this use's using /// instruction. @@ -192,7 +192,7 @@ struct BorrowingOperand { /// guaranteed scope by using a worklist and checking if any of the operands /// are BorrowScopeOperands. void visitConsumingUsesOfBorrowIntroducingUserResults( - function_ref visitor); + function_ref visitor) const; void print(llvm::raw_ostream &os) const; SWIFT_DEBUG_DUMP { print(llvm::dbgs()); } diff --git a/include/swift/SIL/SILDeclRef.h b/include/swift/SIL/SILDeclRef.h index d5d19d46306ed..4487fa4d8ef5d 100644 --- a/include/swift/SIL/SILDeclRef.h +++ b/include/swift/SIL/SILDeclRef.h @@ -34,6 +34,7 @@ namespace swift { enum class EffectsKind : uint8_t; class AbstractFunctionDecl; class AbstractClosureExpr; + class AutoDiffDerivativeFunctionIdentifier; class ValueDecl; class FuncDecl; class ClosureExpr; @@ -147,14 +148,19 @@ struct SILDeclRef { unsigned isForeign : 1; /// The default argument index for a default argument getter. unsigned defaultArgIndex : 10; - + /// The derivative function identifier. + AutoDiffDerivativeFunctionIdentifier *derivativeFunctionIdentifier = nullptr; + /// Produces a null SILDeclRef. - SILDeclRef() : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0) {} - + SILDeclRef() + : loc(), kind(Kind::Func), isForeign(0), defaultArgIndex(0), + derivativeFunctionIdentifier(nullptr) {} + /// Produces a SILDeclRef of the given kind for the given decl. - explicit SILDeclRef(ValueDecl *decl, Kind kind, - bool isForeign = false); - + explicit SILDeclRef( + ValueDecl *decl, Kind kind, bool isForeign = false, + AutoDiffDerivativeFunctionIdentifier *derivativeId = nullptr); + /// Produces a SILDeclRef for the given ValueDecl or /// AbstractClosureExpr: /// - If 'loc' is a func or closure, this returns a Func SILDeclRef. @@ -166,8 +172,7 @@ struct SILDeclRef { /// for the containing ClassDecl. /// - If 'loc' is a global VarDecl, this returns its GlobalAccessor /// SILDeclRef. - explicit SILDeclRef(Loc loc, - bool isForeign = false); + explicit SILDeclRef(Loc loc, bool isForeign = false); /// Produce a SIL constant for a default argument generator. static SILDeclRef getDefaultArgGenerator(Loc loc, unsigned defaultArgIndex); @@ -279,10 +284,10 @@ struct SILDeclRef { } bool operator==(SILDeclRef rhs) const { - return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() - && kind == rhs.kind - && isForeign == rhs.isForeign - && defaultArgIndex == rhs.defaultArgIndex; + return loc.getOpaqueValue() == rhs.loc.getOpaqueValue() && + kind == rhs.kind && isForeign == rhs.isForeign && + defaultArgIndex == rhs.defaultArgIndex && + derivativeFunctionIdentifier == rhs.derivativeFunctionIdentifier; } bool operator!=(SILDeclRef rhs) const { return !(*this == rhs); @@ -296,8 +301,34 @@ struct SILDeclRef { /// Returns the foreign (or native) entry point corresponding to the same /// decl. SILDeclRef asForeign(bool foreign = true) const { - return SILDeclRef(loc.getOpaqueValue(), kind, - foreign, defaultArgIndex); + return SILDeclRef(loc.getOpaqueValue(), kind, foreign, defaultArgIndex, + derivativeFunctionIdentifier); + } + + /// Returns the entry point for the corresponding autodiff derivative + /// function. + SILDeclRef asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier *derivativeId) const { + assert(!derivativeFunctionIdentifier); + SILDeclRef declRef = *this; + declRef.derivativeFunctionIdentifier = derivativeId; + return declRef; + } + + /// Returns the entry point for the original function corresponding to an + /// autodiff derivative function. + SILDeclRef asAutoDiffOriginalFunction() const { + assert(derivativeFunctionIdentifier); + SILDeclRef declRef = *this; + declRef.derivativeFunctionIdentifier = nullptr; + return declRef; + } + + /// Returns this `SILDeclRef` replacing `loc` with `decl`. + SILDeclRef withDecl(ValueDecl *decl) const { + SILDeclRef result = *this; + result.loc = decl; + return result; } /// True if the decl ref references a thunk from a natively foreign @@ -369,14 +400,12 @@ struct SILDeclRef { private: friend struct llvm::DenseMapInfo; /// Produces a SILDeclRef from an opaque value. - explicit SILDeclRef(void *opaqueLoc, - Kind kind, - bool isForeign, - unsigned defaultArgIndex) - : loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind), - isForeign(isForeign), defaultArgIndex(defaultArgIndex) - {} - + explicit SILDeclRef(void *opaqueLoc, Kind kind, bool isForeign, + unsigned defaultArgIndex, + AutoDiffDerivativeFunctionIdentifier *derivativeId) + : loc(Loc::getFromOpaqueValue(opaqueLoc)), kind(kind), + isForeign(isForeign), defaultArgIndex(defaultArgIndex), + derivativeFunctionIdentifier(derivativeId) {} }; inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, SILDeclRef C) { @@ -397,12 +426,12 @@ template<> struct DenseMapInfo { using UnsignedInfo = DenseMapInfo; static SILDeclRef getEmptyKey() { - return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func, - false, 0); + return SILDeclRef(PointerInfo::getEmptyKey(), Kind::Func, false, 0, + nullptr); } static SILDeclRef getTombstoneKey() { - return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func, - false, 0); + return SILDeclRef(PointerInfo::getTombstoneKey(), Kind::Func, false, 0, + nullptr); } static unsigned getHashValue(swift::SILDeclRef Val) { unsigned h1 = PointerInfo::getHashValue(Val.loc.getOpaqueValue()); @@ -411,7 +440,8 @@ template<> struct DenseMapInfo { ? UnsignedInfo::getHashValue(Val.defaultArgIndex) : 0; unsigned h4 = UnsignedInfo::getHashValue(Val.isForeign); - return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7); + unsigned h5 = PointerInfo::getHashValue(Val.derivativeFunctionIdentifier); + return h1 ^ (h2 << 4) ^ (h3 << 9) ^ (h4 << 7) ^ (h5 << 11); } static bool isEqual(swift::SILDeclRef const &LHS, swift::SILDeclRef const &RHS) { diff --git a/include/swift/SIL/SILVTableVisitor.h b/include/swift/SIL/SILVTableVisitor.h index 184781bb93a46..b9be52452b402 100644 --- a/include/swift/SIL/SILVTableVisitor.h +++ b/include/swift/SIL/SILVTableVisitor.h @@ -86,7 +86,24 @@ template class SILVTableVisitor { void maybeAddMethod(FuncDecl *fd) { assert(!fd->hasClangNode()); - maybeAddEntry(SILDeclRef(fd, SILDeclRef::Kind::Func)); + SILDeclRef constant(fd, SILDeclRef::Kind::Func); + maybeAddEntry(constant); + + for (auto *diffAttr : fd->getAttrs().getAttributes()) { + auto jvpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), fd->getASTContext())); + maybeAddEntry(jvpConstant); + + auto vjpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), fd->getASTContext())); + maybeAddEntry(vjpConstant); + } } void maybeAddConstructor(ConstructorDecl *cd) { @@ -96,7 +113,24 @@ template class SILVTableVisitor { // The initializing entry point for designated initializers is only // necessary for super.init chaining, which is sufficiently constrained // to never need dynamic dispatch. - maybeAddEntry(SILDeclRef(cd, SILDeclRef::Kind::Allocator)); + SILDeclRef constant(cd, SILDeclRef::Kind::Allocator); + maybeAddEntry(constant); + + for (auto *diffAttr : cd->getAttrs().getAttributes()) { + auto jvpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), cd->getASTContext())); + maybeAddEntry(jvpConstant); + + auto vjpConstant = constant.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), cd->getASTContext())); + maybeAddEntry(vjpConstant); + } } void maybeAddAccessors(AbstractStorageDecl *asd) { diff --git a/include/swift/SIL/SILWitnessVisitor.h b/include/swift/SIL/SILWitnessVisitor.h index 84768e1a0eeea..4c8bfeb4af50d 100644 --- a/include/swift/SIL/SILWitnessVisitor.h +++ b/include/swift/SIL/SILWitnessVisitor.h @@ -122,14 +122,19 @@ template class SILWitnessVisitor : public ASTVisitor { void visitAbstractStorageDecl(AbstractStorageDecl *sd) { sd->visitOpaqueAccessors([&](AccessorDecl *accessor) { - if (SILDeclRef::requiresNewWitnessTableEntry(accessor)) + if (SILDeclRef::requiresNewWitnessTableEntry(accessor)) { asDerived().addMethod(SILDeclRef(accessor, SILDeclRef::Kind::Func)); + addAutoDiffDerivativeMethodsIfRequired(accessor, + SILDeclRef::Kind::Func); + } }); } void visitConstructorDecl(ConstructorDecl *cd) { - if (SILDeclRef::requiresNewWitnessTableEntry(cd)) + if (SILDeclRef::requiresNewWitnessTableEntry(cd)) { asDerived().addMethod(SILDeclRef(cd, SILDeclRef::Kind::Allocator)); + addAutoDiffDerivativeMethodsIfRequired(cd, SILDeclRef::Kind::Allocator); + } } void visitAccessorDecl(AccessorDecl *func) { @@ -138,8 +143,10 @@ template class SILWitnessVisitor : public ASTVisitor { void visitFuncDecl(FuncDecl *func) { assert(!isa(func)); - if (SILDeclRef::requiresNewWitnessTableEntry(func)) + if (SILDeclRef::requiresNewWitnessTableEntry(func)) { asDerived().addMethod(SILDeclRef(func, SILDeclRef::Kind::Func)); + addAutoDiffDerivativeMethodsIfRequired(func, SILDeclRef::Kind::Func); + } } void visitMissingMemberDecl(MissingMemberDecl *placeholder) { @@ -166,6 +173,26 @@ template class SILWitnessVisitor : public ASTVisitor { void visitPoundDiagnosticDecl(PoundDiagnosticDecl *pdd) { // We don't care about diagnostics at this stage. } + +private: + void addAutoDiffDerivativeMethodsIfRequired(AbstractFunctionDecl *AFD, + SILDeclRef::Kind kind) { + SILDeclRef declRef(AFD, kind); + for (auto *diffAttr : AFD->getAttrs().getAttributes()) { + asDerived().addMethod(declRef.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::JVP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), + AFD->getASTContext()))); + asDerived().addMethod(declRef.asAutoDiffDerivativeFunction( + AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind::VJP, + diffAttr->getParameterIndices(), + diffAttr->getDerivativeGenericSignature(), + AFD->getASTContext()))); + } + } }; } // end namespace swift diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 281f4f0c3756f..231eec4c32a5a 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -420,9 +420,9 @@ struct ASTContext::Implementation { llvm::FoldingSet BuiltinVectorTypes; llvm::FoldingSet CompoundNames; llvm::DenseMap OpenedExistentialArchetypes; - - /// For uniquifying `IndexSubset` allocations. llvm::FoldingSet IndexSubsets; + llvm::FoldingSet + AutoDiffDerivativeFunctionIdentifiers; /// A cache of information about whether particular nominal types /// are representable in a foreign language. @@ -4754,3 +4754,30 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) { foldingSet.InsertNode(newNode, insertPos); return newNode; } + +AutoDiffDerivativeFunctionIdentifier *AutoDiffDerivativeFunctionIdentifier::get( + AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices, + GenericSignature derivativeGenericSignature, ASTContext &C) { + assert(parameterIndices); + auto &foldingSet = C.getImpl().AutoDiffDerivativeFunctionIdentifiers; + llvm::FoldingSetNodeID id; + id.AddInteger((unsigned)kind); + id.AddPointer(parameterIndices); + CanGenericSignature derivativeCanGenSig; + if (derivativeGenericSignature) + derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature(); + id.AddPointer(derivativeCanGenSig.getPointer()); + + void *insertPos; + auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos); + if (existing) + return existing; + + void *mem = C.Allocate(sizeof(AutoDiffDerivativeFunctionIdentifier), + alignof(AutoDiffDerivativeFunctionIdentifier)); + auto *newNode = ::new (mem) AutoDiffDerivativeFunctionIdentifier( + kind, parameterIndices, derivativeGenericSignature); + foldingSet.InsertNode(newNode, insertPos); + + return newNode; +} diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index f4db2a0b1e1c5..c806382b65d31 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -19,6 +19,15 @@ using namespace swift; +AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind( + StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("jvp", JVP) + .Case("vjp", VJP); + assert(result && "Invalid string"); + rawValue = *result; +} + DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind( StringRef string) { Optional result = llvm::StringSwitch>(string) diff --git a/lib/IRGen/GenDiffWitness.cpp b/lib/IRGen/GenDiffWitness.cpp index f27a5861a8840..a13293dca59cf 100644 --- a/lib/IRGen/GenDiffWitness.cpp +++ b/lib/IRGen/GenDiffWitness.cpp @@ -39,15 +39,28 @@ void IRGenModule::emitSILDifferentiabilityWitness( ConstantInitBuilder builder(*this); auto diffWitnessContents = builder.beginStruct(); + // TODO(TF-1211): Uncomment assertions after upstreaming differentiation + // transform. + // The mandatory differentiation transform canonicalizes differentiability + // witnesses and ensures that JVPs/VJPs are populated. + /* assert(dw->getJVP() && "Differentiability witness definition should have JVP"); assert(dw->getVJP() && "Differentiability witness definition should have VJP"); - diffWitnessContents.addBitCast( getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy); diffWitnessContents.addBitCast( getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy); + */ + llvm::Constant *jvpValue = llvm::UndefValue::get(Int8PtrTy); + llvm::Constant *vjpValue = llvm::UndefValue::get(Int8PtrTy); + if (auto *jvpFn = dw->getJVP()) + jvpValue = getAddrOfSILFunction(dw->getJVP(), NotForDefinition); + if (auto *vjpFn = dw->getJVP()) + vjpValue = getAddrOfSILFunction(dw->getVJP(), NotForDefinition); + diffWitnessContents.addBitCast(jvpValue, Int8PtrTy); + diffWitnessContents.addBitCast(vjpValue, Int8PtrTy); getAddrOfDifferentiabilityWitness( dw, diffWitnessContents.finishAndCreateFuture()); diff --git a/lib/IRGen/GenKeyPath.cpp b/lib/IRGen/GenKeyPath.cpp index da6db5cb0739e..cd91cd96c03d5 100644 --- a/lib/IRGen/GenKeyPath.cpp +++ b/lib/IRGen/GenKeyPath.cpp @@ -998,8 +998,7 @@ emitKeyPathComponent(IRGenModule &IGM, auto methodProto = cast(dc); auto &protoInfo = IGM.getProtocolInfo(methodProto, ProtocolInfoKind::Full); - auto index = protoInfo.getFunctionIndex( - cast(declRef.getDecl())); + auto index = protoInfo.getFunctionIndex(declRef); idValue = llvm::ConstantInt::get(IGM.SizeTy, -index.getValue()); idResolution = KeyPathComponentHeader::Resolved; } diff --git a/lib/IRGen/GenProto.cpp b/lib/IRGen/GenProto.cpp index 2f32961082161..3a6567b94b21c 100644 --- a/lib/IRGen/GenProto.cpp +++ b/lib/IRGen/GenProto.cpp @@ -792,20 +792,19 @@ namespace { } void addMethod(SILDeclRef func) { - auto decl = cast(func.getDecl()); // If this assert needs to be changed, be sure to also change // ProtocolDescriptorBuilder::getRequirementInfo. - assert((isa(decl) - ? (func.kind == SILDeclRef::Kind::Allocator) - : (func.kind == SILDeclRef::Kind::Func)) - && "unexpected kind for protocol witness declaration ref"); - Entries.push_back(WitnessTableEntry::forFunction(decl)); + assert((isa(func.getDecl()) + ? (func.kind == SILDeclRef::Kind::Allocator) + : (func.kind == SILDeclRef::Kind::Func)) && + "unexpected kind for protocol witness declaration ref"); + Entries.push_back(WitnessTableEntry::forFunction(func)); } void addPlaceholder(MissingMemberDecl *placeholder) { for (auto i : range(placeholder->getNumberOfVTableEntries())) { (void)i; - Entries.push_back(WitnessTableEntry()); + Entries.push_back(WitnessTableEntry::forPlaceholder()); } } @@ -1318,8 +1317,7 @@ class AccessorConformanceInfo : public ConformanceInfo { && "sil witness table does not match protocol"); assert(entry.getMethodWitness().Requirement == requirement && "sil witness table does not match protocol"); - auto piIndex = - PI.getFunctionIndex(cast(requirement.getDecl())); + auto piIndex = PI.getFunctionIndex(requirement); assert((size_t)piIndex.getValue() == Table.size() - WitnessTableFirstRequirementOffset && "offset doesn't match ProtocolInfo layout"); @@ -3277,7 +3275,7 @@ FunctionPointer irgen::emitWitnessMethodValue(IRGenFunction &IGF, // Find the witness we're interested in. auto &fnProtoInfo = IGF.IGM.getProtocolInfo(proto, ProtocolInfoKind::Full); - auto index = fnProtoInfo.getFunctionIndex(fn); + auto index = fnProtoInfo.getFunctionIndex(member); llvm::Value *slot; llvm::Value *witnessFnPtr = emitInvariantLoadOfOpaqueWitness(IGF, wtable, diff --git a/lib/IRGen/ProtocolInfo.h b/lib/IRGen/ProtocolInfo.h index 23f500ea1789e..d0a192ad39e98 100644 --- a/lib/IRGen/ProtocolInfo.h +++ b/lib/IRGen/ProtocolInfo.h @@ -39,28 +39,59 @@ namespace irgen { /// ProtocolTypeInfo stores one of these for each requirement /// introduced by the protocol. class WitnessTableEntry { -public: - llvm::PointerUnion MemberOrAssociatedType; - ProtocolDecl *Protocol; - - WitnessTableEntry(llvm::PointerUnion member, - ProtocolDecl *protocol) - : MemberOrAssociatedType(member), Protocol(protocol) {} + enum WitnessKind { + PlaceholderKind, + OutOfLineBaseKind, + MethodKind, + AssociatedTypeKind, + AssociatedConformanceKind + }; + + struct OutOfLineBaseWitness { + ProtocolDecl *Protocol; + }; + + struct MethodWitness { + SILDeclRef Witness; + }; + + struct AssociatedTypeWitness { + AssociatedTypeDecl *Association; + }; + + struct AssociatedConformanceWitness { + TypeBase *AssociatedType; + ProtocolDecl *Protocol; + }; + + WitnessKind Kind; + union { + OutOfLineBaseWitness OutOfLineBaseEntry; + MethodWitness MethodEntry; + AssociatedTypeWitness AssociatedTypeEntry; + AssociatedConformanceWitness AssociatedConformanceEntry; + }; + + WitnessTableEntry(WitnessKind Kind) : Kind(Kind) {} public: - WitnessTableEntry() = default; + static WitnessTableEntry forPlaceholder() { + return WitnessTableEntry(WitnessKind::PlaceholderKind); + } static WitnessTableEntry forOutOfLineBase(ProtocolDecl *proto) { assert(proto != nullptr); - return WitnessTableEntry({}, proto); + WitnessTableEntry entry(WitnessKind::OutOfLineBaseKind); + entry.OutOfLineBaseEntry = {proto}; + return entry; } /// Is this a base-protocol entry? - bool isBase() const { return MemberOrAssociatedType.isNull(); } + bool isBase() const { return Kind == WitnessKind::OutOfLineBaseKind; } bool matchesBase(ProtocolDecl *proto) const { assert(proto != nullptr); - return MemberOrAssociatedType.isNull() && Protocol == proto; + return isBase() && OutOfLineBaseEntry.Protocol == proto; } /// Given that this is a base-protocol entry, is the table @@ -72,84 +103,97 @@ class WitnessTableEntry { ProtocolDecl *getBase() const { assert(isBase()); - return Protocol; + return OutOfLineBaseEntry.Protocol; } - static WitnessTableEntry forFunction(AbstractFunctionDecl *func) { - assert(func != nullptr); - return WitnessTableEntry(func, nullptr); - } - - bool isFunction() const { - auto decl = MemberOrAssociatedType.dyn_cast(); - return Protocol == nullptr && decl && isa(decl); + static WitnessTableEntry forFunction(SILDeclRef declRef) { + assert(!declRef.isNull()); + WitnessTableEntry entry(WitnessKind::MethodKind); + entry.MethodEntry = {declRef}; + return entry; } - bool matchesFunction(AbstractFunctionDecl *func) const { - assert(func != nullptr); - if (auto decl = MemberOrAssociatedType.dyn_cast()) - return decl == func && Protocol == nullptr; - return false; + bool isFunction() const { return Kind == WitnessKind::MethodKind; } + + bool matchesFunction(SILDeclRef declRef) const { + return isFunction() && MethodEntry.Witness == declRef; } - AbstractFunctionDecl *getFunction() const { + SILDeclRef getFunction() const { assert(isFunction()); - auto decl = MemberOrAssociatedType.get(); - return static_cast(decl); + return MethodEntry.Witness; } static WitnessTableEntry forAssociatedType(AssociatedType ty) { - return WitnessTableEntry(ty.getAssociation(), nullptr); + WitnessTableEntry entry(WitnessKind::AssociatedTypeKind); + entry.AssociatedTypeEntry = {ty.getAssociation()}; + return entry; } bool isAssociatedType() const { - if (auto decl = MemberOrAssociatedType.dyn_cast()) - return Protocol == nullptr && isa(decl); - return false; + return Kind == WitnessKind::AssociatedTypeKind; } bool matchesAssociatedType(AssociatedType assocType) const { - if (auto decl = MemberOrAssociatedType.dyn_cast()) - return decl == assocType.getAssociation() && Protocol == nullptr; - return false; + return isAssociatedType() && + AssociatedTypeEntry.Association == assocType.getAssociation(); } AssociatedTypeDecl *getAssociatedType() const { assert(isAssociatedType()); - auto decl = MemberOrAssociatedType.get(); - return static_cast(decl); + return AssociatedTypeEntry.Association; } - static WitnessTableEntry forAssociatedConformance(AssociatedConformance conf){ - return WitnessTableEntry(conf.getAssociation().getPointer(), - conf.getAssociatedRequirement()); + static WitnessTableEntry + forAssociatedConformance(AssociatedConformance conf) { + WitnessTableEntry entry(WitnessKind::AssociatedConformanceKind); + entry.AssociatedConformanceEntry = {conf.getAssociation().getPointer(), + conf.getAssociatedRequirement()}; + return entry; } bool isAssociatedConformance() const { - return Protocol != nullptr && !MemberOrAssociatedType.isNull(); + return Kind == WitnessKind::AssociatedConformanceKind; } bool matchesAssociatedConformance(const AssociatedConformance &conf) const { - if (auto type = MemberOrAssociatedType.dyn_cast()) - return type == conf.getAssociation().getPointer() && - Protocol == conf.getAssociatedRequirement(); - return false; + return isAssociatedConformance() && + AssociatedConformanceEntry.AssociatedType == + conf.getAssociation().getPointer() && + AssociatedConformanceEntry.Protocol == + conf.getAssociatedRequirement(); } CanType getAssociatedConformancePath() const { assert(isAssociatedConformance()); - auto type = MemberOrAssociatedType.get(); - return CanType(type); + return CanType(AssociatedConformanceEntry.AssociatedType); } ProtocolDecl *getAssociatedConformanceRequirement() const { assert(isAssociatedConformance()); - return Protocol; + return AssociatedConformanceEntry.Protocol; } friend bool operator==(WitnessTableEntry left, WitnessTableEntry right) { - return left.MemberOrAssociatedType == right.MemberOrAssociatedType && - left.Protocol == right.Protocol; + if (left.Kind != right.Kind) + return false; + switch (left.Kind) { + case WitnessKind::PlaceholderKind: + return true; + case WitnessKind::OutOfLineBaseKind: + return left.OutOfLineBaseEntry.Protocol == + right.OutOfLineBaseEntry.Protocol; + case WitnessKind::MethodKind: + return left.MethodEntry.Witness == right.MethodEntry.Witness; + case WitnessKind::AssociatedTypeKind: + return left.AssociatedTypeEntry.Association == + right.AssociatedTypeEntry.Association; + case WitnessKind::AssociatedConformanceKind: + return left.AssociatedConformanceEntry.AssociatedType == + right.AssociatedConformanceEntry.AssociatedType && + left.AssociatedConformanceEntry.Protocol == + right.AssociatedConformanceEntry.Protocol; + } } }; @@ -236,10 +280,10 @@ class ProtocolInfo final : /// Return the witness index for the witness function for the given /// function requirement. - WitnessIndex getFunctionIndex(AbstractFunctionDecl *function) const { + WitnessIndex getFunctionIndex(SILDeclRef declRef) const { assert(getKind() >= ProtocolInfoKind::Full); for (auto &witness : getWitnessEntries()) { - if (witness.matchesFunction(function)) + if (witness.matchesFunction(declRef)) return getNonBaseWitnessIndex(&witness); } llvm_unreachable("didn't find entry for function"); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 2c51097233b47..2bc8370406fee 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -1341,6 +1341,7 @@ static Optional getAccessorKind(StringRef ident) { /// sil-decl-ref ::= '#' sil-identifier ('.' sil-identifier)* sil-decl-subref? /// sil-decl-subref ::= '!' sil-decl-subref-part ('.' sil-decl-lang)? +/// ('.' sil-decl-autodiff)? /// sil-decl-subref ::= '!' sil-decl-lang /// sil-decl-subref-part ::= 'getter' /// sil-decl-subref-part ::= 'setter' @@ -1350,27 +1351,33 @@ static Optional getAccessorKind(StringRef ident) { /// sil-decl-subref-part ::= 'destroyer' /// sil-decl-subref-part ::= 'globalaccessor' /// sil-decl-lang ::= 'foreign' +/// sil-decl-autodiff ::= sil-decl-autodiff-kind '.' sil-decl-autodiff-indices +/// sil-decl-autodiff-kind ::= 'jvp' +/// sil-decl-autodiff-kind ::= 'vjp' +/// sil-decl-autodiff-indices ::= [SU]+ bool SILParser::parseSILDeclRef(SILDeclRef &Result, SmallVectorImpl &values) { ValueDecl *VD; if (parseSILDottedPath(VD, values)) return true; - // Initialize Kind and IsObjC. + // Initialize SILDeclRef components. SILDeclRef::Kind Kind = SILDeclRef::Kind::Func; bool IsObjC = false; + AutoDiffDerivativeFunctionIdentifier *DerivativeId = nullptr; if (!P.consumeIf(tok::sil_exclamation)) { // Construct SILDeclRef. - Result = SILDeclRef(VD, Kind, IsObjC); + Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId); return false; } - // Handle sil-constant-kind-and-uncurry-level. - // ParseState indicates the value we just handled. - // 1 means we just handled Kind. - // We accept func|getter|setter|...|foreign when ParseState is 0; - // accept foreign when ParseState is 1. + // Handle SILDeclRef components. ParseState tracks the last parsed component. + // + // When ParseState is 0, accept kind (`func|getter|setter|...`) and set + // ParseState to 1. + // + // Always accept `foreign` and derivative function identifier. unsigned ParseState = 0; Identifier Id; do { @@ -1439,15 +1446,47 @@ bool SILParser::parseSILDeclRef(SILDeclRef &Result, } else if (Id.str() == "foreign") { IsObjC = true; break; - } else + } else if (Id.str() == "jvp" || Id.str() == "vjp") { + IndexSubset *parameterIndices = nullptr; + GenericSignature derivativeGenSig; + // Parse derivative function kind. + AutoDiffDerivativeFunctionKind derivativeKind(Id.str()); + if (!P.consumeIf(tok::period)) { + P.diagnose(P.Tok, diag::expected_tok_in_sil_instr, "."); + return true; + } + // Parse parameter indices. + parameterIndices = + IndexSubset::getFromString(SILMod.getASTContext(), P.Tok.getText()); + if (!parameterIndices) { + P.diagnose(P.Tok, diag::invalid_index_subset); + return true; + } + P.consumeToken(); + // Parse derivative generic signature (optional). + if (P.Tok.is(tok::oper_binary_unspaced) && P.Tok.getText() == ".<") { + P.consumeStartingCharacterOfCurrentToken(tok::period); + // Create a new scope to avoid type redefinition errors. + Scope genericsScope(&P, ScopeKind::Generics); + auto *genericParams = P.maybeParseGenericParams().getPtrOrNull(); + assert(genericParams); + auto *derivativeGenEnv = handleSILGenericParams(genericParams, &P.SF); + derivativeGenSig = derivativeGenEnv->getGenericSignature(); + } + DerivativeId = AutoDiffDerivativeFunctionIdentifier::get( + derivativeKind, parameterIndices, derivativeGenSig, + SILMod.getASTContext()); + break; + } else { break; + } } else break; } while (P.consumeIf(tok::period)); // Construct SILDeclRef. - Result = SILDeclRef(VD, Kind, IsObjC); + Result = SILDeclRef(VD, Kind, IsObjC, DerivativeId); return false; } diff --git a/lib/SIL/OwnershipUtils.cpp b/lib/SIL/OwnershipUtils.cpp index 44f3812af86aa..5f671d9cf613c 100644 --- a/lib/SIL/OwnershipUtils.cpp +++ b/lib/SIL/OwnershipUtils.cpp @@ -190,7 +190,7 @@ void BorrowingOperand::visitEndScopeInstructions( } void BorrowingOperand::visitBorrowIntroducingUserResults( - function_ref visitor) { + function_ref visitor) const { switch (kind) { case BorrowingOperandKind::BeginApply: llvm_unreachable("Never has borrow introducer results!"); @@ -212,7 +212,7 @@ void BorrowingOperand::visitBorrowIntroducingUserResults( } void BorrowingOperand::visitConsumingUsesOfBorrowIntroducingUserResults( - function_ref func) { + function_ref func) const { // First visit all of the results of our user that are borrow introducing // values. visitBorrowIntroducingUserResults([&](BorrowedValue value) { @@ -238,7 +238,7 @@ void BorrowingOperand::visitConsumingUsesOfBorrowIntroducingUserResults( } void BorrowingOperand::visitUserResultConsumingUses( - function_ref visitor) { + function_ref visitor) const { auto *ti = dyn_cast(op->getUser()); if (!ti) { for (SILValue result : op->getUser()->getResults()) { diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index a86cb734f5b5f..18c3d1d71ffb1 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -113,14 +113,13 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) { return false; } -SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, - bool isForeign) - : loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0) -{} - -SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign) - : defaultArgIndex(0) -{ +SILDeclRef::SILDeclRef(ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign, + AutoDiffDerivativeFunctionIdentifier *derivativeId) + : loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0), + derivativeFunctionIdentifier(derivativeId) {} + +SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign) + : defaultArgIndex(0), derivativeFunctionIdentifier(nullptr) { if (auto *vd = baseLoc.dyn_cast()) { if (auto *fd = dyn_cast(vd)) { // Map FuncDecls directly to Func SILDeclRefs. @@ -653,6 +652,21 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { using namespace Mangle; ASTMangler mangler; + if (derivativeFunctionIdentifier) { + std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind); + auto *silParameterIndices = autodiff::getLoweredParameterIndices( + derivativeFunctionIdentifier->getParameterIndices(), + getDecl()->getInterfaceType()->castTo()); + auto &ctx = getDecl()->getASTContext(); + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + AutoDiffConfig silConfig( + silParameterIndices, resultIndices, + derivativeFunctionIdentifier->getDerivativeGenericSignature()); + auto derivativeFnKind = derivativeFunctionIdentifier->getKind(); + return mangler.mangleAutoDiffDerivativeFunctionHelper( + originalMangled, derivativeFnKind, silConfig); + } + // As a special case, Clang functions and globals don't get mangled at all. if (hasDecl()) { if (auto clangDecl = getDecl()->getClangDecl()) { @@ -764,7 +778,53 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { llvm_unreachable("bad entity kind!"); } +// Returns true if the given JVP/VJP SILDeclRef requires a new vtable entry. +// FIXME(TF-1213): Also consider derived declaration `@derivative` attributes. +static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) { + assert(declRef.derivativeFunctionIdentifier && + "Expected a derivative function SILDeclRef"); + auto overridden = declRef.getOverridden(); + if (!overridden) + return false; + // Get the derived `@differentiable` attribute. + auto *derivedDiffAttr = *llvm::find_if( + declRef.getDecl()->getAttrs().getAttributes(), + [&](const DifferentiableAttr *derivedDiffAttr) { + return derivedDiffAttr->getParameterIndices() == + declRef.derivativeFunctionIdentifier->getParameterIndices(); + }); + assert(derivedDiffAttr && "Expected `@differentiable` attribute"); + // If the derived `@differentiable` attribute specifies a derivative function, + // then a new vtable entry is needed. Return true. + switch (declRef.derivativeFunctionIdentifier->getKind()) { + case AutoDiffDerivativeFunctionKind::JVP: + if (!overridden.requiresNewVTableEntry() && derivedDiffAttr->getJVP()) + return true; + break; + case AutoDiffDerivativeFunctionKind::VJP: + if (!overridden.requiresNewVTableEntry() && derivedDiffAttr->getVJP()) + return true; + break; + } + // Otherwise, if the base `@differentiable` attribute specifies a derivative + // function, then the derivative is inherited and no new vtable entry is + // needed. Return false. + auto baseDiffAttrs = + overridden.getDecl()->getAttrs().getAttributes(); + for (auto *baseDiffAttr : baseDiffAttrs) { + if (baseDiffAttr->getParameterIndices() == + declRef.derivativeFunctionIdentifier->getParameterIndices()) + return false; + } + // Otherwise, if there is no base `@differentiable` attribute exists, then a + // new vtable entry is needed. Return true. + return true; +} + bool SILDeclRef::requiresNewVTableEntry() const { + if (derivativeFunctionIdentifier) + if (derivativeFunctionRequiresNewVTableEntry(*this)) + return true; if (cast(getDecl())->needsNewVTableEntry()) return true; return false; @@ -784,8 +844,7 @@ SILDeclRef SILDeclRef::getOverridden() const { auto overridden = getDecl()->getOverriddenDecl(); if (!overridden) return SILDeclRef(); - - return SILDeclRef(overridden, kind); + return withDecl(overridden); } SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { @@ -837,6 +896,26 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { if (isa(overridden.getDecl()->getDeclContext())) return SILDeclRef(); + // JVPs/VJPs are overridden only if the base declaration has a + // `@differentiable` attribute with the same parameter indices. + if (derivativeFunctionIdentifier) { + auto overriddenAttrs = + overridden.getDecl()->getAttrs().getAttributes(); + for (const auto *attr : overriddenAttrs) { + if (attr->getParameterIndices() != + derivativeFunctionIdentifier->getParameterIndices()) + continue; + auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier; + overridden.derivativeFunctionIdentifier = + AutoDiffDerivativeFunctionIdentifier::get( + overriddenDerivativeId->getKind(), + overriddenDerivativeId->getParameterIndices(), + attr->getDerivativeGenericSignature(), + getDecl()->getASTContext()); + return overridden; + } + return SILDeclRef(); + } return overridden; } return SILDeclRef(); @@ -845,7 +924,7 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { SILDeclRef SILDeclRef::getOverriddenWitnessTableEntry() const { auto bestOverridden = getOverriddenWitnessTableEntry(cast(getDecl())); - return SILDeclRef(bestOverridden, kind); + return withDecl(bestOverridden); } AbstractFunctionDecl *SILDeclRef::getOverriddenWitnessTableEntry( diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 92a6102536e81..e1462d789dfda 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -3080,6 +3080,45 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, ::getUncachedSILFunctionTypeForConstant(*this, expansion, constant, loweredInterfaceType); + // If the constant refers to a derivative function, get the SIL type of the + // original function and use it to compute the derivative SIL type. + // + // This is necessary because the "lowered AST derivative function type" (BC) + // may differ from the "derivative type of the lowered original function type" + // (AD): + // + // +--------------------+ lowering +--------------------+ + // | AST orig. fn type | -------(A)------> | SIL orig. fn type | + // +--------------------+ +--------------------+ + // | | + // (B, Sema) getAutoDiffDerivativeFunctionType (D, here) + // V V + // +--------------------+ lowering +--------------------+ + // | AST deriv. fn type | -------(C)------> | SIL deriv. fn type | + // +--------------------+ +--------------------+ + // + // (AD) does not always commute with (BC): + // - (BC) is the result of computing the AST derivative type (Sema), then + // lowering it via SILGen. This is the default lowering behavior, but may + // break SIL typing invariants because expected lowered derivative types are + // computed from lowered original function types. + // - (AD) is the result of lowering the original function type, then computing + // its derivative type. This is the expected lowered derivative type, + // preserving SIL typing invariants. + // + // Always use (AD) to compute lowered derivative function types. + if (auto *derivativeId = constant.derivativeFunctionIdentifier) { + // Get lowered original function type. + auto origFnConstantInfo = getConstantInfo( + TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction()); + // Use it to compute lowered derivative function type. + auto *loweredIndices = autodiff::getLoweredParameterIndices( + derivativeId->getParameterIndices(), formalInterfaceType); + silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType( + loweredIndices, /*resultIndex*/ 0, derivativeId->getKind(), *this, + LookUpConformanceInModule(&M)); + } + LLVM_DEBUG(llvm::dbgs() << "lowering type for constant "; constant.print(llvm::dbgs()); llvm::dbgs() << "\n formal type: "; diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 67b21e2807f85..d8277da05a738 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -346,6 +346,23 @@ void SILDeclRef::print(raw_ostream &OS) const { if (isForeign) OS << (isDot ? '.' : '!') << "foreign"; + + if (derivativeFunctionIdentifier) { + OS << ((isDot || isForeign) ? '.' : '!'); + switch (derivativeFunctionIdentifier->getKind()) { + case AutoDiffDerivativeFunctionKind::JVP: + OS << "jvp."; + break; + case AutoDiffDerivativeFunctionKind::VJP: + OS << "vjp."; + break; + } + OS << derivativeFunctionIdentifier->getParameterIndices()->getString(); + if (auto derivativeGenSig = + derivativeFunctionIdentifier->getDerivativeGenericSignature()) { + OS << "." << derivativeGenSig; + } + } } void SILDeclRef::dump() const { diff --git a/lib/SIL/TypeLowering.cpp b/lib/SIL/TypeLowering.cpp index 4532f7de61e61..83ee4c9ec4013 100644 --- a/lib/SIL/TypeLowering.cpp +++ b/lib/SIL/TypeLowering.cpp @@ -2006,6 +2006,15 @@ getFunctionInterfaceTypeWithCaptures(TypeConverter &TC, } CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) { + if (auto *derivativeId = c.derivativeFunctionIdentifier) { + auto originalFnTy = + makeConstantInterfaceType(c.asAutoDiffOriginalFunction()); + auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( + derivativeId->getParameterIndices(), derivativeId->getKind(), + LookUpConformanceInModule(&M)); + return cast(derivativeFnTy->getCanonicalType()); + } + auto *vd = c.loc.dyn_cast(); switch (c.kind) { case SILDeclRef::Kind::Func: { diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h index 738ae33548892..287d656fe1261 100644 --- a/lib/SILGen/SILGen.h +++ b/lib/SILGen/SILGen.h @@ -224,6 +224,12 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { SILFunction *customDerivativeFn, SILFunction *originalFn, const AutoDiffConfig &config, AutoDiffDerivativeFunctionKind kind); + /// Get or create a derivative function vtable entry thunk for the given + /// SILDeclRef and derivative function type. + SILFunction * + getOrCreateAutoDiffClassMethodThunk(SILDeclRef derivativeFnRef, + CanSILFunctionType derivativeFnTy); + /// Determine whether we need to emit an ivar destroyer for the given class. /// An ivar destroyer is needed if a superclass of this class may define a /// failing designated initializer. diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 452646e60a776..4d48e8101d411 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -4454,8 +4454,15 @@ getWitnessFunctionRef(SILGenFunction &SGF, SILLocation loc) { switch (witnessKind) { case WitnessDispatchKind::Static: + if (auto *derivativeId = witness.derivativeFunctionIdentifier) { + // TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function` + // and `differentiable_function_extract`. + auto derivativeFnSilType = SILType::getPrimitiveObjectType(witnessFTy); + return SILUndef::get(derivativeFnSilType, SGF.F); + } return SGF.emitGlobalFunctionRef(loc, witness); case WitnessDispatchKind::Dynamic: + assert(!witness.derivativeFunctionIdentifier); return SGF.emitDynamicMethodRef(loc, witness, witnessFTy).getValue(); case WitnessDispatchKind::Witness: { auto typeAndConf = diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 6e22672eabd00..cfe7c3be00ae7 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -186,3 +186,42 @@ getOrCreateReabstractionThunk(CanSILFunctionType thunkType, loc, name, thunkDeclType, IsBare, IsTransparent, IsSerializable, ProfileCounter(), IsReabstractionThunk, IsNotDynamic); } + +SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( + SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) { + auto *derivativeId = derivativeFnDeclRef.derivativeFunctionIdentifier; + assert(derivativeId); + auto *derivativeFnDecl = derivativeFnDeclRef.getDecl(); + + SILGenFunctionBuilder builder(*this); + auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction(); + // TODO(TF-685): Use principled thunk mangling. + // Do not simply reuse reabstraction thunk mangling. + auto name = derivativeFnDeclRef.mangle() + "_vtable_entry_thunk"; + auto *thunk = builder.getOrCreateFunction( + derivativeFnDecl, name, originalFn.getLinkage(ForDefinition), constantTy, + IsBare, IsTransparent, derivativeFnDeclRef.isSerialized(), IsNotDynamic, + ProfileCounter(), IsThunk); + if (!thunk->empty()) + return thunk; + + if (auto genSig = constantTy->getSubstGenericSignature()) + thunk->setGenericEnvironment(genSig->getGenericEnvironment()); + SILGenFunction SGF(*this, *thunk, SwiftModule); + SmallVector params; + auto loc = derivativeFnDeclRef.getAsRegularLocation(); + SGF.collectThunkParams(loc, params); + + // TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function` and + // `differentiable_function_extract`. + auto derivativeSilTy = SILType::getPrimitiveObjectType(constantTy); + auto derivativeFn = SILUndef::get(derivativeSilTy, *thunk); + SmallVector args(thunk->getArguments().begin(), + thunk->getArguments().end()); + auto apply = + SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeSilTy, + SGF.getForwardingSubstitutionMap(), args); + SGF.B.createReturn(loc, apply); + + return thunk; +} diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 23ee52279a411..3f19296f7478d 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -90,6 +90,14 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass, implFn = getDynamicThunk( derived, Types.getConstantInfo(TypeExpansionContext::minimal(), derived) .SILFnType); + } else if (auto *derivativeId = derived.derivativeFunctionIdentifier) { + // For JVP/VJP methods, create a vtable entry thunk. The thunk contains an + // `differentiable_function` instruction, which is later filled during the + // differentiation transform. + auto derivedFnType = + Types.getConstantInfo(TypeExpansionContext::minimal(), derived) + .SILFnType; + implFn = getOrCreateAutoDiffClassMethodThunk(derived, derivedFnType); } else { implFn = getFunction(derived, NotForDefinition); } @@ -159,6 +167,17 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass, cast(derivedDecl), base.kind == SILDeclRef::Kind::Allocator); } + // TODO(TF-685): Use proper autodiff thunk mangling. + if (auto *derivativeId = derived.derivativeFunctionIdentifier) { + switch (derivativeId->getKind()) { + case AutoDiffDerivativeFunctionKind::JVP: + name += "_jvp"; + break; + case AutoDiffDerivativeFunctionKind::VJP: + name += "_vjp"; + break; + } + } } // If we already emitted this thunk, reuse it. @@ -373,10 +392,9 @@ template class SILGenWitnessTable : public SILWitnessVisitor { // If it's not an accessor, just look for the witness. if (!reqAccessor) { if (auto witness = asDerived().getWitness(requirementRef.getDecl())) { - return addMethodImplementation(requirementRef, - SILDeclRef(witness.getDecl(), - requirementRef.kind), - witness); + return addMethodImplementation( + requirementRef, requirementRef.withDecl(witness.getDecl()), + witness); } return asDerived().addMissingMethod(requirementRef); @@ -395,10 +413,8 @@ template class SILGenWitnessTable : public SILWitnessVisitor { auto witnessAccessor = witnessStorage->getSynthesizedAccessor(reqAccessor->getAccessorKind()); - return addMethodImplementation(requirementRef, - SILDeclRef(witnessAccessor, - SILDeclRef::Kind::Func), - witness); + return addMethodImplementation( + requirementRef, requirementRef.withDecl(witnessAccessor), witness); } private: @@ -690,6 +706,20 @@ SILFunction *SILGenModule::emitProtocolWitness( conformance.isConcrete() ? conformance.getConcrete() : nullptr; std::string nameBuffer = NewMangler.mangleWitnessThunk(manglingConformance, requirement.getDecl()); + // TODO(TF-685): Proper mangling for derivative witness thunks. + if (auto *derivativeId = requirement.derivativeFunctionIdentifier) { + std::string kindString; + switch (derivativeId->getKind()) { + case AutoDiffDerivativeFunctionKind::JVP: + kindString = "jvp"; + break; + case AutoDiffDerivativeFunctionKind::VJP: + kindString = "vjp"; + break; + } + nameBuffer = "AD__" + nameBuffer + "_" + kindString + "_" + + derivativeId->getParameterIndices()->getString(); + } // If the thunked-to function is set to be always inlined, do the // same with the witness, on the theory that the user wants all diff --git a/lib/SILOptimizer/CMakeLists.txt b/lib/SILOptimizer/CMakeLists.txt index 696616bf009ac..f05c7b8a22364 100644 --- a/lib/SILOptimizer/CMakeLists.txt +++ b/lib/SILOptimizer/CMakeLists.txt @@ -34,7 +34,6 @@ add_subdirectory(UtilityPasses) add_subdirectory(Utils) add_swift_host_library(swiftSILOptimizer STATIC - SILOptimizerRequests.cpp ${SILOPTIMIZER_SOURCES}) target_link_libraries(swiftSILOptimizer PRIVATE swiftSIL) diff --git a/lib/SILOptimizer/PassManager/CMakeLists.txt b/lib/SILOptimizer/PassManager/CMakeLists.txt index 80a438a9d95d3..6db4ea12d8504 100644 --- a/lib/SILOptimizer/PassManager/CMakeLists.txt +++ b/lib/SILOptimizer/PassManager/CMakeLists.txt @@ -3,4 +3,5 @@ silopt_register_sources( Passes.cpp PassPipeline.cpp PrettyStackTrace.cpp + SILOptimizerRequests.cpp ) diff --git a/lib/SILOptimizer/SILOptimizerRequests.cpp b/lib/SILOptimizer/PassManager/SILOptimizerRequests.cpp similarity index 100% rename from lib/SILOptimizer/SILOptimizerRequests.cpp rename to lib/SILOptimizer/PassManager/SILOptimizerRequests.cpp diff --git a/test/AutoDiff/SIL/Parse/sildeclref_parse.sil b/test/AutoDiff/SIL/Parse/sildeclref_parse.sil new file mode 100644 index 0000000000000..9c7452949d5ae --- /dev/null +++ b/test/AutoDiff/SIL/Parse/sildeclref_parse.sil @@ -0,0 +1,55 @@ +// RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -module-name=sildeclref_parse | %target-sil-opt -enable-experimental-differentiable-programming -module-name=sildeclref_parse | %FileCheck %s +// REQUIRES: differentiable_programming + +// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions. + +import Swift +import _Differentiation + +protocol Protocol { + @differentiable(wrt: (x, y)) + func f(_ x: Float, _ y: Float) -> Float +} + +class Class { + @differentiable(wrt: (x, y) where T: Differentiable) + func f(_ x: T, _ y: Float) -> T +} + +// CHECK-LABEL: sil hidden @witness_method +sil hidden @witness_method : $@convention(thin) (@in T) -> () { +bb0(%0 : $*T): + // CHECK: witness_method $T, #Protocol.f + %1 = witness_method $T, #Protocol.f : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + + // CHECK: witness_method $T, #Protocol.f!jvp.SSS + %2 = witness_method $T, #Protocol.f!jvp.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + + // CHECK: witness_method $T, #Protocol.f!jvp.UUS + %3 = witness_method $T, #Protocol.f!jvp.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + + // CHECK: witness_method $T, #Protocol.f!vjp.SSS + %4 = witness_method $T, #Protocol.f!vjp.SSS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + + // CHECK: witness_method $T, #Protocol.f!vjp.UUS + %5 = witness_method $T, #Protocol.f!vjp.UUS : (Self) -> (Float, Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (@in_guaranteed τ_0_0) -> (Float, Float) -> Float + + %6 = tuple () + return %6 : $() +} + +// CHECK-LABEL: sil hidden @class_method +sil hidden @class_method : $@convention(thin) (@guaranteed Class) -> () { +bb0(%0 : $Class): + // CHECK: class_method %0 : $Class, #Class.f + %1 = class_method %0 : $Class, #Class.f : (Class) -> (T, Float) -> T, $@convention(method) <τ_0_0> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> @out τ_0_0 + + // CHECK: class_method %0 : $Class, #Class.f!jvp.SSU + %2 = class_method %0 : $Class, #Class.f!jvp.SSU. : (Class) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>) + + // CHECK: class_method %0 : $Class, #Class.f!vjp.SSU + %3 = class_method %0 : $Class, #Class.f!vjp.SSU. : (Class) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>) + + %6 = tuple () + return %6 : $() +} diff --git a/test/AutoDiff/SILGen/vtable.swift b/test/AutoDiff/SILGen/vtable.swift new file mode 100644 index 0000000000000..6763012887d3d --- /dev/null +++ b/test/AutoDiff/SILGen/vtable.swift @@ -0,0 +1,146 @@ +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-silgen %s | %FileCheck %s +// REQUIRES: differentiable_programming + +// Test derivative function vtable entries for `@differentiable` class members: +// - Methods. +// - Accessors (from properties and subscripts). +// - Initializers. + +import _Differentiation + +// Dummy `Differentiable`-conforming type. +struct DummyTangentVector: Differentiable & AdditiveArithmetic { + // FIXME(TF-648): Dummy to make `Super.TangentVector` be nontrivial. + var _nontrivial: [Float] = [] + + static var zero: Self { Self() } + static func + (_: Self, _: Self) -> Self { Self() } + static func - (_: Self, _: Self) -> Self { Self() } + typealias TangentVector = Self +} + +class Super: Differentiable { + typealias TangentVector = DummyTangentVector + func move(along _: TangentVector) {} + + var base: Float + // FIXME(TF-648): Dummy to make `Super.TangentVector` be nontrivial. + var _nontrivial: [Float] = [] + + init(base: Float) { + self.base = base + } + + @differentiable(wrt: x) + func method(_ x: Float, _ y: Float) -> Float { + return x + } + + @differentiable(wrt: x where T: Differentiable) + func genericMethod(_ x: T, _ y: T) -> T { + return x + } + + @differentiable + var property: Float { base } + + @differentiable(wrt: x) + subscript(_ x: Float, _ y: Float) -> Float { + return x + } +} + +class Sub: Super { + override init(base: Float) { + super.init(base: base) + } + + // Override JVP for `method` wrt `x`. + @derivative(of: method, wrt: x) + @derivative(of: subscript, wrt: x) + final func jvpMethod(_ x: Float, _ y: Float) -> (value: Float, differential: (Float) -> Float) { + fatalError() + } + // Override VJP for `method` wrt `x`. + @derivative(of: method, wrt: x) + @derivative(of: subscript, wrt: x) + final func vjpMethod(_ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> (Float)) { + fatalError() + } + + // Override derivatives for `method` wrt `x`. + // FIXME(TF-1203): This `@differentiable` attribute should not be necessary to + // override derivatives. Fix `derivativeFunctionRequiresNewVTableEntry` to + // account for derived declaration `@derivative` attributes. + @differentiable(wrt: x) + // Add new derivatives for `method` wrt `(x, y)`. + @differentiable(wrt: (x, y)) + override func method(_ x: Float, _ y: Float) -> Float { + return x + } + + // Override derivatives for `property` wrt `self`. + @differentiable + override var property: Float { base } + @derivative(of: property) + final func vjpProperty() -> (value: Float, pullback: (Float) -> TangentVector) { + fatalError() + } + + // Override derivatives for `subscript` wrt `x`. + @differentiable(wrt: x) + override subscript(_ x: Float, _ y: Float) -> Float { + return x + } +} + +class SubSub: Sub {} + +// CHECK-LABEL: sil_vtable Super { +// CHECK: #Super.method: (Super) -> (Float, Float) -> Float : @$s6vtable5SuperC6methodyS2f_SftF +// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK: #Super.genericMethod: (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF +// CHECK: #Super.genericMethod!jvp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk +// CHECK: #Super.genericMethod!vjp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk +// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable5SuperC8propertySfvg +// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable5SuperC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable5SuperC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK: #Super.subscript!getter: (Super) -> (Float, Float) -> Float : @$s6vtable5SuperCyS2f_Sftcig +// CHECK: #Super.subscript!getter.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk +// CHECK: #Super.subscript!getter.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk +// CHECK: } + +// CHECK-LABEL: sil_vtable Sub { +// CHECK: #Super.method: (Super) -> (Float, Float) -> Float : @$s6vtable3SubC6methodyS2f_SftF [override] +// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Super.genericMethod: (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF [inherited] +// CHECK: #Super.genericMethod!jvp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited] +// CHECK: #Super.genericMethod!vjp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited] +// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable3SubC8propertySfvg [override] +// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Super.subscript!getter: (Super) -> (Float, Float) -> Float : @$s6vtable3SubCyS2f_Sftcig [override] +// CHECK: #Super.subscript!getter.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Super.subscript!getter.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [override] +// CHECK: #Sub.method!jvp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk +// CHECK: #Sub.method!vjp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk +// CHECK: } + +// CHECK-LABEL: sil_vtable SubSub { +// CHECK: #Super.method: (Super) -> (Float, Float) -> Float : @$s6vtable3SubC6methodyS2f_SftF [inherited] +// CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Super.method!vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Super.genericMethod: (Super) -> (T, T) -> T : @$s6vtable5SuperC13genericMethodyxx_xtlF [inherited] +// CHECK: #Super.genericMethod!jvp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__jvp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited] +// CHECK: #Super.genericMethod!vjp.SUU.: (Super) -> (T, T) -> T : @AD__$s6vtable5SuperC13genericMethodyxx_xtlF__vjp_src_0_wrt_0_16_Differentiation14DifferentiableRzl_vtable_entry_thunk [inherited] +// CHECK: #Super.property!getter: (Super) -> () -> Float : @$s6vtable3SubC8propertySfvg [inherited] +// CHECK: #Super.property!getter.jvp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Super.property!getter.vjp.S: (Super) -> () -> Float : @AD__$s6vtable3SubC8propertySfvg__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Super.subscript!getter: (Super) -> (Float, Float) -> Float : @$s6vtable3SubCyS2f_Sftcig [inherited] +// CHECK: #Super.subscript!getter.jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubCyS2f_Sftcig__jvp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Super.subscript!getter.vjp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable3SubCyS2f_Sftcig__vjp_src_0_wrt_0_vtable_entry_thunk [inherited] +// CHECK: #Sub.method!jvp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__jvp_src_0_wrt_0_1_vtable_entry_thunk [inherited] +// CHECK: #Sub.method!vjp.SSU: (Sub) -> (Float, Float) -> Float : @AD__$s6vtable3SubC6methodyS2f_SftF__vjp_src_0_wrt_0_1_vtable_entry_thunk [inherited] +// CHECK: } diff --git a/test/AutoDiff/SILGen/witness_table.swift b/test/AutoDiff/SILGen/witness_table.swift new file mode 100644 index 0000000000000..65ebf5bf2c99c --- /dev/null +++ b/test/AutoDiff/SILGen/witness_table.swift @@ -0,0 +1,85 @@ +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-silgen %s | %FileCheck %s +// REQUIRES: differentiable_programming + +// Test derivative function witness table entries for `@differentiable` protocol requirements. + +import _Differentiation + +protocol Protocol: Differentiable { + @differentiable(wrt: (self, x, y)) + @differentiable(wrt: x) + func method(_ x: Float, _ y: Double) -> Float + + @differentiable + var property: Float { get set } + + @differentiable(wrt: x) + subscript(_ x: Float, _ y: Float) -> Float { get set } +} + +// Dummy `Differentiable`-conforming type. +struct DummyTangentVector: Differentiable & AdditiveArithmetic { + static var zero: Self { Self() } + static func + (_: Self, _: Self) -> Self { Self() } + static func - (_: Self, _: Self) -> Self { Self() } + typealias TangentVector = Self +} + +struct Struct: Protocol { + typealias TangentVector = DummyTangentVector + mutating func move(along _: TangentVector) {} + + @differentiable(wrt: (self, x, y)) + @differentiable(wrt: x) + func method(_ x: Float, _ y: Double) -> Float { + return x + } + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, Double, @in_guaranteed τ_0_0) -> Float for ) { + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, Double, @out τ_0_0) for ) { + + @differentiable + var property: Float { + get { 1 } + set {} + } + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_jvp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for ) { + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_vjp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for ) { + + + @differentiable(wrt: x) + subscript(_ x: Float, _ y: Float) -> Float { + get { x } + set {} + } + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +} + +// CHECK-LABEL: sil_witness_table hidden Struct: Protocol module witness_table { +// CHECK-NEXT: base_protocol Differentiable: Struct: Differentiable module witness_table +// CHECK-NEXT: method #Protocol.method: (Self) -> (Float, Double) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDP6methodyS2f_SdtFTW +// CHECK-NEXT: method #Protocol.method!jvp.SUU.: (Self) -> (Float, Double) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP6methodyS2f_SdtFTW_jvp_SUU +// CHECK-NEXT: method #Protocol.method!vjp.SUU.: (Self) -> (Float, Double) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP6methodyS2f_SdtFTW_vjp_SUU +// CHECK-NEXT: method #Protocol.method!jvp.SSS.: (Self) -> (Float, Double) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP6methodyS2f_SdtFTW_jvp_SSS +// CHECK-NEXT: method #Protocol.method!vjp.SSS.: (Self) -> (Float, Double) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP6methodyS2f_SdtFTW_vjp_SSS +// CHECK-NEXT: method #Protocol.property!getter: (Self) -> () -> Float : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvgTW +// CHECK-NEXT: method #Protocol.property!getter.jvp.S.: (Self) -> () -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvgTW_jvp_S +// CHECK-NEXT: method #Protocol.property!getter.vjp.S.: (Self) -> () -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvgTW_vjp_S +// CHECK-NEXT: method #Protocol.property!setter: (inout Self) -> (Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW +// CHECK-NEXT: method #Protocol.property!modify: (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW +// CHECK-NEXT: method #Protocol.subscript!getter: (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW +// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU +// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU +// CHECK-NEXT: method #Protocol.subscript!setter: (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW +// CHECK-NEXT: method #Protocol.subscript!modify: (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW +// CHECK: }