diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0b5f322eed1be..d78f26e89a43b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -14,6 +14,14 @@ jobs: runs-on: ubuntu-18.04 steps: + - name: Free disk space + run: | + df -h + sudo swapoff -a + sudo rm -f /swapfile + sudo apt clean + docker rmi $(docker image ls -aq) + df -h - uses: actions/checkout@v1 with: path: swift diff --git a/benchmark/utils/ObjectiveCTests/module.map b/benchmark/utils/ObjectiveCTests/module.modulemap similarity index 100% rename from benchmark/utils/ObjectiveCTests/module.map rename to benchmark/utils/ObjectiveCTests/module.modulemap diff --git a/docs/SIL.rst b/docs/SIL.rst index 337339c061798..89f917c61a29b 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5774,6 +5774,67 @@ The rules on generic substitutions are identical to those of ``apply``. Differentiable Programming ~~~~~~~~~~~~~~~~~~~~~~~~~~ +differentiable_function +``````````````````````` +:: + + sil-instruction ::= 'differentiable_function' + sil-differentiable-function-parameter-indices + sil-value ':' sil-type + sil-differentiable-function-derivative-functions-clause? + + sil-differentiable-function-parameter-indices ::= + '[' 'parameters' [0-9]+ (' ' [0-9]+)* ']' + sil-differentiable-derivative-functions-clause ::= + 'with_derivative' + '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}' + + differentiable_function [parameters 0] %0 : $(T) -> T \ + with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)} + +Creates a ``@differentiable`` function from an original function operand and +derivative function operands (optional). There are two derivative function +kinds: a Jacobian-vector products (JVP) function and a vector-Jacobian products +(VJP) function. + +``[parameters ...]`` specifies parameter indices that the original function is +differentiable with respect to. + +The ``with_derivative`` clause specifies the derivative function operands +associated with the original function. + +The differentiation transformation canonicalizes all `differentiable_function` +instructions, generating derivative functions if necessary to fill in derivative +function operands. + +In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the +``with_derivative`` clause is mandatory. + + +differentiable_function_extract +``````````````````````````````` +:: + + sil-instruction ::= 'differentiable_function_extract' + '[' sil-differentiable-function-extractee ']' + sil-value ':' sil-type + ('as' sil-type)? + + sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp' + + differentiable_function_extract [original] %0 : $@differentiable (T) -> T + differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T + differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T + differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \ + as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector) + +Extracts the original function or a derivative function from the given +``@differentiable`` function. The extractee is one of the following: +``[original]``, ``[jvp]``, or ``[vjp]``. + +In lowered SIL, an explicit extractee type may be provided. This is currently +used by the LoadableByAddress transformation, which rewrites function types. + differentiability_witness_function `````````````````````````````````` :: diff --git a/include/swift/AST/ASTMangler.h b/include/swift/AST/ASTMangler.h index 434666d871a44..2aa33dec89060 100644 --- a/include/swift/AST/ASTMangler.h +++ b/include/swift/AST/ASTMangler.h @@ -370,6 +370,9 @@ class ASTMangler : public Mangler { void appendProtocolConformance(const ProtocolConformance *conformance); void appendProtocolConformanceRef(const RootProtocolConformance *conformance); + void appendAnyProtocolConformance(CanGenericSignature genericSig, + CanType conformingType, + ProtocolConformanceRef conformance); void appendConcreteProtocolConformance( const ProtocolConformance *conformance); void appendDependentProtocolConformance(const ConformanceAccessPath &path); diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index e2976e1bfc7ea..a6583e4311a06 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1675,12 +1675,11 @@ struct DeclNameRefWithLoc { DeclNameLoc Loc; }; -/// Attribute that marks a function as differentiable and optionally specifies -/// custom associated derivative functions: 'jvp' and 'vjp'. +/// Attribute that marks a function as differentiable. /// /// Examples: -/// @differentiable(jvp: jvpFoo where T : FloatingPoint) -/// @differentiable(wrt: (self, x, y), jvp: jvpFoo) +/// @differentiable(where T : FloatingPoint) +/// @differentiable(wrt: (self, x, y)) class DifferentiableAttr final : public DeclAttribute, private llvm::TrailingObjects JVP; - /// The VJP function. - Optional VJP; - /// The JVP function (optional), resolved by the type checker if JVP name is - /// specified. - FuncDecl *JVPFunction = nullptr; - /// The VJP function (optional), resolved by the type checker if VJP name is - /// specified. - FuncDecl *VJPFunction = nullptr; /// The differentiability parameter indices, resolved by the type checker. /// The bit stores whether the parameter indices have been computed. /// @@ -1720,19 +1709,22 @@ class DifferentiableAttr final /// attribute's where clause requirements. This is set only if the attribute /// has a where clause. GenericSignature DerivativeGenericSignature; + /// The source location of the implicitly inherited protocol requirement + /// `@differentiable` attribute. Used for diagnostics, not serialized. + /// + /// This is set during conformance type-checking, only for implicit + /// `@differentiable` attributes created for non-public protocol witnesses of + /// protocol requirements with `@differentiable` attributes. + SourceLoc ImplicitlyInheritedDifferentiableAttrLocation; explicit DifferentiableAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, ArrayRef parameters, - Optional jvp, - Optional vjp, TrailingWhereClause *clause); explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *parameterIndices, - Optional jvp, - Optional vjp, GenericSignature derivativeGenericSignature); public: @@ -1740,16 +1732,12 @@ class DifferentiableAttr final SourceLoc atLoc, SourceRange baseRange, bool linear, ArrayRef params, - Optional jvp, - Optional vjp, TrailingWhereClause *clause); static DifferentiableAttr *create(AbstractFunctionDecl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *parameterIndices, - Optional jvp, - Optional vjp, GenericSignature derivativeGenSig); Decl *getOriginalDeclaration() const { return OriginalDeclaration; } @@ -1758,16 +1746,6 @@ class DifferentiableAttr final /// Should only be used by parsing and deserialization. void setOriginalDeclaration(Decl *originalDeclaration); - /// Get the optional 'jvp:' function name and location. - /// Use this instead of `getJVPFunction` to check whether the attribute has a - /// registered JVP. - Optional getJVP() const { return JVP; } - - /// Get the optional 'vjp:' function name and location. - /// Use this instead of `getVJPFunction` to check whether the attribute has a - /// registered VJP. - Optional getVJP() const { return VJP; } - private: /// Returns true if the given `@differentiable` attribute has been /// type-checked. @@ -1800,10 +1778,13 @@ class DifferentiableAttr final DerivativeGenericSignature = derivativeGenSig; } - FuncDecl *getJVPFunction() const { return JVPFunction; } - void setJVPFunction(FuncDecl *decl); - FuncDecl *getVJPFunction() const { return VJPFunction; } - void setVJPFunction(FuncDecl *decl); + SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const { + return ImplicitlyInheritedDifferentiableAttrLocation; + } + void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) { + assert(isImplicit()); + ImplicitlyInheritedDifferentiableAttrLocation = loc; + } /// Get the derivative generic environment for the given `@differentiable` /// attribute and original function. @@ -1812,9 +1793,7 @@ class DifferentiableAttr final // Print the attribute to the given stream. // If `omitWrtClause` is true, omit printing the `wrt:` clause. - // If `omitDerivativeFunctions` is true, omit printing derivative functions. - void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false, - bool omitDerivativeFunctions = false) const; + void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false) const; static bool classof(const DeclAttribute *DA) { return DA->getKind() == DAK_Differentiable; diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index a0ec63684162b..f8f3423e9b110 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -75,6 +75,41 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// A component of a SIL `@differentiable` function-typed value. +struct NormalDifferentiableFunctionTypeComponent { + enum innerty : unsigned { Original = 0, JVP = 1, VJP = 2 } rawValue; + + NormalDifferentiableFunctionTypeComponent() = default; + NormalDifferentiableFunctionTypeComponent(innerty rawValue) + : rawValue(rawValue) {} + NormalDifferentiableFunctionTypeComponent( + AutoDiffDerivativeFunctionKind kind); + explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) + : NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {} + explicit NormalDifferentiableFunctionTypeComponent(StringRef name); + operator innerty() const { return rawValue; } + + /// Returns the derivative function kind, if the component is a derivative + /// function. + Optional getAsDerivativeFunctionKind() const; +}; + +/// A component of a SIL `@differentiable(linear)` function-typed value. +struct LinearDifferentiableFunctionTypeComponent { + enum innerty : unsigned { + Original = 0, + Transpose = 1, + } rawValue; + + LinearDifferentiableFunctionTypeComponent() = default; + LinearDifferentiableFunctionTypeComponent(innerty rawValue) + : rawValue(rawValue) {} + explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) + : LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {} + explicit LinearDifferentiableFunctionTypeComponent(StringRef name); + operator innerty() const { return rawValue; } +}; + /// A derivative function configuration, uniqued in `ASTContext`. /// Identifies a specific derivative function given an original function. class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode { @@ -406,6 +441,33 @@ GenericSignature getConstrainedDerivativeGenericSignature( GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance, bool isTranspose = false); +/// Retrieve config from the function name of a variant of +/// `Builtin.applyDerivative`, e.g. `Builtin.applyDerivative_jvp_arity2`. +/// Returns true if the function name is parsed successfully. +bool getBuiltinApplyDerivativeConfig( + StringRef operationName, AutoDiffDerivativeFunctionKind &kind, + unsigned &arity, bool &rethrows); + +/// Retrieve config from the function name of a variant of +/// `Builtin.applyTranspose`, e.g. `Builtin.applyTranspose_arity2`. +/// Returns true if the function name is parsed successfully. +bool getBuiltinApplyTransposeConfig( + StringRef operationName, unsigned &arity, bool &rethrows); + +/// Retrieve config from the function name of a variant of +/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g. +/// `Builtin.differentiableFunction_arity1_throws`. +/// Returns true if the function name is parsed successfully. +bool getBuiltinDifferentiableOrLinearFunctionConfig( + StringRef operationName, unsigned &arity, bool &throws); + +/// Retrieve config from the function name of a variant of +/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g. +/// `Builtin.differentiableFunction_arity1_throws`. +/// Returns true if the function name is parsed successfully. +bool getBuiltinDifferentiableOrLinearFunctionConfig( + StringRef operationName, unsigned &arity, bool &throws); + } // end namespace autodiff } // end namespace swift diff --git a/include/swift/AST/Builtins.def b/include/swift/AST/Builtins.def index 483cb15a7cc54..1510688d37e3d 100644 --- a/include/swift/AST/Builtins.def +++ b/include/swift/AST/Builtins.def @@ -469,6 +469,18 @@ BUILTIN_SIL_OPERATION(ConvertStrongToUnownedUnsafe, "convertStrongToUnownedUnsaf /// now. BUILTIN_SIL_OPERATION(ConvertUnownedUnsafeToGuaranteed, "convertUnownedUnsafeToGuaranteed", Special) +/// applyDerivative +BUILTIN_SIL_OPERATION(ApplyDerivative, "applyDerivative", Special) + +/// applyTranspose +BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special) + +/// differentiableFunction +BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special) + +/// linearFunction +BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special) + #undef BUILTIN_SIL_OPERATION // BUILTIN_RUNTIME_CALL - A call into a runtime function. diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 3821c33c631bf..e3f3df4db22d7 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -3516,6 +3516,27 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// or \c nullptr if it does not have one. ConstructorDecl *getMemberwiseInitializer() const; + /// Retrieves the effective memberwise initializer for this declaration, or + /// \c nullptr if it does not have one. + /// + /// An effective memberwise initializer is either a synthesized memberwise + /// initializer or a user-defined initializer with the same type. + /// + /// The access level of the memberwise initializer is set to the minimum of: + /// - Public, by default. This enables public nominal types to have public + /// memberwise initializers. + /// - The `public` default is important for synthesized member types, e.g. + /// `TangentVector` structs synthesized during `Differentiable` derived + /// conformances. Manually extending these types to define a public + /// memberwise initializer causes a redeclaration error. + /// - The minimum access level of memberwise-initialized properties in the + /// nominal type declaration. + /// + /// Effective memberwise initializers are used only by derived conformances + /// for `Self`-returning protocol requirements like `AdditiveArithmetic.+`. + /// Such derived conformances require memberwise initialization. + ConstructorDecl *getEffectiveMemberwiseInitializer(); + /// Whether this declaration has a synthesized zero parameter default /// initializer. bool hasDefaultInitializer() const; @@ -4249,8 +4270,6 @@ class ProtocolDecl final : public NominalTypeDecl { Bits.ProtocolDecl.ExistentialTypeSupported = supported; } - ArrayRef getInheritedProtocolsSlow(); - bool hasLazyRequirementSignature() const { return Bits.ProtocolDecl.HasLazyRequirementSignature; } @@ -4261,7 +4280,8 @@ class ProtocolDecl final : public NominalTypeDecl { friend class ProtocolRequiresClassRequest; friend class ExistentialConformsToSelfRequest; friend class ExistentialTypeSupportedRequest; - + friend class InheritedProtocolsRequest; + public: ProtocolDecl(DeclContext *DC, SourceLoc ProtocolLoc, SourceLoc NameLoc, Identifier Name, MutableArrayRef Inherited, @@ -4270,12 +4290,7 @@ class ProtocolDecl final : public NominalTypeDecl { using Decl::getASTContext; /// Retrieve the set of protocols inherited from this protocol. - ArrayRef getInheritedProtocols() const { - if (Bits.ProtocolDecl.InheritedProtocolsValid) - return InheritedProtocols; - - return const_cast(this)->getInheritedProtocolsSlow(); - } + ArrayRef getInheritedProtocols() const; /// Determine whether this protocol has a superclass. bool hasSuperclass() const { return (bool)getSuperclassDecl(); } @@ -4370,6 +4385,13 @@ class ProtocolDecl final : public NominalTypeDecl { private: void computeKnownProtocolKind() const; + bool areInheritedProtocolsValid() const { + return Bits.ProtocolDecl.InheritedProtocolsValid; + } + void setInheritedProtocolsValid() { + Bits.ProtocolDecl.InheritedProtocolsValid = true; + } + public: /// If this is known to be a compiler-known protocol, returns the kind. /// Otherwise returns None. @@ -7063,6 +7085,28 @@ class PrecedenceGroupDecl : public Decl { } }; +/// The fixity of an OperatorDecl. +enum class OperatorFixity : uint8_t { + Infix, + Prefix, + Postfix +}; + +inline void simple_display(llvm::raw_ostream &out, OperatorFixity fixity) { + switch (fixity) { + case OperatorFixity::Infix: + out << "infix"; + return; + case OperatorFixity::Prefix: + out << "prefix"; + return; + case OperatorFixity::Postfix: + out << "postfix"; + return; + } + llvm_unreachable("Unhandled case in switch"); +} + /// Abstract base class of operator declarations. class OperatorDecl : public Decl { SourceLoc OperatorLoc, NameLoc; @@ -7088,6 +7132,21 @@ class OperatorDecl : public Decl { : Decl(kind, DC), OperatorLoc(OperatorLoc), NameLoc(NameLoc), name(Name), DesignatedNominalTypes(DesignatedNominalTypes) {} + /// Retrieve the operator's fixity, corresponding to the concrete subclass + /// of the OperatorDecl. + OperatorFixity getFixity() const { + switch (getKind()) { +#define DECL(Id, Name) case DeclKind::Id: llvm_unreachable("Not an operator!"); +#define OPERATOR_DECL(Id, Name) +#include "swift/AST/DeclNodes.def" + case DeclKind::InfixOperator: + return OperatorFixity::Infix; + case DeclKind::PrefixOperator: + return OperatorFixity::Prefix; + case DeclKind::PostfixOperator: + return OperatorFixity::Postfix; + } + } SourceLoc getOperatorLoc() const { return OperatorLoc; } SourceLoc getNameLoc() const { return NameLoc; } diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index bca7857545d2c..fd0fecf1cb6ae 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1582,24 +1582,14 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken, "expected a member name as second parameter in '_implements' attribute", ()) // differentiable -// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. -ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken, - "expected a %0 function name", (StringRef)) ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken, "expected a list of parameters to differentiate with respect to", ()) -// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. ERROR(attr_differentiable_use_wrt_not_withrespectto,none, "use 'wrt:' to specify parameters to differentiate with respect to", ()) ERROR(attr_differentiable_expected_label,none, - "expected either 'wrt:' or a function specifier label, e.g. 'jvp:', " - "or 'vjp:'", ()) + "expected 'wrt:' or 'where' in '@differentiable' attribute", ()) ERROR(attr_differentiable_unexpected_argument,none, "unexpected argument '%0' in '@differentiable' attribute", (StringRef)) -// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed. -WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none, - "'jvp:' and 'vjp:' arguments in '@differentiable' attribute are " - "deprecated; use '@derivative' attribute for derivative registration " - "instead", ()) // differentiation `wrt` parameters clause ERROR(expected_colon_after_label,PointsToFirstBadToken, @@ -1628,6 +1618,17 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken, "expected the index of a parameter to differentiate with respect to", ()) ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken, "expected the index of a result to differentiate from", ()) +ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken, + "expected '{' to start a derivative function list", ()) +ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken, + "expected ',' between operands in a derivative function list", ()) +ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken, + "expected '}' to start a derivative function list", ()) +ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken, + "expected an extractee kind attribute, which can be one of '[original]', " + "'[jvp]', and '[vjp]'", ()) +ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken, + "expected an operand of a function type", ()) ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken, "expected a differentiability witness kind, which can be one of '[jvp]', " "'[vjp]', or '[transpose]'", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 380b28e49d980..43aca22824906 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2674,6 +2674,8 @@ ERROR(cannot_synthesize_in_crossfile_extension,none, "implementation of %0 cannot be automatically synthesized in an extension " "in a different file to the type", (Type)) +ERROR(broken_additive_arithmetic_requirement,none, + "AdditiveArithmetic protocol is broken: unexpected requirement", ()) ERROR(broken_case_iterable_requirement,none, "CaseIterable protocol is broken: unexpected requirement", ()) ERROR(broken_raw_representable_requirement,none, @@ -2914,6 +2916,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none, ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none, "cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' " "attribute for transpose registration instead", ()) +ERROR(differentiable_attr_void_result,none, + "cannot differentiate void function %0", (DeclName)) ERROR(differentiable_attr_overload_not_found,none, "%0 does not have expected type %1", (DeclNameRef, Type)) // TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also @@ -2938,12 +2942,6 @@ ERROR(differentiable_attr_result_not_differentiable,none, ERROR(differentiable_attr_protocol_req_where_clause,none, "'@differentiable' attribute on protocol requirement cannot specify " "'where' clause", ()) -ERROR(differentiable_attr_protocol_req_assoc_func,none, - "'@differentiable' attribute on protocol requirement cannot specify " - "'jvp:' or 'vjp:'", ()) -ERROR(differentiable_attr_stored_property_variable_unsupported,none, - "'@differentiable' attribute on stored property cannot specify " - "'jvp:' or 'vjp:'", ()) ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none, "'@differentiable' attribute cannot be declared on class members " "returning 'Self'", ()) @@ -2962,6 +2960,12 @@ ERROR(overriding_decl_missing_differentiable_attr,none, "overriding declaration is missing attribute '%0'", (StringRef)) NOTE(protocol_witness_missing_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) +NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none, + "non-public %1 %2 must have explicit '%0' attribute to satisfy " + "requirement %3 %4 (in protocol %6) because it is declared in a different " + "file than the conformance of %5 to %6", + (StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName, + Type, Type)) // @derivative ERROR(derivative_attr_expected_result_tuple,none, diff --git a/include/swift/AST/FileUnit.h b/include/swift/AST/FileUnit.h index 70237b584ab51..9395d6b2df401 100644 --- a/include/swift/AST/FileUnit.h +++ b/include/swift/AST/FileUnit.h @@ -30,6 +30,9 @@ class FileUnit : public DeclContext { #pragma clang diagnostic pop virtual void anchor(); + friend class DirectOperatorLookupRequest; + friend class DirectPrecedenceGroupLookupRequest; + // FIXME: Stick this in a PointerIntPair. const FileUnitKind Kind; @@ -107,6 +110,25 @@ class FileUnit : public DeclContext { const ModuleDecl *importedModule, SmallVectorImpl &spiGroups) const {}; +protected: + /// Look up an operator declaration. Do not call directly, use + /// \c DirectOperatorLookupRequest instead. + /// + /// \param name The operator name ("+", ">>", etc.) + /// + /// \param fixity One of Prefix, Infix, or Postfix. + virtual void + lookupOperatorDirect(Identifier name, OperatorFixity fixity, + TinyPtrVector &results) const {} + + /// Look up a precedence group. Do not call directly, use + /// \c DirectPrecedenceGroupLookupRequest instead. + /// + /// \param name The precedence group name. + virtual void lookupPrecedenceGroupDirect( + Identifier name, TinyPtrVector &results) const {} + +public: /// Returns the comment attached to the given declaration. /// /// This function is an implementation detail for comment serialization. @@ -342,22 +364,6 @@ class LoadedFile : public FileUnit { return StringRef(); } - /// Look up an operator declaration. - /// - /// \param name The operator name ("+", ">>", etc.) - /// - /// \param fixity One of PrefixOperator, InfixOperator, or PostfixOperator. - virtual OperatorDecl *lookupOperator(Identifier name, DeclKind fixity) const { - return nullptr; - } - - /// Look up a precedence group. - /// - /// \param name The precedence group name. - virtual PrecedenceGroupDecl *lookupPrecedenceGroup(Identifier name) const { - return nullptr; - } - /// Returns the Swift module that overlays a Clang module. virtual ModuleDecl *getOverlayModule() const { return nullptr; } diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 8af23bb8edcd5..b67e99f9f9919 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -204,7 +204,10 @@ IDENTIFIER_(nsError) IDENTIFIER(OSLogMessage) // Differentiable programming +IDENTIFIER(differential) +IDENTIFIER(pullback) IDENTIFIER(TangentVector) +IDENTIFIER(zero) #undef IDENTIFIER #undef IDENTIFIER_ diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 363344f1c95b9..18a80628d71db 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -84,6 +84,7 @@ PROTOCOL_(DestructorSafeContainer) PROTOCOL(StringInterpolationProtocol) +PROTOCOL(AdditiveArithmetic) PROTOCOL(Differentiable) EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false) diff --git a/include/swift/AST/Module.h b/include/swift/AST/Module.h index 337ae78865837..5ee0ab719a15e 100644 --- a/include/swift/AST/Module.h +++ b/include/swift/AST/Module.h @@ -335,6 +335,31 @@ class ModuleDecl : public DeclContext, public TypeDecl { void getDeclaredCrossImportBystanders( SmallVectorImpl &bystanderNames); + /// A lazily populated mapping from each declared cross import overlay this + /// module transitively underlies to its bystander and immediate underlying + /// module. + llvm::SmallDenseMap, 1> + declaredCrossImportsTransitive; + + /// Determines if the given \p overlay is a declarared cross-import overlay of + /// this module, or an of its transitively declared overlay modules. + /// + /// This is used by tooling to map overlays to their underlying modules, and t + bool isUnderlyingModuleOfCrossImportOverlay(const ModuleDecl *overlay); + + /// If \p overlay is a transitively declared cross-import overlay of this + /// module, gets the list of bystander modules that need to be imported + /// alongside this module for the overlay to be loaded. + void getAllBystandersForCrossImportOverlay( + ModuleDecl *overlay, SmallVectorImpl &bystanders); + + /// Walks and loads the declared cross-import overlays of this module, + /// transitively, to find all overlays this module underlies. + /// + /// This is used by tooling to present these overlays as part of this module. + void findDeclaredCrossImportOverlaysTransitive( + SmallVectorImpl &overlays); + /// Convenience accessor for clients that know what kind of file they're /// dealing with. SourceFile &getMainSourceFile(SourceFileKind expectedKind) const; diff --git a/include/swift/AST/NameLookupRequests.h b/include/swift/AST/NameLookupRequests.h index 107886a0c09b3..ff2f358b6d891 100644 --- a/include/swift/AST/NameLookupRequests.h +++ b/include/swift/AST/NameLookupRequests.h @@ -18,6 +18,7 @@ #include "swift/AST/SimpleRequest.h" #include "swift/AST/ASTTypeIDs.h" +#include "swift/AST/FileUnit.h" #include "swift/AST/Identifier.h" #include "swift/Basic/Statistic.h" #include "llvm/ADT/Hashing.h" @@ -162,6 +163,27 @@ class SuperclassDeclRequest : void cacheResult(ClassDecl *value) const; }; +class InheritedProtocolsRequest + : public SimpleRequest(ProtocolDecl *), + CacheKind::SeparatelyCached> { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + ArrayRef + evaluate(Evaluator &evaluator, ProtocolDecl *PD) const; + +public: + // Caching. + bool isCached() const { return true; } + Optional> getCachedResult() const; + void cacheResult(ArrayRef decls) const; +}; + /// Requests whether or not this class has designated initializers that are /// not public or @usableFromInline. class HasMissingDesignatedInitializersRequest : @@ -518,22 +540,35 @@ class DirectLookupRequest class OperatorLookupDescriptor final { public: - SourceFile *SF; + using Storage = llvm::PointerUnion; + Storage fileOrModule; Identifier name; bool isCascading; SourceLoc diagLoc; - OperatorLookupDescriptor(SourceFile *SF, Identifier name, bool isCascading, - SourceLoc diagLoc) - : SF(SF), name(name), isCascading(isCascading), diagLoc(diagLoc) {} +private: + OperatorLookupDescriptor(Storage fileOrModule, Identifier name, + bool isCascading, SourceLoc diagLoc) + : fileOrModule(fileOrModule), name(name), isCascading(isCascading), + diagLoc(diagLoc) {} + +public: + /// Retrieves the files to perform lookup in. + ArrayRef getFiles() const; + + /// If this is for a module lookup, returns the module. Otherwise returns + /// \c nullptr. + ModuleDecl *getModule() const { + return fileOrModule.dyn_cast(); + } friend llvm::hash_code hash_value(const OperatorLookupDescriptor &desc) { - return llvm::hash_combine(desc.SF, desc.name, desc.isCascading); + return llvm::hash_combine(desc.fileOrModule, desc.name, desc.isCascading); } friend bool operator==(const OperatorLookupDescriptor &lhs, const OperatorLookupDescriptor &rhs) { - return lhs.SF == rhs.SF && lhs.name == rhs.name && + return lhs.fileOrModule == rhs.fileOrModule && lhs.name == rhs.name && lhs.isCascading == rhs.isCascading; } @@ -541,6 +576,17 @@ class OperatorLookupDescriptor final { const OperatorLookupDescriptor &rhs) { return !(lhs == rhs); } + + static OperatorLookupDescriptor forFile(FileUnit *file, Identifier name, + bool isCascading, SourceLoc diagLoc) { + return OperatorLookupDescriptor(file, name, isCascading, diagLoc); + } + + static OperatorLookupDescriptor forModule(ModuleDecl *mod, Identifier name, + bool isCascading, + SourceLoc diagLoc) { + return OperatorLookupDescriptor(mod, name, isCascading, diagLoc); + } }; void simple_display(llvm::raw_ostream &out, @@ -572,6 +618,86 @@ using LookupInfixOperatorRequest = LookupOperatorRequest; using LookupPostfixOperatorRequest = LookupOperatorRequest; using LookupPrecedenceGroupRequest = LookupOperatorRequest; +/// Looks up an operator in a given file or module without looking through +/// imports. +class DirectOperatorLookupRequest + : public SimpleRequest( + OperatorLookupDescriptor, OperatorFixity), + CacheKind::Uncached> { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + llvm::Expected> + evaluate(Evaluator &evaluator, OperatorLookupDescriptor descriptor, + OperatorFixity fixity) const; +}; + +/// Looks up an precedencegroup in a given file or module without looking +/// through imports. +class DirectPrecedenceGroupLookupRequest + : public SimpleRequest( + OperatorLookupDescriptor), + CacheKind::Uncached> { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + llvm::Expected> + evaluate(Evaluator &evaluator, OperatorLookupDescriptor descriptor) const; +}; + +class LookupConformanceDescriptor final { +public: + ModuleDecl *Mod; + Type Ty; + ProtocolDecl *PD; + + LookupConformanceDescriptor(ModuleDecl *Mod, Type Ty, ProtocolDecl *PD) + : Mod(Mod), Ty(Ty), PD(PD) {} + + friend llvm::hash_code hash_value(const LookupConformanceDescriptor &desc) { + return llvm::hash_combine(desc.Mod, desc.Ty.getPointer(), desc.PD); + } + + friend bool operator==(const LookupConformanceDescriptor &lhs, + const LookupConformanceDescriptor &rhs) { + return lhs.Mod == rhs.Mod && lhs.Ty.getPointer() == rhs.Ty.getPointer() && + lhs.PD == rhs.PD; + } + + friend bool operator!=(const LookupConformanceDescriptor &lhs, + const LookupConformanceDescriptor &rhs) { + return !(lhs == rhs); + } +}; + +void simple_display(llvm::raw_ostream &out, + const LookupConformanceDescriptor &desc); + +SourceLoc extractNearestSourceLoc(const LookupConformanceDescriptor &desc); + +class LookupConformanceInModuleRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + llvm::Expected evaluate( + Evaluator &evaluator, LookupConformanceDescriptor desc) const; +}; + #define SWIFT_TYPEID_ZONE NameLookup #define SWIFT_TYPEID_HEADER "swift/AST/NameLookupTypeIDZone.def" #include "swift/Basic/DefineTypeIDZone.h" diff --git a/include/swift/AST/NameLookupTypeIDZone.def b/include/swift/AST/NameLookupTypeIDZone.def index 5e556d25b0126..f1f680288d8b0 100644 --- a/include/swift/AST/NameLookupTypeIDZone.def +++ b/include/swift/AST/NameLookupTypeIDZone.def @@ -24,6 +24,13 @@ SWIFT_REQUEST(NameLookup, CustomAttrNominalRequest, SWIFT_REQUEST(NameLookup, DirectLookupRequest, TinyPtrVector(DirectLookupDescriptor), Uncached, NoLocationInfo) +SWIFT_REQUEST(NameLookup, DirectOperatorLookupRequest, + TinyPtrVector(OperatorLookupDescriptor, + OperatorFixity), + Uncached, NoLocationInfo) +SWIFT_REQUEST(NameLookup, DirectPrecedenceGroupLookupRequest, + TinyPtrVector(OperatorLookupDescriptor), + Uncached, NoLocationInfo) SWIFT_REQUEST(NameLookup, ExpandASTScopeRequest, ast_scope::ASTScopeImpl* (ast_scope::ASTScopeImpl*, ast_scope::ScopeCreator*), SeparatelyCached, @@ -40,6 +47,12 @@ SWIFT_REQUEST(NameLookup, InheritedDeclsReferencedRequest, DirectlyReferencedTypeDecls( llvm::PointerUnion, unsigned), Uncached, HasNearestLocation) +SWIFT_REQUEST(NameLookup, InheritedProtocolsRequest, + ArrayRef(ProtocolDecl *), SeparatelyCached, + NoLocationInfo) +SWIFT_REQUEST(NameLookup, LookupConformanceInModuleRequest, + ProtocolConformanceRef(LookupConformanceDescriptor), + Uncached, NoLocationInfo) SWIFT_REQUEST(NameLookup, LookupInModuleRequest, QualifiedLookupResult(const DeclContext *, DeclName, NLKind, namelookup::ResolutionKind, diff --git a/include/swift/AST/PrintOptions.h b/include/swift/AST/PrintOptions.h index 5772430b2ea3c..56db3d423de10 100644 --- a/include/swift/AST/PrintOptions.h +++ b/include/swift/AST/PrintOptions.h @@ -432,6 +432,13 @@ struct PrintOptions { /// The information for converting archetypes to specialized types. llvm::Optional TransformContext; + /// Before printing the name of a ModuleDecl, this callback will be called and + /// the name of the ModuleDecl it returns will be printed instead. This is + /// currently used to present cross import overlays as if they were their + /// underlying module. + std::function mapModuleToUnderlying = + [] (const ModuleDecl *D) { return D; }; + bool PrintAsMember = false; /// Whether to print parameter specifiers as 'let' and 'var'. diff --git a/include/swift/AST/SourceFile.h b/include/swift/AST/SourceFile.h index 336c53c4bfd27..a4e089e5c57be 100644 --- a/include/swift/AST/SourceFile.h +++ b/include/swift/AST/SourceFile.h @@ -436,6 +436,16 @@ class SourceFile final : public FileUnit { ObjCSelector selector, SmallVectorImpl &results) const override; +protected: + virtual void + lookupOperatorDirect(Identifier name, OperatorFixity fixity, + TinyPtrVector &results) const override; + + virtual void lookupPrecedenceGroupDirect( + Identifier name, + TinyPtrVector &results) const override; + +public: virtual void getTopLevelDecls(SmallVectorImpl &results) const override; virtual void diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 3c39c8db91c0f..e87342fc67fe0 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -1600,6 +1600,31 @@ class SynthesizeMemberwiseInitRequest bool isCached() const { return true; } }; +/// Resolves the effective memberwise initializer for a given type. +/// +/// An effective memberwise initializer is either a synthesized memberwise +/// initializer or a user-defined initializer with the same type. +/// +/// See `NominalTypeDecl::getEffectiveMemberwiseInitializer` for details. +class ResolveEffectiveMemberwiseInitRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + llvm::Expected evaluate(Evaluator &evaluator, + NominalTypeDecl *decl) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + /// Checks whether this type has a synthesized zero parameter default /// initializer. class HasDefaultInitRequest diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 8a6972ccd2c7d..39329f067b29e 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -214,6 +214,8 @@ SWIFT_REQUEST(TypeChecker, SPIGroupsRequest, Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, SynthesizeMemberwiseInitRequest, ConstructorDecl *(NominalTypeDecl *), Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, ResolveEffectiveMemberwiseInitRequest, + ConstructorDecl *(NominalTypeDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, HasDefaultInitRequest, bool(NominalTypeDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, SynthesizeDefaultInitRequest, diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 7b145879fd9f4..c0da7fe307146 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3353,6 +3353,8 @@ class AnyFunctionType : public TypeBase { IndexSubset *parameterIndices, AutoDiffLinearMapKind kind, LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false); + AnyFunctionType *getWithoutDifferentiability() const; + /// True if the parameter declaration it is attached to is guaranteed /// to not persist the closure for longer than the duration of the call. bool isNoEscape() const { diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 8c8da5701f6a5..0138da6ec5b29 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -327,6 +327,10 @@ namespace swift { /// `@differentiable` declaration attribute, etc. bool EnableExperimentalDifferentiableProgramming = false; + /// Whether to enable experimental `AdditiveArithmetic` derived + /// conformances. + bool EnableExperimentalAdditiveArithmeticDerivedConformances = false; + /// Enable verification when every SubstitutionMap is constructed. bool VerifyAllSubstitutionMaps = false; diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index b85b6b415a47d..4043a786023ce 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -495,10 +495,16 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">, HelpText<"Disable automatic generation of bridging PCH files">; // Experimental feature options -def enable_experimental_differentiable_programming : Flag<["-"], "enable-experimental-differentiable-programming">, +def enable_experimental_differentiable_programming : + Flag<["-"], "enable-experimental-differentiable-programming">, Flags<[FrontendOption]>, HelpText<"Enable experimental differentiable programming features">; +def enable_experimental_additive_arithmetic_derivation : + Flag<["-"], "enable-experimental-additive-arithmetic-derivation">, + Flags<[FrontendOption]>, + HelpText<"Enable experimental 'AdditiveArithmetic' derived conformances">; + def enable_experimental_concise_pound_file : Flag<["-"], "enable-experimental-concise-pound-file">, Flags<[FrontendOption]>, diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index 56ea6ec896ecf..1849898d90990 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -1010,8 +1010,6 @@ class Parser { /// Parse the arguments inside the @differentiable attribute. bool parseDifferentiableAttributeArguments( bool &linear, SmallVectorImpl ¶ms, - Optional &jvpSpec, - Optional &vjpSpec, TrailingWhereClause *&whereClause); /// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in diff --git a/include/swift/SIL/LinearLifetimeChecker.h b/include/swift/SIL/LinearLifetimeChecker.h index 85464175436fd..e4bcc02b36c04 100644 --- a/include/swift/SIL/LinearLifetimeChecker.h +++ b/include/swift/SIL/LinearLifetimeChecker.h @@ -28,103 +28,8 @@ class SILInstruction; class SILModule; class SILValue; class DeadEndBlocks; - -namespace ownership { - -struct ErrorBehaviorKind { - enum inner_t { - Invalid = 0, - ReturnFalse = 1, - PrintMessage = 2, - Assert = 4, - ReturnFalseOnLeak = 8, - PrintMessageAndReturnFalse = PrintMessage | ReturnFalse, - PrintMessageAndAssert = PrintMessage | Assert, - ReturnFalseOnLeakAssertOtherwise = ReturnFalseOnLeak | Assert, - } Value; - - ErrorBehaviorKind() : Value(Invalid) {} - ErrorBehaviorKind(inner_t Inner) : Value(Inner) { assert(Value != Invalid); } - - bool shouldAssert() const { - assert(Value != Invalid); - return Value & Assert; - } - - bool shouldReturnFalseOnLeak() const { - assert(Value != Invalid); - return Value & ReturnFalseOnLeak; - } - - bool shouldPrintMessage() const { - assert(Value != Invalid); - return Value & PrintMessage; - } - - bool shouldReturnFalse() const { - assert(Value != Invalid); - return Value & ReturnFalse; - } -}; - -} // end namespace ownership - -class LinearLifetimeError { - ownership::ErrorBehaviorKind errorBehavior; - bool foundUseAfterFree = false; - bool foundLeak = false; - bool foundOverConsume = false; - -public: - LinearLifetimeError(ownership::ErrorBehaviorKind errorBehavior) - : errorBehavior(errorBehavior) {} - - bool getFoundError() const { - return foundUseAfterFree || foundLeak || foundOverConsume; - } - - bool getFoundLeak() const { return foundLeak; } - - bool getFoundUseAfterFree() const { return foundUseAfterFree; } - - bool getFoundOverConsume() const { return foundOverConsume; } - - void handleLeak(llvm::function_ref &&messagePrinterFunc) { - foundLeak = true; - - if (errorBehavior.shouldPrintMessage()) - messagePrinterFunc(); - - if (errorBehavior.shouldReturnFalseOnLeak()) - return; - - // We already printed out our error if we needed to, so don't pass it along. - handleError([]() {}); - } - - void handleOverConsume(llvm::function_ref &&messagePrinterFunc) { - foundOverConsume = true; - handleError(std::move(messagePrinterFunc)); - } - - void handleUseAfterFree(llvm::function_ref &&messagePrinterFunc) { - foundUseAfterFree = true; - handleError(std::move(messagePrinterFunc)); - } - -private: - void handleError(llvm::function_ref &&messagePrinterFunc) { - if (errorBehavior.shouldPrintMessage()) - messagePrinterFunc(); - - if (errorBehavior.shouldReturnFalse()) { - return; - } - - assert(errorBehavior.shouldAssert() && "At this point, we should assert"); - llvm_unreachable("triggering standard assertion failure routine"); - } -}; +class SILOwnershipVerifier; +class SILValueOwnershipChecker; /// A class used to validate linear lifetime with respect to an SSA-like /// definition. @@ -140,6 +45,14 @@ class LinearLifetimeError { /// uses must not be reachable from each other and jointly post-dominate all /// consuming uses as well as the defining block/instruction. class LinearLifetimeChecker { +public: + class Error; + struct ErrorBehaviorKind; + +private: + friend class SILOwnershipVerifier; + friend class SILValueOwnershipChecker; + SmallPtrSetImpl &visitedBlocks; DeadEndBlocks &deadEndBlocks; @@ -148,6 +61,24 @@ class LinearLifetimeChecker { DeadEndBlocks &deadEndBlocks) : visitedBlocks(visitedBlocks), deadEndBlocks(deadEndBlocks) {} + /// Returns true that \p value forms a linear lifetime with consuming uses \p + /// consumingUses, non consuming uses \p nonConsumingUses. Returns false + /// otherwise. + bool validateLifetime(SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses); + + /// Given a value and a consuming use of that value, compute a non-unique + /// minimal set of insertion points that together with \p consumingUse + /// post-dominate and end the lifetime of \p value. + /// + /// Returns true if we completed the consuming use set and discovered that \p + /// consumingUse is not strongly control equivalent to value (meaning + /// consumingUse is not in the same loop in the loop nest as value). + bool completeConsumingUseSet( + SILValue value, Operand *consumingUse, + function_ref visitor); + +private: /// Returns true if: /// /// 1. No consuming uses are reachable from any other consuming use, from any @@ -164,22 +95,19 @@ class LinearLifetimeChecker { /// error. /// \p leakingBlocks If non-null a list of blocks where the value was detected /// to leak. Can be used to insert missing destroys. - LinearLifetimeError - checkValue(SILValue value, ArrayRef consumingUses, - ArrayRef nonConsumingUses, - ownership::ErrorBehaviorKind errorBehavior, - SmallVectorImpl *leakingBlocks = nullptr); - - /// Returns true that \p value forms a linear lifetime with consuming uses \p - /// consumingUses, non consuming uses \p nonConsumingUses. Returns false - /// otherwise. - bool validateLifetime(SILValue value, ArrayRef consumingUses, - ArrayRef nonConsumingUses) { - return !checkValue(value, consumingUses, nonConsumingUses, - ownership::ErrorBehaviorKind::ReturnFalse, - nullptr /*leakingBlocks*/) - .getFoundError(); - } + Error checkValue(SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses, + ErrorBehaviorKind errorBehavior); + + Error checkValue(SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses, + ErrorBehaviorKind errorBehavior, + function_ref leakingBlockCallback); + + Error checkValueImpl( + SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses, ErrorBehaviorKind errorBehavior, + Optional> leakingBlockCallback); }; } // namespace swift diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 6401020d37703..95c1cefe217f6 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -2164,6 +2164,23 @@ class SILBuilder { // Differentiable programming instructions //===--------------------------------------------------------------------===// + DifferentiableFunctionInst *createDifferentiableFunction( + SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction, + Optional> JVPAndVJPFunctions = None) { + return insert(DifferentiableFunctionInst::create( + getModule(), getSILDebugLocation(Loc), ParameterIndices, + OriginalFunction, JVPAndVJPFunctions, hasOwnership())); + } + + /// Note: explicit extractee type may be specified only in lowered SIL. + DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract( + SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee, + SILValue Function, Optional ExtracteeType = None) { + return insert(new (getModule()) DifferentiableFunctionExtractInst( + getModule(), getSILDebugLocation(Loc), Extractee, Function, + ExtracteeType)); + } + /// Note: explicit function type may be specified only in lowered SIL. DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction( SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind, diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 503852f5ba021..a94c1662cd7ce 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -2827,6 +2827,33 @@ void SILCloner::visitKeyPathInst(KeyPathInst *Inst) { opValues, getOpType(Inst->getType()))); } +template +void SILCloner::visitDifferentiableFunctionInst( + DifferentiableFunctionInst *Inst) { + getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); + Optional> derivativeFns = None; + if (Inst->hasDerivativeFunctions()) + derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()), + getOpValue(Inst->getVJPFunction())); + recordClonedInstruction( + Inst, getBuilder().createDifferentiableFunction( + getOpLocation(Inst->getLoc()), Inst->getParameterIndices(), + getOpValue(Inst->getOriginalFunction()), derivativeFns)); +} + +template +void SILCloner::visitDifferentiableFunctionExtractInst( + DifferentiableFunctionExtractInst *Inst) { + getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); + Optional explicitExtracteeType = None; + if (Inst->hasExplicitExtracteeType()) + explicitExtracteeType = Inst->getType(); + recordClonedInstruction( + Inst, getBuilder().createDifferentiableFunctionExtract( + getOpLocation(Inst->getLoc()), Inst->getExtractee(), + getOpValue(Inst->getOperand()), explicitExtracteeType)); +} + template void SILCloner::visitDifferentiabilityWitnessFunctionInst( DifferentiabilityWitnessFunctionInst *Inst) { diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 3366328915a74..5563c3b684980 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7976,6 +7976,126 @@ class TryApplyInst final const GenericSpecializationInformation *SpecializationInfo); }; +/// DifferentiableFunctionInst - creates a `@differentiable` function-typed +/// value from an original function operand and derivative function operands +/// (optional). The differentiation transform canonicalizes +/// `differentiable_function` instructions, filling in derivative function +/// operands if missing. +class DifferentiableFunctionInst final + : public InstructionBaseWithTrailingOperands< + SILInstructionKind::DifferentiableFunctionInst, + DifferentiableFunctionInst, OwnershipForwardingSingleValueInst> { +private: + friend SILBuilder; + /// Differentiability parameter indices. + IndexSubset *ParameterIndices; + /// Indicates whether derivative function operands (JVP/VJP) exist. + bool HasDerivativeFunctions; + + DifferentiableFunctionInst(SILDebugLocation DebugLoc, + IndexSubset *ParameterIndices, + SILValue OriginalFunction, + ArrayRef DerivativeFunctions, + bool HasOwnership); + + static SILType getDifferentiableFunctionType(SILValue OriginalFunction, + IndexSubset *ParameterIndices); + + static ValueOwnershipKind + getMergedOwnershipKind(SILValue OriginalFunction, + ArrayRef DerivativeFunctions); + +public: + static DifferentiableFunctionInst * + create(SILModule &Module, SILDebugLocation Loc, IndexSubset *ParameterIndices, + SILValue OriginalFunction, + Optional> VJPAndJVPFunctions, + bool HasOwnership); + + /// Returns the original function operand. + SILValue getOriginalFunction() const { return getOperand(0); } + + /// Returns differentiability parameter indices. + IndexSubset *getParameterIndices() const { return ParameterIndices; } + + /// Returns true if derivative functions (JVP/VJP) exist. + bool hasDerivativeFunctions() const { return HasDerivativeFunctions; } + + /// Returns the derivative function operands if they exist. + /// Otherwise, return `None`. + Optional> + getOptionalDerivativeFunctionPair() const { + if (!HasDerivativeFunctions) + return None; + return std::make_pair(getOperand(1), getOperand(2)); + } + + ArrayRef getDerivativeFunctionArray() const { + return getAllOperands().drop_front(); + } + + /// Returns the JVP function operand. + SILValue getJVPFunction() const { + assert(HasDerivativeFunctions); + return getOperand(1); + } + + /// Returns the VJP function operand. + SILValue getVJPFunction() const { + assert(HasDerivativeFunctions); + return getOperand(2); + } + + /// Returns the derivative function operand (JVP or VJP) with the given kind. + SILValue getDerivativeFunction(AutoDiffDerivativeFunctionKind kind) const { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: + return getJVPFunction(); + case AutoDiffDerivativeFunctionKind::VJP: + return getVJPFunction(); + } + } +}; + +/// DifferentiableFunctionExtractInst - extracts either the original or +/// derivative function value from a `@differentiable` function. +class DifferentiableFunctionExtractInst + : public UnaryInstructionBase< + SILInstructionKind::DifferentiableFunctionExtractInst, + SingleValueInstruction> { +private: + /// The extractee. + NormalDifferentiableFunctionTypeComponent Extractee; + /// True if the instruction has an explicit extractee type. + bool HasExplicitExtracteeType; + + static SILType + getExtracteeType(SILValue function, + NormalDifferentiableFunctionTypeComponent extractee, + SILModule &module); + +public: + /// Note: explicit extractee type may be specified only in lowered SIL. + explicit DifferentiableFunctionExtractInst( + SILModule &module, SILDebugLocation debugLoc, + NormalDifferentiableFunctionTypeComponent extractee, SILValue function, + Optional extracteeType = None); + + NormalDifferentiableFunctionTypeComponent getExtractee() const { + return Extractee; + } + + AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const { + auto kind = Extractee.getAsDerivativeFunctionKind(); + assert(kind); + return *kind; + } + + bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; } +}; + +/// DifferentiabilityWitnessFunctionInst - Looks up a differentiability witness +/// function for a given original function. class DifferentiabilityWitnessFunctionInst : public InstructionBase< SILInstructionKind::DifferentiabilityWitnessFunctionInst, diff --git a/include/swift/SIL/SILNodes.def b/include/swift/SIL/SILNodes.def index 7b448d3a5e645..7eb0202153c3f 100644 --- a/include/swift/SIL/SILNodes.def +++ b/include/swift/SIL/SILNodes.def @@ -692,6 +692,11 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction) SingleValueInstruction, None, DoesNotRelease) // Differentiable programming + SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function, + SingleValueInstruction, None, DoesNotRelease) + SINGLE_VALUE_INST(DifferentiableFunctionExtractInst, + differentiable_function_extract, + SingleValueInstruction, None, DoesNotRelease) SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst, differentiability_witness_function, SingleValueInstruction, None, DoesNotRelease) diff --git a/include/swift/Serialization/SerializedModuleLoader.h b/include/swift/Serialization/SerializedModuleLoader.h index b3c7bdbd4b117..b508bb08dfd56 100644 --- a/include/swift/Serialization/SerializedModuleLoader.h +++ b/include/swift/Serialization/SerializedModuleLoader.h @@ -328,12 +328,16 @@ class SerializedASTFile final : public LoadedFile { lookupNestedType(Identifier name, const NominalTypeDecl *parent) const override; - virtual OperatorDecl *lookupOperator(Identifier name, - DeclKind fixity) const override; +protected: + virtual void + lookupOperatorDirect(Identifier name, OperatorFixity fixity, + TinyPtrVector &results) const override; - virtual PrecedenceGroupDecl * - lookupPrecedenceGroup(Identifier name) const override; + virtual void lookupPrecedenceGroupDirect( + Identifier name, + TinyPtrVector &results) const override; +public: virtual void lookupVisibleDecls(ModuleDecl::AccessPathTy accessPath, VisibleDeclConsumer &consumer, NLKind lookupKind) const override; diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index 1a30c3a030520..7ed6f750fa345 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -2801,6 +2801,35 @@ void ASTMangler::appendDependentProtocolConformance( } } +void ASTMangler::appendAnyProtocolConformance( + CanGenericSignature genericSig, + CanType conformingType, + ProtocolConformanceRef conformance) { + if (conformingType->isTypeParameter()) { + assert(genericSig && "Need a generic signature to resolve conformance"); + auto path = genericSig->getConformanceAccessPath(conformingType, + conformance.getAbstract()); + appendDependentProtocolConformance(path); + } else if (auto opaqueType = conformingType->getAs()) { + GenericSignature opaqueSignature = opaqueType->getBoundSignature(); + GenericTypeParamType *opaqueTypeParam = opaqueSignature->getGenericParams().back(); + ConformanceAccessPath conformanceAccessPath = + opaqueSignature->getConformanceAccessPath(opaqueTypeParam, + conformance.getAbstract()); + + // Append the conformance access path with the signature of the opaque type. + { + llvm::SaveAndRestore savedSignature( + CurGenericSignature, opaqueSignature.getCanonicalSignature()); + appendDependentProtocolConformance(conformanceAccessPath); + } + appendType(conformingType); + appendOperator("HO"); + } else { + appendConcreteProtocolConformance(conformance.getConcrete()); + } +} + void ASTMangler::appendConcreteProtocolConformance( const ProtocolConformance *conformance) { auto module = conformance->getDeclContext()->getParentModule(); @@ -2841,30 +2870,15 @@ void ASTMangler::appendConcreteProtocolConformance( CanType canType = type->getCanonicalType(CurGenericSignature); auto proto = conditionalReq.getSecondType()->castTo()->getDecl(); - if (canType->isTypeParameter()) { - assert(CurGenericSignature && - "Need a generic signature to resolve conformance"); - auto conformanceAccessPath = - CurGenericSignature->getConformanceAccessPath(type, proto); - appendDependentProtocolConformance(conformanceAccessPath); - } else if (auto opaqueType = canType->getAs()) { - GenericSignature opaqueSignature = opaqueType->getBoundSignature(); - GenericTypeParamType *opaqueTypeParam = opaqueSignature->getGenericParams().back(); - ConformanceAccessPath conformanceAccessPath = - opaqueSignature->getConformanceAccessPath(opaqueTypeParam, proto); - - // Append the conformance access path with the signature of the opaque type. - { - llvm::SaveAndRestore savedSignature( - CurGenericSignature, opaqueSignature.getCanonicalSignature()); - appendDependentProtocolConformance(conformanceAccessPath); - } - appendType(canType); - appendOperator("HO"); + + ProtocolConformanceRef conformance; + + if (canType->isTypeParameter() || canType->is()){ + conformance = ProtocolConformanceRef(proto); } else { - auto conditionalConf = module->lookupConformance(canType, proto); - appendConcreteProtocolConformance(conditionalConf.getConcrete()); + conformance = module->lookupConformance(canType, proto); } + appendAnyProtocolConformance(CurGenericSignature, canType, conformance); appendListSeparator(firstRequirement); break; } diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index 7707794ee6535..6860ffbda7b47 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -3586,7 +3586,8 @@ class TypePrinter : public TypeVisitor { template void printModuleContext(T *Ty) { FileUnit *File = cast(Ty->getDecl()->getModuleScopeContext()); - ModuleDecl *Mod = File->getParentModule(); + const ModuleDecl *Mod = + Options.mapModuleToUnderlying(File->getParentModule()); Identifier Name = Mod->getName(); if (Options.UseExportedModuleNames) diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index a217599fc2621..081c6d8969733 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -526,12 +526,9 @@ static std::string getDifferentiationParametersClauseString( /// Print the arguments of the given `@differentiable` attribute. /// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation /// parameters clause. -/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative -/// functions. static void printDifferentiableAttrArguments( const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options, - const Decl *D, bool omitWrtClause = false, - bool omitDerivativeFunctions = false) { + const Decl *D, bool omitWrtClause = false) { assert(D); // Create a temporary string for the attribute argument text. std::string attrArgText; @@ -574,19 +571,6 @@ static void printDifferentiableAttrArguments( stream << diffParamsString; } } - // Print derivative function names, unless they are to be omitted. - if (!omitDerivativeFunctions) { - // Print jvp function name, if specified. - if (auto jvp = attr->getJVP()) { - printCommaIfNecessary(); - stream << "jvp: " << jvp->Name; - } - // Print vjp function name, if specified. - if (auto vjp = attr->getVJP()) { - printCommaIfNecessary(); - stream << "vjp: " << vjp->Name; - } - } // Print 'where' clause, if any. // First, filter out requirements satisfied by the original function's // generic signature. They should not be printed. @@ -1616,12 +1600,9 @@ SPIAccessControlAttr::create(ASTContext &context, DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, ArrayRef params, - Optional jvp, - Optional vjp, TrailingWhereClause *clause) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), - Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)), - VJP(std::move(vjp)), WhereClause(clause) { + Linear(linear), NumParsedParameters(params.size()), WhereClause(clause) { std::copy(params.begin(), params.end(), getTrailingObjects()); } @@ -1630,12 +1611,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *parameterIndices, - Optional jvp, - Optional vjp, GenericSignature derivativeGenSig) : DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit), - OriginalDeclaration(original), Linear(linear), JVP(std::move(jvp)), - VJP(std::move(vjp)) { + OriginalDeclaration(original), Linear(linear) { setParameterIndices(parameterIndices); setDerivativeGenericSignature(derivativeGenSig); } @@ -1645,29 +1623,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, ArrayRef parameters, - Optional jvp, - Optional vjp, TrailingWhereClause *clause) { unsigned size = totalSizeToAlloc(parameters.size()); void *mem = context.Allocate(size, alignof(DifferentiableAttr)); return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear, - parameters, std::move(jvp), - std::move(vjp), clause); + parameters, clause); } DifferentiableAttr * DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit, SourceLoc atLoc, SourceRange baseRange, bool linear, IndexSubset *parameterIndices, - Optional jvp, - Optional vjp, GenericSignature derivativeGenSig) { auto &ctx = original->getASTContext(); void *mem = ctx.Allocate(sizeof(DifferentiableAttr), alignof(DifferentiableAttr)); return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange, - linear, parameterIndices, std::move(jvp), - std::move(vjp), derivativeGenSig); + linear, parameterIndices, derivativeGenSig); } void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) { @@ -1701,18 +1673,6 @@ void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) { std::move(paramIndices)); } -void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { - JVPFunction = decl; - if (decl && !JVP) - JVP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())}; -} - -void DifferentiableAttr::setVJPFunction(FuncDecl *decl) { - VJPFunction = decl; - if (decl && !VJP) - VJP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())}; -} - GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment( AbstractFunctionDecl *original) const { GenericEnvironment *derivativeGenEnv = original->getGenericEnvironment(); @@ -1722,12 +1682,10 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment( } void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, - bool omitWrtClause, - bool omitDerivativeFunctions) const { + bool omitWrtClause) const { StreamPrinter P(OS); P << "@" << getAttrName(); - printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause, - omitDerivativeFunctions); + printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause); } DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc, diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index c806382b65d31..5dc2c5f6c73ce 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -28,6 +28,50 @@ AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind( rawValue = *result; } +NormalDifferentiableFunctionTypeComponent:: + NormalDifferentiableFunctionTypeComponent( + AutoDiffDerivativeFunctionKind kind) { + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: + rawValue = JVP; + return; + case AutoDiffDerivativeFunctionKind::VJP: + rawValue = VJP; + return; + } +} + +NormalDifferentiableFunctionTypeComponent:: + NormalDifferentiableFunctionTypeComponent(StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("original", Original) + .Case("jvp", JVP) + .Case("vjp", VJP); + assert(result && "Invalid string"); + rawValue = *result; +} + +Optional +NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const { + switch (rawValue) { + case Original: + return None; + case JVP: + return {AutoDiffDerivativeFunctionKind::JVP}; + case VJP: + return {AutoDiffDerivativeFunctionKind::VJP}; + } +} + +LinearDifferentiableFunctionTypeComponent:: + LinearDifferentiableFunctionTypeComponent(StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("original", Original) + .Case("transpose", Transpose); + assert(result && "Invalid string"); + rawValue = *result; +} + DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind( StringRef string) { Optional result = llvm::StringSwitch>(string) @@ -235,6 +279,77 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature( nullptr); } +// Given the rest of a `Builtin.applyDerivative_{jvp|vjp}` or +// `Builtin.applyTranspose` operation name, attempts to parse the arity and +// throwing-ness from the operation name. Modifies the operation name argument +// in place as substrings get dropped. +static void parseAutoDiffBuiltinCommonConfig( + StringRef &operationName, unsigned &arity, bool &throws) { + // Parse '_arity'. + constexpr char arityPrefix[] = "_arity"; + if (operationName.startswith(arityPrefix)) { + operationName = operationName.drop_front(sizeof(arityPrefix) - 1); + auto arityStr = operationName.take_while(llvm::isDigit); + operationName = operationName.drop_front(arityStr.size()); + auto converted = llvm::to_integer(arityStr, arity); + assert(converted); (void)converted; + assert(arity > 0); + } else { + arity = 1; + } + // Parse '_throws'. + constexpr char throwsPrefix[] = "_throws"; + if (operationName.startswith(throwsPrefix)) { + operationName = operationName.drop_front(sizeof(throwsPrefix) - 1); + throws = true; + } else { + throws = false; + } +} + +bool autodiff::getBuiltinApplyDerivativeConfig( + StringRef operationName, AutoDiffDerivativeFunctionKind &kind, + unsigned &arity, bool &throws) { + constexpr char prefix[] = "applyDerivative"; + if (!operationName.startswith(prefix)) + return false; + operationName = operationName.drop_front(sizeof(prefix) - 1); + // Parse 'jvp' or 'vjp'. + constexpr char jvpPrefix[] = "_jvp"; + constexpr char vjpPrefix[] = "_vjp"; + if (operationName.startswith(jvpPrefix)) + kind = AutoDiffDerivativeFunctionKind::JVP; + else if (operationName.startswith(vjpPrefix)) + kind = AutoDiffDerivativeFunctionKind::VJP; + operationName = operationName.drop_front(sizeof(jvpPrefix) - 1); + parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); + return operationName.empty(); +} + +bool autodiff::getBuiltinApplyTransposeConfig( + StringRef operationName, unsigned &arity, bool &throws) { + constexpr char prefix[] = "applyTranspose"; + if (!operationName.startswith(prefix)) + return false; + operationName = operationName.drop_front(sizeof(prefix) - 1); + parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); + return operationName.empty(); +} + +bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig( + StringRef operationName, unsigned &arity, bool &throws) { + constexpr char differentiablePrefix[] = "differentiableFunction"; + constexpr char linearPrefix[] = "linearFunction"; + if (operationName.startswith(differentiablePrefix)) + operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1); + else if (operationName.startswith(linearPrefix)) + operationName = operationName.drop_front(sizeof(linearPrefix) - 1); + else + return false; + parseAutoDiffBuiltinCommonConfig(operationName, arity, throws); + return operationName.empty(); +} + Type TangentSpace::getType() const { switch (kind) { case Kind::TangentVector: diff --git a/lib/AST/Availability.cpp b/lib/AST/Availability.cpp index c4f4bb276c874..4512b15047735 100644 --- a/lib/AST/Availability.cpp +++ b/lib/AST/Availability.cpp @@ -304,6 +304,9 @@ AvailabilityContext ASTContext::getSwift52Availability() { AvailabilityContext ASTContext::getSwift53Availability() { auto target = LangOpts.Target; + if (target.getArchName() == "arm64e") + return AvailabilityContext::alwaysAvailable(); + if (target.isMacOSX() ) { return AvailabilityContext( VersionRange::allGTE(llvm::VersionTuple(10, 99, 0))); diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index 343dab63d4224..8492c7a16b702 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -19,6 +19,7 @@ #include "swift/AST/FileUnit.h" #include "swift/AST/Module.h" #include "swift/AST/ParameterList.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/LLVMContext.h" #include "swift/Strings.h" #include "llvm/ADT/SmallString.h" @@ -184,7 +185,9 @@ static FuncDecl * getBuiltinGenericFunction(Identifier Id, ArrayRef ArgParamTypes, Type ResType, - GenericParamList *GenericParams) { + GenericParamList *GenericParams, + GenericSignature Sig, + bool Rethrows = false) { assert(GenericParams && "Missing generic parameters"); auto &Context = ResType->getASTContext(); @@ -213,13 +216,16 @@ getBuiltinGenericFunction(Identifier Id, StaticSpellingKind::None, /*FuncLoc=*/SourceLoc(), Name, /*NameLoc=*/SourceLoc(), - /*Throws=*/false, /*ThrowsLoc=*/SourceLoc(), + /*Throws=*/ Rethrows, /*ThrowsLoc=*/SourceLoc(), GenericParams, paramList, TypeLoc::withoutLoc(ResType), DC); func->setImplicit(); func->setAccess(AccessLevel::Public); + func->setGenericSignature(Sig); + if (Rethrows) + func->getAttrs().add(new (Context) RethrowsAttr(/*ThrowsLoc*/ SourceLoc())); return func; } @@ -446,11 +452,21 @@ namespace { GenericParamList *TheGenericParamList; SmallVector InterfaceParams; Type InterfaceResult; + bool Rethrows = false; + + // Accumulate params and requirements here, so that we can make the + // appropriate `AbstractGenericSignatureRequest` when `build()` is called. + SmallVector genericParamTypes; + SmallVector addedRequirements; public: BuiltinFunctionBuilder(ASTContext &ctx, unsigned numGenericParams = 1) : Context(ctx) { TheGenericParamList = getGenericParams(ctx, numGenericParams); + for (auto gp : TheGenericParamList->getParams()) { + genericParamTypes.push_back( + gp->getDeclaredInterfaceType()->castTo()); + } } template @@ -466,10 +482,27 @@ namespace { InterfaceResult = generator.build(*this); } + template + void addConformanceRequirement(const G &generator, ProtocolDecl *proto) { + Requirement req(RequirementKind::Conformance, + generator.build(*this), + proto->getDeclaredType()); + addedRequirements.push_back(req); + } + + void setRethrows(bool rethrows = true) { + Rethrows = rethrows; + } + FuncDecl *build(Identifier name) { + auto GenericSig = evaluateOrDefault( + Context.evaluator, + AbstractGenericSignatureRequest{ + nullptr, std::move(genericParamTypes), std::move(addedRequirements)}, + nullptr); return getBuiltinGenericFunction(name, InterfaceParams, InterfaceResult, - TheGenericParamList); + TheGenericParamList, GenericSig); } // Don't use these generator classes directly; call the make{...} @@ -942,6 +975,296 @@ static ValueDecl *getGetObjCTypeEncodingOperation(ASTContext &Context, return builder.build(Id); } +static ValueDecl *getAutoDiffApplyDerivativeFunction( + ASTContext &Context, Identifier Id, AutoDiffDerivativeFunctionKind kind, + unsigned arity, bool throws) { + assert(arity >= 1); + // JVP: + // <...T...(arity), R> (@differentiable (...T) throws -> R, ...T) + // rethrows -> (R, (...T.TangentVector) -> R.TangentVector) + // VJP: + // <...T...(arity), R> (@differentiable (...T) throws -> R, ...T) + // rethrows -> (R, (R.TangentVector) -> ...T.TangentVector) + unsigned numGenericParams = 1 + arity; + BuiltinFunctionBuilder builder(Context, numGenericParams); + // Get the `Differentiable` protocol. + auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable); + // Create type parameters and add conformance constraints. + auto fnResultGen = makeGenericParam(arity); + builder.addConformanceRequirement(fnResultGen, diffableProto); + SmallVector fnParamGens; + for (auto i : range(arity)) { + auto T = makeGenericParam(i); + builder.addConformanceRequirement(T, diffableProto); + fnParamGens.push_back(T); + } + // Generator for the first argument, i.e. the `@differentiable` function. + BuiltinFunctionBuilder::LambdaGenerator firstArgGen { + // Generator for the function type at the argument position, i.e. the + // function being differentiated. + [=, &fnParamGens](BuiltinFunctionBuilder &builder) -> Type { + FunctionType::ExtInfo ext; + auto extInfo = FunctionType::ExtInfo() + .withDifferentiabilityKind(DifferentiabilityKind::Normal) + .withNoEscape().withThrows(throws); + SmallVector params; + for (auto ¶mGen : fnParamGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + auto innerFunction = FunctionType::get(params, + fnResultGen.build(builder)); + return innerFunction->withExtInfo(extInfo); + } + }; + // Eagerly build the type of the first arg, then use that to compute the type + // of the result. + auto *diffFnType = + firstArgGen.build(builder)->castTo(); + diffFnType = diffFnType->getWithoutDifferentiability()->withExtInfo( + diffFnType->getExtInfo().withNoEscape(false)); + auto *paramIndices = IndexSubset::get( + Context, SmallBitVector(diffFnType->getNumParams(), true)); + // Generator for the resultant function type, i.e. the AD derivative function. + BuiltinFunctionBuilder::LambdaGenerator resultGen{ + [=, &Context](BuiltinFunctionBuilder &builder) -> Type { + auto derivativeFnTy = diffFnType->getAutoDiffDerivativeFunctionType( + paramIndices, kind, + LookUpConformanceInModule(Context.TheBuiltinModule)); + return derivativeFnTy->getResult(); + }}; + builder.addParameter(firstArgGen); + for (auto argGen : fnParamGens) + builder.addParameter(argGen); + if (throws) + builder.setRethrows(); + builder.setResult(resultGen); + return builder.build(Id); +} + +static ValueDecl *getAutoDiffApplyTransposeFunction( + ASTContext &Context, Identifier Id, unsigned arity, bool throws) { + assert(arity >= 1); + // <...T...(arity), R> + // (@differentiable (...T) throws -> R, ...R.TangentVector) + // rethrows -> (...T.TangentVector) + unsigned numGenericParams = 1 + arity; + BuiltinFunctionBuilder builder(Context, numGenericParams); + auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable); + auto *addArithProto = + Context.getProtocol(KnownProtocolKind::AdditiveArithmetic); + // Create type parameters and add conformance constraints. + auto linearFnResultGen = makeGenericParam(arity); + builder.addConformanceRequirement(linearFnResultGen, diffableProto); + builder.addConformanceRequirement(linearFnResultGen, addArithProto); + SmallVector linearFnParamGens; + for (auto i : range(arity)) { + auto T = makeGenericParam(i); + builder.addConformanceRequirement(T, diffableProto); + builder.addConformanceRequirement(T, addArithProto); + linearFnParamGens.push_back(T); + } + // Generator for the first argument, i.e. the `@differentiable(linear)` + // function. + BuiltinFunctionBuilder::LambdaGenerator firstArgGen { + // Generator for the function type at the argument position, i.e. the + // function being differentiated. + [=, &linearFnParamGens](BuiltinFunctionBuilder &builder) -> Type { + FunctionType::ExtInfo ext; + auto extInfo = FunctionType::ExtInfo() + .withDifferentiabilityKind(DifferentiabilityKind::Linear) + .withNoEscape().withThrows(throws); + SmallVector params; + for (auto ¶mGen : linearFnParamGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + auto innerFunction = FunctionType::get(params, + linearFnResultGen.build(builder)); + return innerFunction->withExtInfo(extInfo); + } + }; + builder.addParameter(firstArgGen); + builder.addParameter(linearFnResultGen); + if (throws) + builder.setRethrows(); + if (arity == 1) + builder.setResult(linearFnParamGens.front()); + else { + BuiltinFunctionBuilder::LambdaGenerator tupleResultGen { + [&](BuiltinFunctionBuilder &builder) -> Type { + SmallVector tupleElts; + for (auto linearFnParamGen : linearFnParamGens) + tupleElts.push_back(linearFnParamGen.build(builder)); + return TupleType::get(tupleElts, Context); + } + }; + builder.setResult(tupleResultGen); + } + return builder.build(Id); +} + +static ValueDecl *getDifferentiableFunctionConstructor( + ASTContext &Context, Identifier Id, unsigned arity, bool throws) { + assert(arity >= 1); + unsigned numGenericParams = 1 + arity; + BuiltinFunctionBuilder builder(Context, numGenericParams); + // Get the `Differentiable` and `AdditiveArithmetic` protocols. + auto *diffableProto = + Context.getProtocol(KnownProtocolKind::Differentiable); + auto *tangentVectorDecl = + diffableProto->getAssociatedType(Context.Id_TangentVector); + assert(tangentVectorDecl); + // Create type parameters and add conformance constraints. + auto origResultGen = makeGenericParam(arity); + builder.addConformanceRequirement(origResultGen, diffableProto); + SmallVector fnArgGens; + for (auto i : range(arity)) { + auto T = makeGenericParam(i); + builder.addConformanceRequirement(T, diffableProto); + fnArgGens.push_back(T); + } + + BuiltinFunctionBuilder::LambdaGenerator origFnGen { + [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type { + SmallVector params; + for (auto ¶mGen : fnArgGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + return FunctionType::get(params, origResultGen.build(builder)) + ->withExtInfo( + FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws)); + } + }; + + BuiltinFunctionBuilder::LambdaGenerator jvpGen { + [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type { + SmallVector params; + for (auto ¶mGen : fnArgGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + auto origResultType = origResultGen.build(builder); + SmallVector differentialParams; + for (auto ¶m : params) { + auto tanType = DependentMemberType::get( + param.getPlainType(), tangentVectorDecl); + differentialParams.push_back(FunctionType::Param(tanType)); + } + auto differentialResultType = DependentMemberType::get( + origResultType, tangentVectorDecl); + auto differentialType = + FunctionType::get({differentialParams}, differentialResultType); + auto jvpResultType = TupleType::get( + {TupleTypeElt(origResultType, Context.Id_value), + TupleTypeElt(differentialType, Context.Id_differential)}, Context); + return FunctionType::get(params, jvpResultType) + ->withExtInfo( + FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws)); + } + }; + + BuiltinFunctionBuilder::LambdaGenerator vjpGen { + [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type { + SmallVector params; + for (auto ¶mGen : fnArgGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + auto origResultType = origResultGen.build(builder); + SmallVector pullbackResultTupleElts; + for (auto ¶m : params) { + auto tanType = DependentMemberType::get( + param.getPlainType(), tangentVectorDecl); + pullbackResultTupleElts.push_back(TupleTypeElt(tanType)); + } + auto pullbackParam = FunctionType::Param( + DependentMemberType::get(origResultType, tangentVectorDecl)); + auto pullbackType = FunctionType::get( + {pullbackParam}, + pullbackResultTupleElts.size() == 1 + ? pullbackResultTupleElts.front().getType() + : TupleType::get(pullbackResultTupleElts, Context)); + auto vjpResultType = TupleType::get( + {TupleTypeElt(origResultType, Context.Id_value), + TupleTypeElt(pullbackType, Context.Id_pullback)}, Context); + return FunctionType::get(params, vjpResultType) + ->withExtInfo( + FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws)); + } + }; + + BuiltinFunctionBuilder::LambdaGenerator resultGen { + [&](BuiltinFunctionBuilder &builder) -> Type { + auto origFnType = origFnGen.build(builder)->castTo(); + return origFnType->withExtInfo( + origFnType->getExtInfo() + .withDifferentiabilityKind(DifferentiabilityKind::Normal)); + } + }; + + builder.addParameter(origFnGen, ValueOwnership::Owned); + builder.addParameter(jvpGen, ValueOwnership::Owned); + builder.addParameter(vjpGen, ValueOwnership::Owned); + builder.setResult(resultGen); + return builder.build(Id); +} + +static ValueDecl *getLinearFunctionConstructor( + ASTContext &Context, Identifier Id, unsigned arity, bool throws) { + assert(arity >= 1); + unsigned numGenericParams = 1 + arity; + BuiltinFunctionBuilder builder(Context, numGenericParams); + // Get the `Differentiable` and `AdditiveArithmetic` protocols. + auto *diffableProto = + Context.getProtocol(KnownProtocolKind::Differentiable); + auto *addArithProto = + Context.getProtocol(KnownProtocolKind::AdditiveArithmetic); + // Create type parameters and add conformance constraints. + auto origResultGen = makeGenericParam(arity); + builder.addConformanceRequirement(origResultGen, diffableProto); + builder.addConformanceRequirement(origResultGen, addArithProto); + SmallVector fnArgGens; + for (auto i : range(arity)) { + auto T = makeGenericParam(i); + builder.addConformanceRequirement(T, diffableProto); + builder.addConformanceRequirement(T, addArithProto); + fnArgGens.push_back(T); + } + + BuiltinFunctionBuilder::LambdaGenerator origFnGen { + [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type { + SmallVector params; + for (auto ¶mGen : fnArgGens) + params.push_back(FunctionType::Param(paramGen.build(builder))); + return FunctionType::get(params, origResultGen.build(builder)) + ->withExtInfo( + FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws)); + } + }; + + BuiltinFunctionBuilder::LambdaGenerator transposeFnGen { + [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type { + auto origResultType = origResultGen.build(builder); + SmallVector resultTupleElts; + for (auto ¶mGen : fnArgGens) + resultTupleElts.push_back(paramGen.build(builder)); + return FunctionType::get( + {FunctionType::Param(origResultType)}, + resultTupleElts.size() == 1 + ? resultTupleElts.front().getType() + : TupleType::get(resultTupleElts, Context)); + } + }; + + BuiltinFunctionBuilder::LambdaGenerator resultGen { + [&](BuiltinFunctionBuilder &builder) -> Type { + auto origFnType = origFnGen.build(builder)->castTo(); + return origFnType->withExtInfo( + origFnType->getExtInfo() + .withDifferentiabilityKind(DifferentiabilityKind::Linear)); + } + }; + + builder.addParameter(origFnGen, ValueOwnership::Owned); + builder.addParameter(transposeFnGen, ValueOwnership::Owned); + builder.setResult(resultGen); + return builder.build(Id); +} + + + static ValueDecl *getGlobalStringTablePointer(ASTContext &Context, Identifier Id) { // String -> Builtin.RawPointer @@ -1758,6 +2081,40 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { return getAllocWithTailElemsOperation(Context, Id, NumTailTypes); } + if (OperationName.startswith("applyDerivative_")) { + AutoDiffDerivativeFunctionKind kind; + unsigned arity; + bool throws; + if (!autodiff::getBuiltinApplyDerivativeConfig( + OperationName, kind, arity, throws)) + return nullptr; + return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity, + throws); + } + if (OperationName.startswith("applyTranspose_")) { + unsigned arity; + bool throws; + if (!autodiff::getBuiltinApplyTransposeConfig( + OperationName, arity, throws)) + return nullptr; + return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws); + } + if (OperationName.startswith("differentiableFunction_")) { + unsigned arity; + bool throws; + if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig( + OperationName, arity, throws)) + return nullptr; + return getDifferentiableFunctionConstructor(Context, Id, arity, throws); + } + if (OperationName.startswith("linearFunction_")) { + unsigned arity; + bool throws; + if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig( + OperationName, arity, throws)) + return nullptr; + return getLinearFunctionConstructor(Context, Id, arity, throws); + } auto BV = llvm::StringSwitch(OperationName) #define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id) @@ -2028,6 +2385,12 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { case BuiltinValueKind::UnsafeGuaranteedEnd: return getUnsafeGuaranteedEnd(Context, Id); + case BuiltinValueKind::ApplyDerivative: + case BuiltinValueKind::ApplyTranspose: + case BuiltinValueKind::DifferentiableFunction: + case BuiltinValueKind::LinearFunction: + llvm_unreachable("Handled above"); + case BuiltinValueKind::OnFastPath: return getOnFastPath(Context, Id); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 9df38af966127..597308fcdeb81 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -4059,6 +4059,14 @@ ConstructorDecl *NominalTypeDecl::getMemberwiseInitializer() const { ctx.evaluator, SynthesizeMemberwiseInitRequest{mutableThis}, nullptr); } +ConstructorDecl *NominalTypeDecl::getEffectiveMemberwiseInitializer() { + auto &ctx = getASTContext(); + auto *mutableThis = const_cast(this); + return evaluateOrDefault(ctx.evaluator, + ResolveEffectiveMemberwiseInitRequest{mutableThis}, + nullptr); +} + bool NominalTypeDecl::hasDefaultInitializer() const { // Currently only structs and classes can have default initializers. if (!isa(this) && !isa(this)) @@ -4627,26 +4635,11 @@ ProtocolDecl::ProtocolDecl(DeclContext *DC, SourceLoc ProtocolLoc, setTrailingWhereClause(TrailingWhere); } -ArrayRef -ProtocolDecl::getInheritedProtocolsSlow() { - Bits.ProtocolDecl.InheritedProtocolsValid = true; - - llvm::SmallVector result; - SmallPtrSet known; - known.insert(this); - bool anyObject = false; - for (const auto found : - getDirectlyInheritedNominalTypeDecls( - const_cast(this), anyObject)) { - if (auto proto = dyn_cast(found.Item)) { - if (known.insert(proto).second) - result.push_back(proto); - } - } - - auto &ctx = getASTContext(); - InheritedProtocols = ctx.AllocateCopy(result); - return InheritedProtocols; +ArrayRef ProtocolDecl::getInheritedProtocols() const { + auto *mutThis = const_cast(this); + return evaluateOrDefault(getASTContext().evaluator, + InheritedProtocolsRequest{mutThis}, + {}); } llvm::TinyPtrVector diff --git a/lib/AST/Module.cpp b/lib/AST/Module.cpp index 6decda7cc3856..0dd2b4448ebf3 100644 --- a/lib/AST/Module.cpp +++ b/lib/AST/Module.cpp @@ -821,7 +821,19 @@ ModuleDecl::lookupExistentialConformance(Type type, ProtocolDecl *protocol) { ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, ProtocolDecl *protocol) { - ASTContext &ctx = getASTContext(); + return evaluateOrDefault( + getASTContext().evaluator, + LookupConformanceInModuleRequest{{this, type, protocol}}, + ProtocolConformanceRef::forInvalid()); +} + +llvm::Expected +LookupConformanceInModuleRequest::evaluate( + Evaluator &evaluator, LookupConformanceDescriptor desc) const { + auto *mod = desc.Mod; + auto type = desc.Ty; + auto protocol = desc.PD; + ASTContext &ctx = mod->getASTContext(); // A dynamic Self type conforms to whatever its underlying type // conforms to. @@ -839,7 +851,7 @@ ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, // able to be resolved by a substitution that makes the archetype // concrete. if (auto super = archetype->getSuperclass()) { - if (auto inheritedConformance = lookupConformance(super, protocol)) { + if (auto inheritedConformance = mod->lookupConformance(super, protocol)) { return ProtocolConformanceRef(ctx.getInheritedConformance( type, inheritedConformance.getConcrete())); } @@ -857,7 +869,7 @@ ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, // existential's list of conformances and the existential conforms to // itself. if (type->isExistentialType()) - return lookupExistentialConformance(type, protocol); + return mod->lookupExistentialConformance(type, protocol); // Type variables have trivial conformances. if (type->isTypeVariableOrMember()) @@ -877,7 +889,7 @@ ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, // Find the (unspecialized) conformance. SmallVector conformances; - if (!nominal->lookupConformance(this, protocol, conformances)) + if (!nominal->lookupConformance(mod, protocol, conformances)) return ProtocolConformanceRef::forInvalid(); // FIXME: Ambiguity resolution. @@ -897,7 +909,7 @@ ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, auto superclassTy = type->getSuperclassForDecl(conformingClass); // Compute the conformance for the inherited type. - auto inheritedConformance = lookupConformance(superclassTy, protocol); + auto inheritedConformance = mod->lookupConformance(superclassTy, protocol); assert(inheritedConformance && "We already found the inherited conformance"); @@ -918,7 +930,7 @@ ProtocolConformanceRef ModuleDecl::lookupConformance(Type type, if (!explicitConformanceType->isEqual(type)) { // Gather the substitutions we need to map the generic conformance to // the specialized conformance. - auto subMap = type->getContextSubstitutionMap(this, explicitConformanceDC); + auto subMap = type->getContextSubstitutionMap(mod, explicitConformanceDC); // Create the specialized conformance entry. auto result = ctx.getSpecializedConformance(type, conformance, subMap); @@ -945,39 +957,50 @@ namespace { template <> struct OperatorLookup { constexpr static auto map_ptr = &SourceFile::PrefixOperators; - template - static PrefixOperatorDecl *lookup(T &container, Identifier name) { - return cast_or_null( - container.lookupOperator(name, DeclKind::PrefixOperator)); + static PrefixOperatorDecl *lookup(Evaluator &eval, + const OperatorLookupDescriptor &desc) { + // We can return the first prefix operator. All prefix operators of the + // same name are equivalent. + DirectOperatorLookupRequest req{desc, OperatorFixity::Prefix}; + auto results = evaluateOrDefault(eval, req, {}); + return results.empty() ? nullptr : cast(results[0]); } }; template <> struct OperatorLookup { constexpr static auto map_ptr = &SourceFile::InfixOperators; - template - static InfixOperatorDecl *lookup(T &container, Identifier name) { - return cast_or_null( - container.lookupOperator(name, DeclKind::InfixOperator)); + static InfixOperatorDecl *lookup(Evaluator &eval, + const OperatorLookupDescriptor &desc) { + // Return the first result if it exists. + DirectOperatorLookupRequest req{desc, OperatorFixity::Infix}; + auto results = evaluateOrDefault(eval, req, {}); + return results.empty() ? nullptr : cast(results[0]); } }; template <> struct OperatorLookup { constexpr static auto map_ptr = &SourceFile::PostfixOperators; - template - static PostfixOperatorDecl *lookup(T &container, Identifier name) { - return cast_or_null( - container.lookupOperator(name, DeclKind::PostfixOperator)); + static PostfixOperatorDecl *lookup(Evaluator &eval, + const OperatorLookupDescriptor &desc) { + // We can return the first postfix operator. All postfix operators of the + // same name are equivalent. + DirectOperatorLookupRequest req{desc, OperatorFixity::Postfix}; + auto results = evaluateOrDefault(eval, req, {}); + return results.empty() ? nullptr : cast(results[0]); } }; template <> struct OperatorLookup { constexpr static auto map_ptr = &SourceFile::PrecedenceGroups; - template - static PrecedenceGroupDecl *lookup(T &container, Identifier name) { - return container.lookupPrecedenceGroup(name); + static PrecedenceGroupDecl *lookup(Evaluator &eval, + const OperatorLookupDescriptor &desc) { + // Return the first result if it exists. + auto results = + evaluateOrDefault(eval, DirectPrecedenceGroupLookupRequest{desc}, {}); + return results.empty() ? nullptr : results[0]; } }; } // end anonymous namespace @@ -1014,7 +1037,8 @@ void SourceFile::setSyntaxRoot(syntax::SourceFileSyntax &&Root) { template static Optional -lookupOperatorDeclForName(ModuleDecl *M, SourceLoc Loc, Identifier Name); +lookupOperatorDeclForName(ModuleDecl *M, SourceLoc Loc, Identifier Name, + bool isCascading); template using ImportedOperatorsMap = llvm::SmallDenseMap; @@ -1031,9 +1055,8 @@ checkOperatorConflicts(const SourceFile &SF, SourceLoc loc, if (loc.isValid()) { ASTContext &C = SF.getASTContext(); C.Diags.diagnose(loc, diag::ambiguous_operator_decls); - C.Diags.diagnose(start->first->getLoc(), - diag::found_this_operator_decl); - C.Diags.diagnose(i->first->getLoc(), diag::found_this_operator_decl); + start->first->diagnose(diag::found_this_operator_decl); + i->first->diagnose(diag::found_this_operator_decl); } return end; } @@ -1053,8 +1076,7 @@ checkOperatorConflicts(const SourceFile &SF, SourceLoc loc, ASTContext &C = SF.getASTContext(); C.Diags.diagnose(loc, diag::ambiguous_precedence_groups); for (auto &entry : importedGroups) { - C.Diags.diagnose(entry.first->getLoc(), - diag::found_this_precedence_group); + entry.first->diagnose(diag::found_this_precedence_group); } } return importedGroups.end(); @@ -1065,7 +1087,8 @@ checkOperatorConflicts(const SourceFile &SF, SourceLoc loc, template static Optional lookupOperatorDeclForName(const FileUnit &File, SourceLoc Loc, - Identifier Name, bool includePrivate) { + Identifier Name, bool includePrivate, + bool isCascading) { switch (File.getKind()) { case FileUnitKind::Builtin: // The Builtin module declares no operators. @@ -1074,8 +1097,13 @@ lookupOperatorDeclForName(const FileUnit &File, SourceLoc Loc, break; case FileUnitKind::SerializedAST: case FileUnitKind::ClangModule: - case FileUnitKind::DWARFModule: - return OperatorLookup::lookup(cast(File), Name); + case FileUnitKind::DWARFModule: { + auto &eval = File.getASTContext().evaluator; + auto desc = OperatorLookupDescriptor::forFile(const_cast(&File), + Name, isCascading, + /*diagLoc*/ SourceLoc()); + return OperatorLookup::lookup(eval, desc); + } } auto &SF = cast(File); @@ -1104,7 +1132,8 @@ lookupOperatorDeclForName(const FileUnit &File, SourceLoc Loc, continue; Optional maybeOp = - lookupOperatorDeclForName(imported.module.second, Loc, Name); + lookupOperatorDeclForName(imported.module.second, Loc, Name, + isCascading); if (!maybeOp) return None; @@ -1137,10 +1166,12 @@ lookupOperatorDeclForName(const FileUnit &File, SourceLoc Loc, template static Optional -lookupOperatorDeclForName(ModuleDecl *M, SourceLoc Loc, Identifier Name) { +lookupOperatorDeclForName(ModuleDecl *M, SourceLoc Loc, Identifier Name, + bool isCascading) { OP_DECL *result = nullptr; for (const FileUnit *File : M->getFiles()) { - auto next = lookupOperatorDeclForName(*File, Loc, Name, false); + auto next = lookupOperatorDeclForName(*File, Loc, Name, false, + isCascading); if (!next.hasValue()) return next; @@ -1156,31 +1187,31 @@ lookupOperatorDeclForName(ModuleDecl *M, SourceLoc Loc, Identifier Name) { template llvm::Expected LookupOperatorRequest::evaluate( Evaluator &evaluator, OperatorLookupDescriptor desc) const { - auto result = lookupOperatorDeclForName(*desc.SF, desc.diagLoc, - desc.name, - /*includePrivate*/ true); + auto *file = desc.fileOrModule.get(); + auto result = + lookupOperatorDeclForName(*file, desc.diagLoc, desc.name, + /*includePrivate*/ true, + desc.isCascading); if (!result.hasValue()) return nullptr; - if (auto *tracker = desc.SF->getReferencedNameTracker()) { - if (!result.getValue() || - result.getValue()->getDeclContext()->getModuleScopeContext() != - desc.SF) { - tracker->addTopLevelName(desc.name, desc.isCascading); - } + + if (!result.getValue() || + result.getValue()->getDeclContext()->getModuleScopeContext() != file) { + namelookup::recordLookupOfTopLevelName(file, desc.name, desc.isCascading); } if (!result.getValue()) { - result = lookupOperatorDeclForName(desc.SF->getParentModule(), - desc.diagLoc, - desc.name); + result = lookupOperatorDeclForName(file->getParentModule(), + desc.diagLoc, desc.name, + desc.isCascading); } return result.hasValue() ? result.getValue() : nullptr; } - #define LOOKUP_OPERATOR(Kind) \ Kind##Decl *ModuleDecl::lookup##Kind(Identifier name, SourceLoc loc) { \ auto result = \ - lookupOperatorDeclForName(this, loc, name); \ + lookupOperatorDeclForName(this, loc, name, \ + /*isCascading*/ false); \ return result ? *result : nullptr; \ } \ template llvm::Expected \ @@ -1193,6 +1224,75 @@ LOOKUP_OPERATOR(PostfixOperator) LOOKUP_OPERATOR(PrecedenceGroup) #undef LOOKUP_OPERATOR +llvm::Expected> +DirectOperatorLookupRequest::evaluate(Evaluator &evaluator, + OperatorLookupDescriptor descriptor, + OperatorFixity fixity) const { + // Query each file. + // TODO: Module-level caching. + TinyPtrVector results; + for (auto *file : descriptor.getFiles()) + file->lookupOperatorDirect(descriptor.name, fixity, results); + + return std::move(results); +} + +void SourceFile::lookupOperatorDirect( + Identifier name, OperatorFixity fixity, + TinyPtrVector &results) const { + OperatorDecl *op = nullptr; + switch (fixity) { + case OperatorFixity::Infix: { + auto result = InfixOperators.find(name); + if (result != InfixOperators.end()) + op = result->second.getPointer(); + break; + } + case OperatorFixity::Postfix: { + auto result = PostfixOperators.find(name); + if (result != PostfixOperators.end()) + op = result->second.getPointer(); + break; + } + case OperatorFixity::Prefix: { + auto result = PrefixOperators.find(name); + if (result != PrefixOperators.end()) + op = result->second.getPointer(); + break; + } + } + + // We currently can use the operator maps to cache lookup results from other + // modules. Make sure we only return results from the source file. + if (op && op->getDeclContext()->getParentSourceFile() == this) + results.push_back(op); +} + +llvm::Expected> +DirectPrecedenceGroupLookupRequest::evaluate( + Evaluator &evaluator, OperatorLookupDescriptor descriptor) const { + // Query each file. + // TODO: Module-level caching. + TinyPtrVector results; + for (auto *file : descriptor.getFiles()) + file->lookupPrecedenceGroupDirect(descriptor.name, results); + + return std::move(results); +} + +void SourceFile::lookupPrecedenceGroupDirect( + Identifier name, TinyPtrVector &results) const { + auto result = PrecedenceGroups.find(name); + if (result == PrecedenceGroups.end()) + return; + + // We currently can use the operator maps to cache lookup results from other + // modules. Make sure we only return results from the source file. + auto *group = result->second.getPointer(); + if (group->getDeclContext()->getParentSourceFile() == this) + results.push_back(group); +} + void ModuleDecl::getImportedModules(SmallVectorImpl &modules, ModuleDecl::ImportFilter filter) const { FORWARD(getImportedModules, (modules, filter)); @@ -1597,6 +1697,80 @@ void ModuleDecl::getDeclaredCrossImportBystanders( otherModules.push_back(std::get<0>(pair)); } +using TransitiveOverlays = + llvm::SmallDenseMap, 1>; + +static void populateTransitiveCrossImports(ModuleDecl *base, + TransitiveOverlays &result) { + if (!result.empty() || !base->mightDeclareCrossImportOverlays()) + return; + + SmallVector bystanders; + SmallVector overlays; + SmallVector worklist; + SourceLoc diagLoc; // ignored + + worklist.push_back(base); + while (!worklist.empty()) { + ModuleDecl *current = worklist.back(); + worklist.pop_back(); + if (!current->mightDeclareCrossImportOverlays()) + continue; + bystanders.clear(); + current->getDeclaredCrossImportBystanders(bystanders); + for (Identifier bystander: bystanders) { + overlays.clear(); + current->findDeclaredCrossImportOverlays(bystander, overlays, diagLoc); + for (Identifier overlay: overlays) { + if (!overlay.str().startswith("_")) + continue; + ModuleDecl *overlayMod = + base->getASTContext().getModuleByName(overlay.str()); + if (!overlayMod) + continue; + if (result.insert({overlayMod, {bystander, current}}).second) + worklist.push_back(overlayMod); + } + } + } +} + +bool ModuleDecl::isUnderlyingModuleOfCrossImportOverlay( + const ModuleDecl *overlay) { + if (!overlay->getNameStr().startswith("_")) + return false; + + populateTransitiveCrossImports(this, declaredCrossImportsTransitive); + return declaredCrossImportsTransitive.find(overlay) != + declaredCrossImportsTransitive.end(); +} + +void ModuleDecl::getAllBystandersForCrossImportOverlay( + ModuleDecl *overlay, SmallVectorImpl &bystanders) { + if (!overlay->getNameStr().startswith("_")) + return; + + populateTransitiveCrossImports(this, declaredCrossImportsTransitive); + + auto end = declaredCrossImportsTransitive.end(); + for (auto i = declaredCrossImportsTransitive.find(overlay); + i != end; + i = declaredCrossImportsTransitive.find(i->second.second)) { + bystanders.push_back(i->second.first); + } +} + +void ModuleDecl::findDeclaredCrossImportOverlaysTransitive( + SmallVectorImpl &overlayModules) { + populateTransitiveCrossImports(this, declaredCrossImportsTransitive); + std::transform(declaredCrossImportsTransitive.begin(), + declaredCrossImportsTransitive.end(), + std::back_inserter(overlayModules), + [](TransitiveOverlays::iterator::value_type &i) { + return i.first; + }); +} + namespace { struct OverlayFileContents { struct Module { diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 66700d8b80d36..ef23b20fbbb14 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -2235,6 +2235,23 @@ SuperclassDeclRequest::evaluate(Evaluator &evaluator, return nullptr; } +ArrayRef +InheritedProtocolsRequest::evaluate(Evaluator &evaluator, + ProtocolDecl *PD) const { + llvm::SmallVector result; + SmallPtrSet known; + known.insert(PD); + bool anyObject = false; + for (const auto found : getDirectlyInheritedNominalTypeDecls(PD, anyObject)) { + if (auto proto = dyn_cast(found.Item)) { + if (known.insert(proto).second) + result.push_back(proto); + } + } + + return PD->getASTContext().AllocateCopy(result); +} + llvm::Expected ExtendedNominalRequest::evaluate(Evaluator &evaluator, ExtensionDecl *ext) const { diff --git a/lib/AST/NameLookupRequests.cpp b/lib/AST/NameLookupRequests.cpp index 460bc05c5e3d9..9db327cb6b3c4 100644 --- a/lib/AST/NameLookupRequests.cpp +++ b/lib/AST/NameLookupRequests.cpp @@ -68,6 +68,25 @@ void SuperclassDeclRequest::cacheResult(ClassDecl *value) const { protocolDecl->LazySemanticInfo.SuperclassDecl.setPointerAndInt(value, true); } +//----------------------------------------------------------------------------// +// InheritedProtocolsRequest computation. +//----------------------------------------------------------------------------// + +Optional> +InheritedProtocolsRequest::getCachedResult() const { + auto proto = std::get<0>(getStorage()); + if (!proto->areInheritedProtocolsValid()) + return None; + + return proto->InheritedProtocols; +} + +void InheritedProtocolsRequest::cacheResult(ArrayRef PDs) const { + auto proto = std::get<0>(getStorage()); + proto->InheritedProtocols = PDs; + proto->setInheritedProtocolsValid(); +} + //----------------------------------------------------------------------------// // Missing designated initializers computation //----------------------------------------------------------------------------// @@ -211,18 +230,45 @@ SourceLoc swift::extractNearestSourceLoc(const DirectLookupDescriptor &desc) { // LookupOperatorRequest computation. //----------------------------------------------------------------------------// +ArrayRef OperatorLookupDescriptor::getFiles() const { + if (auto *module = getModule()) + return module->getFiles(); + + // Return an ArrayRef pointing to the FileUnit in the union. + return llvm::makeArrayRef(*fileOrModule.getAddrOfPtr1()); +} + void swift::simple_display(llvm::raw_ostream &out, const OperatorLookupDescriptor &desc) { out << "looking up operator "; simple_display(out, desc.name); out << " in "; - simple_display(out, desc.SF); + simple_display(out, desc.fileOrModule); } SourceLoc swift::extractNearestSourceLoc(const OperatorLookupDescriptor &desc) { return desc.diagLoc; } +//----------------------------------------------------------------------------// +// LookupConformanceInModuleRequest computation. +//----------------------------------------------------------------------------// + +void swift::simple_display(llvm::raw_ostream &out, + const LookupConformanceDescriptor &desc) { + out << "looking up conformance to "; + simple_display(out, desc.PD); + out << " for "; + out << desc.Ty.getString(); + out << " in "; + simple_display(out, desc.Mod); +} + +SourceLoc +swift::extractNearestSourceLoc(const LookupConformanceDescriptor &desc) { + return SourceLoc(); +} + // Define request evaluation functions for each of the name lookup requests. static AbstractRequestFunction *nameLookupRequestFunctions[] = { #define SWIFT_REQUEST(Zone, Name, Sig, Caching, LocOptions) \ diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 988d6fea07a85..032fed0c2b9af 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4018,14 +4018,15 @@ TypeSubstitutionMap TypeBase::getMemberSubstitutions( isa(member)) { auto *innerDC = member->getInnermostDeclContext(); if (innerDC->isInnermostContextGeneric()) { - auto sig = innerDC->getGenericSignatureOfContext(); - for (auto param : sig->getInnermostGenericParams()) { - auto *genericParam = param->getCanonicalType() - ->castTo(); - substitutions[genericParam] = - (genericEnv - ? genericEnv->mapTypeIntoContext(param) - : param); + if (auto sig = innerDC->getGenericSignatureOfContext()) { + for (auto param : sig->getInnermostGenericParams()) { + auto *genericParam = param->getCanonicalType() + ->castTo(); + substitutions[genericParam] = + (genericEnv + ? genericEnv->mapTypeIntoContext(param) + : param); + } } } } @@ -5012,6 +5013,22 @@ CanType swift::substOpaqueTypesWithUnderlyingTypes(CanType ty, return ty.subst(replacer, replacer, flags)->getCanonicalType(); } +AnyFunctionType *AnyFunctionType::getWithoutDifferentiability() const { + SmallVector newParams; + for (auto ¶m : getParams()) { + Param newParam(param.getPlainType(), param.getLabel(), + param.getParameterFlags().withNoDerivative(false)); + newParams.push_back(newParam); + } + auto nonDiffExtInfo = getExtInfo() + .withDifferentiabilityKind(DifferentiabilityKind::NonDifferentiable); + if (isa(this)) + return FunctionType::get(newParams, getResult(), nonDiffExtInfo); + assert(isa(this)); + return GenericFunctionType::get(getOptGenericSignature(), newParams, + getResult(), nonDiffExtInfo); +} + Optional TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) { assert(lookupConformance); diff --git a/lib/Basic/Platform.cpp b/lib/Basic/Platform.cpp index 0ae286fffc297..5f59d5a756421 100644 --- a/lib/Basic/Platform.cpp +++ b/lib/Basic/Platform.cpp @@ -397,6 +397,9 @@ swift::getSwiftRuntimeCompatibilityVersionForTarget( const llvm::Triple &Triple) { unsigned Major, Minor, Micro; + if (Triple.getArchName() == "arm64e") + return llvm::VersionTuple(5, 3); + if (Triple.isMacOSX()) { Triple.getMacOSXVersion(Major, Minor, Micro); if (Major == 10) { diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 3cc2a481a5fe8..3054aa2ef4a4a 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -442,8 +442,15 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, if (Args.hasArg(OPT_fine_grained_dependency_include_intrafile)) Opts.FineGrainedDependenciesIncludeIntrafileOnes = true; - if (Args.hasArg(OPT_enable_experimental_differentiable_programming)) + if (Args.hasArg(OPT_enable_experimental_additive_arithmetic_derivation)) + Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true; + + if (Args.hasArg(OPT_enable_experimental_differentiable_programming)) { Opts.EnableExperimentalDifferentiableProgramming = true; + // Differentiable programming implies `AdditiveArithmetic` derived + // conformances. + Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true; + } Opts.DebuggerSupport |= Args.hasArg(OPT_debugger_support); if (Opts.DebuggerSupport) diff --git a/lib/Frontend/DiagnosticVerifier.cpp b/lib/Frontend/DiagnosticVerifier.cpp index 88e4f079f4108..62b9bae5a3542 100644 --- a/lib/Frontend/DiagnosticVerifier.cpp +++ b/lib/Frontend/DiagnosticVerifier.cpp @@ -530,7 +530,7 @@ DiagnosticVerifier::verifyFile(unsigned BufferID, bool shouldAutoApplyFixes) { // fixit to add them in. auto actual = renderFixits(FoundDiagnostic.getFixIts(), InputFile); auto replStartLoc = SMLoc::getFromPointer(expected.ExpectedEnd - 8); // {{none}} length - auto replEndLoc = SMLoc::getFromPointer(expected.ExpectedEnd - 1); + auto replEndLoc = SMLoc::getFromPointer(expected.ExpectedEnd); llvm::SMFixIt fix(llvm::SMRange(replStartLoc, replEndLoc), actual); addError(replStartLoc.getPointer(), "expected no fix-its; actual fix-it seen: " + actual, fix); diff --git a/lib/IDE/ModuleInterfacePrinting.cpp b/lib/IDE/ModuleInterfacePrinting.cpp index 975ea7b7243a4..88a1e41d7882e 100644 --- a/lib/IDE/ModuleInterfacePrinting.cpp +++ b/lib/IDE/ModuleInterfacePrinting.cpp @@ -238,6 +238,295 @@ swift::ide::findGroupNameForUSR(ModuleDecl *M, StringRef USR) { return None; } +/// Prints a single decl using the \p Printer and \p Options provided. +/// +/// \returns Whether the given decl was printed. +static bool printModuleInterfaceDecl(Decl *D, + ASTPrinter &Printer, + PrintOptions &Options, + bool PrintSynthesizedExtensions) { + if (!Options.shouldPrint(D)) { + Printer.callAvoidPrintDeclPost(D); + return false; + } + if (auto Ext = dyn_cast(D)) { + // Clang extensions (categories) are always printed in source order. + // Swift extensions are printed with their associated type unless it's + // a cross-module extension. + if (!extensionHasClangNode(Ext)) { + auto ExtendedNominal = Ext->getExtendedNominal(); + if (Ext->getModuleContext() == ExtendedNominal->getModuleContext()) + return false; + } + } + std::unique_ptr pAnalyzer; + if (auto NTD = dyn_cast(D)) { + if (PrintSynthesizedExtensions) { + pAnalyzer.reset(new SynthesizedExtensionAnalyzer(NTD, Options)); + Options.BracketOptions = { + NTD, true, true, + !pAnalyzer->hasMergeGroup( + SynthesizedExtensionAnalyzer::MergeGroupKind::MergeableWithTypeDef + ) + }; + } + } + if (D->print(Printer, Options)) { + if (Options.BracketOptions.shouldCloseNominal(D)) + Printer << "\n"; + Options.BracketOptions = BracketOptions(); + if (auto NTD = dyn_cast(D)) { + std::queue SubDecls{{NTD}}; + + while (!SubDecls.empty()) { + auto NTD = SubDecls.front(); + SubDecls.pop(); + + // Add sub-types of NTD. + for (auto Sub : NTD->getMembers()) + if (auto N = dyn_cast(Sub)) + SubDecls.push(N); + + // Print Ext and add sub-types of Ext. + for (auto Ext : NTD->getExtensions()) { + if (!PrintSynthesizedExtensions) { + if (!Options.shouldPrint(Ext)) { + Printer.callAvoidPrintDeclPost(Ext); + continue; + } + if (extensionHasClangNode(Ext)) + continue; // will be printed in its source location, see above. + Printer << "\n"; + Ext->print(Printer, Options); + Printer << "\n"; + } + for (auto Sub : Ext->getMembers()) + if (auto N = dyn_cast(Sub)) + SubDecls.push(N); + } + if (!PrintSynthesizedExtensions) + continue; + + bool IsTopLevelDecl = D == NTD; + + // If printed Decl is the top-level, merge the constraint-free extensions + // into the main body. + if (IsTopLevelDecl) { + // Print the part that should be merged with the type decl. + pAnalyzer->forEachExtensionMergeGroup( + SynthesizedExtensionAnalyzer::MergeGroupKind::MergeableWithTypeDef, + [&](ArrayRef Decls) { + for (auto ET : Decls) { + Options.BracketOptions = { + ET.Ext, false, Decls.back().Ext == ET.Ext, true + }; + if (ET.IsSynthesized) + Options.initForSynthesizedExtension(NTD); + ET.Ext->print(Printer, Options); + if (ET.IsSynthesized) + Options.clearSynthesizedExtension(); + if (Options.BracketOptions.shouldCloseExtension(ET.Ext)) + Printer << "\n"; + } + }); + } + + // If the printed Decl is not the top-level one, reset analyzer. + if (!IsTopLevelDecl) + pAnalyzer.reset(new SynthesizedExtensionAnalyzer(NTD, Options)); + + // Print the rest as synthesized extensions. + pAnalyzer->forEachExtensionMergeGroup( + // For top-level decls, only constraint extensions need to be + // printed, since the rest are merged into the main body. + IsTopLevelDecl + ? SynthesizedExtensionAnalyzer::MergeGroupKind::UnmergeableWithTypeDef + : SynthesizedExtensionAnalyzer::MergeGroupKind::All, + [&](ArrayRef Decls) { + // Whether we've started the extension merge group in printing. + bool Opened = false; + for (auto ET : Decls) { + Options.BracketOptions = { + ET.Ext, !Opened, Decls.back().Ext == ET.Ext, true + }; + if (Options.BracketOptions.shouldOpenExtension(ET.Ext)) + Printer << "\n"; + if (ET.IsSynthesized) { + if (ET.EnablingExt) + Options.initForSynthesizedExtension(ET.EnablingExt); + else + Options.initForSynthesizedExtension(NTD); + } + // Set opened if we actually printed this extension. + Opened |= ET.Ext->print(Printer, Options); + if (ET.IsSynthesized) + Options.clearSynthesizedExtension(); + if (Options.BracketOptions.shouldCloseExtension(ET.Ext)) + Printer << "\n"; + } + }); + Options.BracketOptions = BracketOptions(); + } + } + return true; + } + return false; +} + +/// Sorts import declarations for display. +static bool compareImports(ImportDecl *LHS, ImportDecl *RHS) { + auto LHSPath = LHS->getFullAccessPath(); + auto RHSPath = RHS->getFullAccessPath(); + for (unsigned i: range(std::min(LHSPath.size(), RHSPath.size()))) { + if (int Ret = LHSPath[i].Item.str().compare(RHSPath[i].Item.str())) + return Ret < 0; + } + return false; +}; + +/// Sorts Swift declarations for display. +static bool compareSwiftDecls(Decl *LHS, Decl *RHS) { + auto *LHSValue = dyn_cast(LHS); + auto *RHSValue = dyn_cast(RHS); + + if (LHSValue && RHSValue) { + auto LHSName = LHSValue->getBaseName(); + auto RHSName = RHSValue->getBaseName(); + if (int Ret = LHSName.compare(RHSName)) + return Ret < 0; + // FIXME: not sufficient to establish a total order for overloaded decls. + } + return LHS->getKind() < RHS->getKind(); +}; + +static std::pair, ArrayRef> +getDeclsFromOverlay(ModuleDecl *Overlay, ModuleDecl *Underlying, + SmallVectorImpl &Decls, AccessLevel AccessFilter) { + Overlay->getDisplayDecls(Decls); + + // Collector the imports of the underlying module so we can filter them out. + SmallPtrSet PrevImported; + SmallVector UnderlyingDecls; + Underlying->getDisplayDecls(UnderlyingDecls); + for (auto *D: UnderlyingDecls) { + if (auto *ID = dyn_cast(D)) + PrevImported.insert(ID->getModule()); + } + + // Filter out inaccessible decls and any imports of, or shared with the + // underlying module. + auto NewEnd = std::partition(Decls.begin(), Decls.end(), [&](Decl *D) { + if (auto *ID = dyn_cast(D)) { + // Ignore imports of the underlying module, or any cross-import + // that would map back to it. + if (ID->getModule() == Underlying || + Underlying->isUnderlyingModuleOfCrossImportOverlay(ID->getModule())) + return false; + // Ignore an imports of modules also imported by the underlying module. + if (PrevImported.find(ID->getModule()) != PrevImported.end()) + return false; + } + if (auto *VD = dyn_cast(D)) { + if (AccessFilter > AccessLevel::Private && + VD->getFormalAccess() < AccessFilter) + return false; + } + return true; + }); + if (NewEnd != Decls.end()) + Decls.erase(NewEnd, Decls.end()); + + // Separate out the import declarations and sort + MutableArrayRef Imports, Remainder; + auto ImportsEnd = std::partition(Decls.begin(), Decls.end(), [](Decl *D) { + return isa(D); + }); + if (ImportsEnd != Decls.begin()) { + Imports = {Decls.begin(), ImportsEnd}; + Remainder = {ImportsEnd, Decls.end()}; + std::sort(Imports.begin(), Imports.end(), [](Decl *LHS, Decl *RHS) { + return compareImports(cast(LHS), cast(RHS)); + }); + } else { + Remainder = Decls; + } + std::sort(Remainder.begin(), Remainder.end(), compareSwiftDecls); + return {Imports, Remainder}; +} + +static void printCrossImportOverlays(ModuleDecl *Underlying, ASTContext &Ctx, + ASTPrinter &Printer, + PrintOptions Options, + bool PrintSynthesizedExtensions) { + // If we end up printing decls from any cross-import overlay modules, make + // sure we map any qualifying module references to the underlying module. + Options.mapModuleToUnderlying = [&](const ModuleDecl *M) { + if (Underlying->isUnderlyingModuleOfCrossImportOverlay(M)) + return const_cast(Underlying); + return M; + }; + + SmallVector OverlayDecls; + SmallVector Bystanders; + + auto PrintDecl = [&](Decl *D) { + return printModuleInterfaceDecl(D, Printer, Options, + PrintSynthesizedExtensions); + }; + + SmallVector OverlayModules; + Underlying->findDeclaredCrossImportOverlaysTransitive(OverlayModules); + std::sort(OverlayModules.begin(), OverlayModules.end(), + [](ModuleDecl *LHS, ModuleDecl *RHS) { + return LHS->getNameStr() < RHS->getNameStr(); + }); + + for (auto *Overlay: OverlayModules) { + OverlayDecls.clear(); + auto DeclLists = getDeclsFromOverlay(Overlay, Underlying, OverlayDecls, + Options.AccessFilter); + + // Ignore overlays without any decls + if (OverlayDecls.empty()) + continue; + + Bystanders.clear(); + Underlying->getAllBystandersForCrossImportOverlay(Overlay, Bystanders); + assert(!Bystanders.empty() && "Overlay with no bystanders?"); + std::sort(Bystanders.begin(), Bystanders.end(), + [](Identifier LHS, Identifier RHS) { + return LHS.str() < RHS.str(); + }); + + std::string BystanderList; + for (size_t I: range(Bystanders.size())) { + if (I == Bystanders.size() - 1) { + if (I != 0) + BystanderList += " and "; + } else if (I != 0) { + BystanderList += ", "; + } + BystanderList += Bystanders[I].str(); + } + + Printer << "\n// MARK: - " << BystanderList << " Additions\n\n"; + for (auto *Import : DeclLists.first) + PrintDecl(Import); + Printer << "\n"; + + std::string PerDeclComment = "// Available when " + BystanderList; + PerDeclComment += Bystanders.size() == 1 ? " is" : " are"; + PerDeclComment += " imported with " + Underlying->getNameStr().str(); + + for (auto *D : DeclLists.second) { + // FIXME: only print this comment if the decl is actually printed. + Printer << PerDeclComment << "\n"; + if (PrintDecl(D)) + Printer << "\n"; + } + } +} + void swift::ide::printSubmoduleInterface( ModuleDecl *M, ArrayRef FullModuleName, @@ -246,18 +535,17 @@ void swift::ide::printSubmoduleInterface( ASTPrinter &Printer, const PrintOptions &Options, const bool PrintSynthesizedExtensions) { + auto &SwiftContext = M->getASTContext(); + auto &Importer = + static_cast(*SwiftContext.getClangModuleLoader()); + auto AdjustedOptions = Options; adjustPrintOptions(AdjustedOptions); SmallVector Decls; M->getDisplayDecls(Decls); - auto &SwiftContext = M->getASTContext(); - auto &Importer = - static_cast(*SwiftContext.getClangModuleLoader()); - const clang::Module *InterestingClangModule = nullptr; - SmallVector ImportDecls; llvm::DenseSet ClangModulesForImports; SmallVector SwiftDecls; @@ -429,9 +717,8 @@ void swift::ide::printSubmoduleInterface( ImportDecls.push_back(createImportDecl(M->getASTContext(), M, SM, {})); } + // Sort imported clang declarations in source order *within a submodule*. auto &ClangSourceManager = Importer.getClangASTContext().getSourceManager(); - - // Sort imported declarations in source order *within a submodule*. for (auto &P : ClangDecls) { std::stable_sort(P.second.begin(), P.second.end(), [&](std::pair LHS, @@ -442,39 +729,12 @@ void swift::ide::printSubmoduleInterface( } // Sort Swift declarations so that we print them in a consistent order. - std::sort(ImportDecls.begin(), ImportDecls.end(), - [](ImportDecl *LHS, ImportDecl *RHS) -> bool { - auto LHSPath = LHS->getFullAccessPath(); - auto RHSPath = RHS->getFullAccessPath(); - for (unsigned i = 0, e = std::min(LHSPath.size(), RHSPath.size()); i != e; - i++) { - if (int Ret = LHSPath[i].Item.str().compare(RHSPath[i].Item.str())) - return Ret < 0; - } - return false; - }); + std::sort(ImportDecls.begin(), ImportDecls.end(), compareImports); // If the group name is specified, we sort them according to their source order, // which is the order preserved by getTopLevelDecls. - if (GroupNames.empty()) { - std::stable_sort(SwiftDecls.begin(), SwiftDecls.end(), - [&](Decl *LHS, Decl *RHS) -> bool { - auto *LHSValue = dyn_cast(LHS); - auto *RHSValue = dyn_cast(RHS); - - if (LHSValue && RHSValue) { - auto LHSName = LHSValue->getBaseName(); - auto RHSName = RHSValue->getBaseName(); - if (int Ret = LHSName.compare(RHSName)) - return Ret < 0; - // FIXME: this is not sufficient to establish a total order for overloaded - // decls. - return LHS->getKind() < RHS->getKind(); - } - - return LHS->getKind() < RHS->getKind(); - }); - } + if (GroupNames.empty()) + std::stable_sort(SwiftDecls.begin(), SwiftDecls.end(), compareSwiftDecls); ASTPrinter *PrinterToUse = &Printer; @@ -482,136 +742,9 @@ void swift::ide::printSubmoduleInterface( if (Options.PrintRegularClangComments) PrinterToUse = &RegularCommentPrinter; - auto PrintDecl = [&](Decl *D) -> bool { - ASTPrinter &Printer = *PrinterToUse; - if (!AdjustedOptions.shouldPrint(D)) { - Printer.callAvoidPrintDeclPost(D); - return false; - } - if (auto Ext = dyn_cast(D)) { - // Clang extensions (categories) are always printed in source order. - // Swift extensions are printed with their associated type unless it's - // a cross-module extension. - if (!extensionHasClangNode(Ext)) { - auto ExtendedNominal = Ext->getExtendedNominal(); - if (Ext->getModuleContext() == ExtendedNominal->getModuleContext()) - return false; - } - } - std::unique_ptr pAnalyzer; - if (auto NTD = dyn_cast(D)) { - if (PrintSynthesizedExtensions) { - pAnalyzer.reset(new SynthesizedExtensionAnalyzer(NTD, AdjustedOptions)); - AdjustedOptions.BracketOptions = {NTD, true, true, - !pAnalyzer->hasMergeGroup(SynthesizedExtensionAnalyzer:: - MergeGroupKind::MergeableWithTypeDef)}; - } - } - if (D->print(Printer, AdjustedOptions)) { - if (AdjustedOptions.BracketOptions.shouldCloseNominal(D)) - Printer << "\n"; - AdjustedOptions.BracketOptions = BracketOptions(); - if (auto NTD = dyn_cast(D)) { - std::queue SubDecls{{NTD}}; - - while (!SubDecls.empty()) { - auto NTD = SubDecls.front(); - SubDecls.pop(); - - // Add sub-types of NTD. - for (auto Sub : NTD->getMembers()) - if (auto N = dyn_cast(Sub)) - SubDecls.push(N); - - // Print Ext and add sub-types of Ext. - for (auto Ext : NTD->getExtensions()) { - if (!PrintSynthesizedExtensions) { - if (!AdjustedOptions.shouldPrint(Ext)) { - Printer.callAvoidPrintDeclPost(Ext); - continue; - } - if (extensionHasClangNode(Ext)) - continue; // will be printed in its source location, see above. - Printer << "\n"; - Ext->print(Printer, AdjustedOptions); - Printer << "\n"; - } - for (auto Sub : Ext->getMembers()) - if (auto N = dyn_cast(Sub)) - SubDecls.push(N); - } - if (!PrintSynthesizedExtensions) - continue; - - bool IsTopLevelDecl = D == NTD; - - // If printed Decl is the top-level, merge the constraint-free extensions - // into the main body. - if (IsTopLevelDecl) { - // Print the part that should be merged with the type decl. - pAnalyzer->forEachExtensionMergeGroup( - SynthesizedExtensionAnalyzer::MergeGroupKind:: - MergeableWithTypeDef, - [&](ArrayRef Decls) { - for (auto ET : Decls) { - AdjustedOptions.BracketOptions = { - ET.Ext, false, Decls.back().Ext == ET.Ext, true}; - if (ET.IsSynthesized) - AdjustedOptions.initForSynthesizedExtension(NTD); - ET.Ext->print(Printer, AdjustedOptions); - if (ET.IsSynthesized) - AdjustedOptions.clearSynthesizedExtension(); - if (AdjustedOptions.BracketOptions.shouldCloseExtension( - ET.Ext)) - Printer << "\n"; - } - }); - } - - // If the printed Decl is not the top-level one, reset analyzer. - if (!IsTopLevelDecl) - pAnalyzer.reset(new SynthesizedExtensionAnalyzer(NTD, AdjustedOptions)); - - // Print the rest as synthesized extensions. - pAnalyzer->forEachExtensionMergeGroup( - // For top-level decls, only constraint extensions need to be - // printed, since the rest are merged into the main body. - IsTopLevelDecl ? SynthesizedExtensionAnalyzer::MergeGroupKind:: - UnmergeableWithTypeDef - : - // For sub-decls, all extensions should be printed. - SynthesizedExtensionAnalyzer::MergeGroupKind::All, - [&](ArrayRef Decls) { - // Whether we've started the extension merge group in printing. - bool Opened = false; - for (auto ET : Decls) { - AdjustedOptions.BracketOptions = { ET.Ext, !Opened, - Decls.back().Ext == ET.Ext, true}; - if (AdjustedOptions.BracketOptions.shouldOpenExtension( - ET.Ext)) - Printer << "\n"; - if (ET.IsSynthesized) { - if (ET.EnablingExt) - AdjustedOptions.initForSynthesizedExtension( - ET.EnablingExt); - else - AdjustedOptions.initForSynthesizedExtension(NTD); - } - // Set opened if we actually printed this extension. - Opened |= ET.Ext->print(Printer, AdjustedOptions); - if (ET.IsSynthesized) - AdjustedOptions.clearSynthesizedExtension(); - if (AdjustedOptions.BracketOptions.shouldCloseExtension( - ET.Ext)) - Printer << "\n"; - } - }); - AdjustedOptions.BracketOptions = BracketOptions(); - } - } - return true; - } - return false; + auto PrintDecl = [&](Decl *D) { + return printModuleInterfaceDecl(D, *PrinterToUse, AdjustedOptions, + PrintSynthesizedExtensions); }; // Imports from the stdlib are internal details that don't need to be exposed. @@ -646,6 +779,14 @@ void swift::ide::printSubmoduleInterface( if (PrintDecl(D)) Printer << "\n"; } + + // If we're printing the entire target module (not specific sub-groups), + // also print the decls from any underscored Swift cross-import overlays it + // is the underlying module of, transitively. + if (GroupNames.empty()) { + printCrossImportOverlays(M, SwiftContext, *PrinterToUse, AdjustedOptions, + PrintSynthesizedExtensions); + } } } diff --git a/lib/IRGen/CMakeLists.txt b/lib/IRGen/CMakeLists.txt index 493e608d8fd3c..1e73fff73d49b 100644 --- a/lib/IRGen/CMakeLists.txt +++ b/lib/IRGen/CMakeLists.txt @@ -16,6 +16,7 @@ add_swift_host_library(swiftIRGen STATIC GenControl.cpp GenCoverage.cpp GenDecl.cpp + GenDiffFunc.cpp GenDiffWitness.cpp GenEnum.cpp GenExistential.cpp diff --git a/lib/IRGen/GenClass.cpp b/lib/IRGen/GenClass.cpp index 6e371750f7465..0b62510c2c516 100644 --- a/lib/IRGen/GenClass.cpp +++ b/lib/IRGen/GenClass.cpp @@ -1644,12 +1644,15 @@ namespace { void buildExtMethodTypes(ConstantArrayBuilder &array, ArrayRef methods) { assert(isBuildingProtocol()); - + llvm::StringSet<> uniqueSelectors; for (auto descriptor : methods) { assert(descriptor.getKind() == MethodDescriptor::Kind::Method && "cannot emit descriptor for non-method"); auto method = descriptor.getMethod(); - array.add(getMethodTypeExtendedEncoding(IGM, method)); + auto *encodingOrNullIfDuplicate = + getMethodTypeExtendedEncoding(IGM, method, uniqueSelectors); + if (encodingOrNullIfDuplicate != nullptr) + array.add(encodingOrNullIfDuplicate); } } diff --git a/lib/IRGen/GenDiffFunc.cpp b/lib/IRGen/GenDiffFunc.cpp new file mode 100644 index 0000000000000..349160231ac08 --- /dev/null +++ b/lib/IRGen/GenDiffFunc.cpp @@ -0,0 +1,350 @@ +//===- GenDiffFunc.cpp - Swift IR Generation For @differentiable Functions ===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file implements IR generation for `@differentiable` function types in +// Swift. +// +//===----------------------------------------------------------------------===// + +#include "swift/AST/Decl.h" +#include "swift/AST/Pattern.h" +#include "swift/AST/Types.h" +#include "swift/SIL/SILModule.h" +#include "swift/SIL/SILType.h" +#include "llvm/IR/DerivedTypes.h" + +#include "Explosion.h" +#include "GenHeap.h" +#include "GenRecord.h" +#include "GenType.h" +#include "IRGenFunction.h" +#include "IRGenModule.h" +#include "IndirectTypeInfo.h" +#include "NonFixedTypeInfo.h" + +#pragma clang diagnostic ignored "-Winconsistent-missing-override" + +using namespace swift; +using namespace irgen; + +//----------------------------------------------------------------------------// +// `@differentiable` (non-linear) function type info +//----------------------------------------------------------------------------// + +namespace { +class DifferentiableFuncFieldInfo final + : public RecordField { +public: + DifferentiableFuncFieldInfo( + NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type, + IndexSubset *parameterIndices) + : RecordField(type), component(component), + parameterIndices(parameterIndices) {} + + /// The field index. + const NormalDifferentiableFunctionTypeComponent component; + + /// The parameter indices. + IndexSubset *parameterIndices; + + std::string getFieldName() const { + switch (component) { + case NormalDifferentiableFunctionTypeComponent::Original: + return "original"; + case NormalDifferentiableFunctionTypeComponent::JVP: + return "jvp"; + case NormalDifferentiableFunctionTypeComponent::VJP: + return "vjp"; + } + } + + SILType getType(IRGenModule &IGM, SILType t) const { + auto fnTy = t.castTo(); + auto origFnTy = fnTy->getWithoutDifferentiability(); + if (component == NormalDifferentiableFunctionTypeComponent::Original) + return SILType::getPrimitiveObjectType(origFnTy); + auto kind = *component.getAsDerivativeFunctionKind(); + auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( + parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(assocTy); + } +}; + +class DifferentiableFuncTypeInfo final + : public RecordTypeInfo { + using super = RecordTypeInfo; + +public: + DifferentiableFuncTypeInfo(ArrayRef fields, + unsigned explosionSize, llvm::Type *ty, Size size, + SpareBitVector &&spareBits, Alignment align, + IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize) + : super(fields, explosionSize, ty, size, std::move(spareBits), align, + isPOD, alwaysFixedSize) {} + + Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, + const DifferentiableFuncFieldInfo &field) const { + return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); + } + + void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, + SILType T, bool isOutlined) const override { + llvm_unreachable("unexploded @differentiable function as argument?"); + } + + void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, + Size offset) const override { + for (auto &field : getFields()) { + auto fieldOffset = offset + field.getFixedByteOffset(); + cast(field.getTypeInfo()) + .addToAggLowering(IGM, lowering, fieldOffset); + } + } + + TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, + SILType T) const override { + return IGM.typeLayoutCache.getOrCreateScalarEntry(*this, T); + } + + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { + return None; + } +}; + +class DifferentiableFuncTypeBuilder + : public RecordTypeBuilder { + + SILFunctionType *originalType; + IndexSubset *parameterIndices; + +public: + DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) + : RecordTypeBuilder(IGM), + originalType(fnTy->getWithoutDifferentiability()), + parameterIndices(fnTy->getDifferentiabilityParameterIndices()) { + assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal); + } + + TypeInfo *createFixed(ArrayRef fields, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + DifferentiableFuncTypeInfo * + createLoadable(ArrayRef fields, + StructLayout &&layout, unsigned explosionSize) { + return DifferentiableFuncTypeInfo::create( + fields, explosionSize, layout.getType(), layout.getSize(), + std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), + layout.isAlwaysFixedSize()); + } + + TypeInfo *createNonFixed(ArrayRef fields, + FieldsAreABIAccessible_t fieldsAccessible, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + DifferentiableFuncFieldInfo + getFieldInfo(unsigned index, + NormalDifferentiableFunctionTypeComponent component, + const TypeInfo &fieldTI) { + return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices); + } + + SILType getType(NormalDifferentiableFunctionTypeComponent component) { + if (component == NormalDifferentiableFunctionTypeComponent::Original) + return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); + auto kind = *component.getAsDerivativeFunctionKind(); + auto assocTy = originalType->getAutoDiffDerivativeFunctionType( + parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(assocTy); + } + + StructLayout performLayout(ArrayRef fieldTypes) { + return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, + LayoutStrategy::Universal, fieldTypes); + } +}; +} // end anonymous namespace + +//----------------------------------------------------------------------------// +// `@differentiable(linear)` function type info +//----------------------------------------------------------------------------// + +namespace { +class LinearFuncFieldInfo final : public RecordField { +public: + LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component, + const TypeInfo &type, IndexSubset *parameterIndices) + : RecordField(type), component(component), + parameterIndices(parameterIndices) {} + + /// The field index. + const LinearDifferentiableFunctionTypeComponent component; + + /// The parameter indices. + IndexSubset *parameterIndices; + + std::string getFieldName() const { + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return "original"; + case LinearDifferentiableFunctionTypeComponent::Transpose: + return "transpose"; + } + } + + SILType getType(IRGenModule &IGM, SILType t) const { + auto fnTy = t.castTo(); + auto origFnTy = fnTy->getWithoutDifferentiability(); + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return SILType::getPrimitiveObjectType(origFnTy); + case LinearDifferentiableFunctionTypeComponent::Transpose: + auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType( + parameterIndices, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(transposeTy); + } + } +}; + +class LinearFuncTypeInfo final + : public RecordTypeInfo { + using super = + RecordTypeInfo; + +public: + LinearFuncTypeInfo(ArrayRef fields, + unsigned explosionSize, llvm::Type *ty, Size size, + SpareBitVector &&spareBits, Alignment align, IsPOD_t isPOD, + IsFixedSize_t alwaysFixedSize) + : super(fields, explosionSize, ty, size, std::move(spareBits), align, + isPOD, alwaysFixedSize) {} + + Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, + const LinearFuncFieldInfo &field) const { + return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); + } + + void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, + SILType T, bool isOutlined) const override { + llvm_unreachable("unexploded @differentiable function as argument?"); + } + + void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, + Size offset) const override { + for (auto &field : getFields()) { + auto fieldOffset = offset + field.getFixedByteOffset(); + cast(field.getTypeInfo()) + .addToAggLowering(IGM, lowering, fieldOffset); + } + } + + TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, + SILType T) const override { + return IGM.typeLayoutCache.getOrCreateScalarEntry(*this, T); + } + + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { + return None; + } +}; + +class LinearFuncTypeBuilder + : public RecordTypeBuilder { + + SILFunctionType *originalType; + IndexSubset *parameterIndices; + +public: + LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) + : RecordTypeBuilder(IGM), + originalType(fnTy->getWithoutDifferentiability()), + parameterIndices(fnTy->getDifferentiabilityParameterIndices()) { + assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); + } + + TypeInfo *createFixed(ArrayRef fields, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + LinearFuncTypeInfo *createLoadable(ArrayRef fields, + StructLayout &&layout, + unsigned explosionSize) { + return LinearFuncTypeInfo::create( + fields, explosionSize, layout.getType(), layout.getSize(), + std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), + layout.isAlwaysFixedSize()); + } + + TypeInfo *createNonFixed(ArrayRef fields, + FieldsAreABIAccessible_t fieldsAccessible, + StructLayout &&layout) { + llvm_unreachable("@differentiable functions are always loadable"); + } + + LinearFuncFieldInfo + getFieldInfo(unsigned index, LinearDifferentiableFunctionTypeComponent field, + const TypeInfo &fieldTI) { + return LinearFuncFieldInfo(field, fieldTI, parameterIndices); + } + + SILType getType(LinearDifferentiableFunctionTypeComponent component) { + switch (component) { + case LinearDifferentiableFunctionTypeComponent::Original: + return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); + case LinearDifferentiableFunctionTypeComponent::Transpose: + auto transposeTy = originalType->getAutoDiffTransposeFunctionType( + parameterIndices, IGM.getSILTypes(), + LookUpConformanceInModule(IGM.getSwiftModule())); + return SILType::getPrimitiveObjectType(transposeTy); + } + } + + StructLayout performLayout(ArrayRef fieldTypes) { + return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, + LayoutStrategy::Universal, fieldTypes); + } +}; +} // end anonymous namespace + +//----------------------------------------------------------------------------// +// Type converter entry points +//----------------------------------------------------------------------------// + +const TypeInfo * +TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) { + DifferentiableFuncTypeBuilder builder(IGM, type); + return builder.layout({NormalDifferentiableFunctionTypeComponent::Original, + NormalDifferentiableFunctionTypeComponent::JVP, + NormalDifferentiableFunctionTypeComponent::VJP}); +} + +const TypeInfo * +TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) { + LinearFuncTypeBuilder builder(IGM, type); + return builder.layout({LinearDifferentiableFunctionTypeComponent::Original, + LinearDifferentiableFunctionTypeComponent::Transpose}); +} diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index 1a58310081503..c2704f088025e 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -498,6 +498,16 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF, } const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) { + // Handle `@differentiable` and `@differentiable(linear)` functions. + switch (T->getDifferentiabilityKind()) { + case DifferentiabilityKind::Normal: + return convertNormalDifferentiableFunctionType(T); + case DifferentiabilityKind::Linear: + return convertLinearDifferentiableFunctionType(T); + case DifferentiabilityKind::NonDifferentiable: + break; + } + switch (T->getRepresentation()) { case SILFunctionType::Representation::Block: return new BlockTypeInfo(CanSILFunctionType(T), diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index bb4d4cc506c7d..27fd6fc10c996 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -4634,6 +4634,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::Encodable: case KnownProtocolKind::Decodable: case KnownProtocolKind::StringInterpolationProtocol: + case KnownProtocolKind::AdditiveArithmetic: case KnownProtocolKind::Differentiable: return SpecialProtocol::None; } diff --git a/lib/IRGen/GenObjC.cpp b/lib/IRGen/GenObjC.cpp index fd742cb91f0ff..0041e3ef3e68c 100644 --- a/lib/IRGen/GenObjC.cpp +++ b/lib/IRGen/GenObjC.cpp @@ -1409,7 +1409,13 @@ void irgen::emitObjCIVarInitDestroyDescriptor( llvm::Constant * irgen::getMethodTypeExtendedEncoding(IRGenModule &IGM, - AbstractFunctionDecl *method) { + AbstractFunctionDecl *method, + llvm::StringSet<> &uniqueSelectors) { + // Don't emit a selector twice. + Selector selector(method); + if (!uniqueSelectors.insert(selector.str()).second) + return nullptr; + CanSILFunctionType methodType = getObjCMethodType(IGM, method); return getObjCEncodingForMethod(IGM, methodType, true /*Extended*/, method); } diff --git a/lib/IRGen/GenObjC.h b/lib/IRGen/GenObjC.h index c9dd3ab03b68f..33d582d59eee5 100644 --- a/lib/IRGen/GenObjC.h +++ b/lib/IRGen/GenObjC.h @@ -177,10 +177,12 @@ namespace irgen { CanSILFunctionType invokeTy); /// Produces extended encoding of method type. - /// \returns the encoded type. - llvm::Constant *getMethodTypeExtendedEncoding(IRGenModule &IGM, - AbstractFunctionDecl *method); - + /// \returns the encoded type or null if it is a duplicate (exists in + /// \p uniqueSelectors). + llvm::Constant * + getMethodTypeExtendedEncoding(IRGenModule &IGM, AbstractFunctionDecl *method, + llvm::StringSet<> &uniqueSelectors); + /// Build an Objective-C method descriptor for the given getter method. void emitObjCGetterDescriptor(IRGenModule &IGM, ConstantArrayBuilder &descriptors, diff --git a/lib/IRGen/GenProto.cpp b/lib/IRGen/GenProto.cpp index 3a6567b94b21c..97db4b0a6b3c4 100644 --- a/lib/IRGen/GenProto.cpp +++ b/lib/IRGen/GenProto.cpp @@ -2737,12 +2737,8 @@ static void addAbstractConditionalRequirements( SpecializedProtocolConformance *specializedConformance, llvm::SetVector &requirements) { auto subMap = specializedConformance->getSubstitutionMap(); - auto condRequirements = - specializedConformance->getConditionalRequirementsIfAvailable(); - if (!condRequirements) - return; - - for (auto req : *condRequirements) { + auto condRequirements = specializedConformance->getConditionalRequirements(); + for (auto req : condRequirements) { if (req.getKind() != RequirementKind::Conformance) continue; auto *proto = diff --git a/lib/IRGen/GenReflection.cpp b/lib/IRGen/GenReflection.cpp index c1264e1dfce42..a8a291c832809 100644 --- a/lib/IRGen/GenReflection.cpp +++ b/lib/IRGen/GenReflection.cpp @@ -827,13 +827,13 @@ static bool deploymentTargetHasRemoteMirrorZeroSizedTypeDescriptorBug(IRGenModule &IGM) { auto target = IGM.Context.LangOpts.Target; - if (target.isMacOSX() && target.isMacOSXVersionLT(10, 16, 0)) { + if (target.isMacOSX() && target.isMacOSXVersionLT(10, 15, 4)) { return true; } - if (target.isiOS() && target.isOSVersionLT(14)) { // includes tvOS + if (target.isiOS() && target.isOSVersionLT(13, 4)) { // includes tvOS return true; } - if (target.isWatchOS() && target.isOSVersionLT(7)) { + if (target.isWatchOS() && target.isOSVersionLT(6, 2)) { return true; } diff --git a/lib/IRGen/GenType.h b/lib/IRGen/GenType.h index dd7086aa1ea97..3b03e931f42d0 100644 --- a/lib/IRGen/GenType.h +++ b/lib/IRGen/GenType.h @@ -145,6 +145,8 @@ class TypeConverter { const TypeInfo *convertEnumType(TypeBase *key, CanType type, EnumDecl *D); const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D); const TypeInfo *convertFunctionType(SILFunctionType *T); + const TypeInfo *convertNormalDifferentiableFunctionType(SILFunctionType *T); + const TypeInfo *convertLinearDifferentiableFunctionType(SILFunctionType *T); const TypeInfo *convertBlockStorageType(SILBlockStorageType *T); const TypeInfo *convertBoxType(SILBoxType *T); const TypeInfo *convertArchetypeType(ArchetypeType *T); diff --git a/lib/IRGen/IRGenMangler.cpp b/lib/IRGen/IRGenMangler.cpp index 7dbac2eaf2f82..e83e186519ed2 100644 --- a/lib/IRGen/IRGenMangler.cpp +++ b/lib/IRGen/IRGenMangler.cpp @@ -308,15 +308,7 @@ std::string IRGenMangler::mangleSymbolNameForMangledConformanceAccessorString( if (genericSig) appendGenericSignature(genericSig); - if (type) - appendType(type); - - if (conformance.isConcrete()) - appendConcreteProtocolConformance(conformance.getConcrete()); - else if (conformance.isAbstract()) - appendProtocolName(conformance.getAbstract()); - else - assert(conformance.isInvalid() && "Unknown protocol conformance"); + appendAnyProtocolConformance(genericSig, type, conformance); return finalize(); } diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index 38577f4c8c4c7..a1acf3489a5ad 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -1044,6 +1044,9 @@ class IRGenSILFunction : void visitKeyPathInst(KeyPathInst *I); + void visitDifferentiableFunctionInst(DifferentiableFunctionInst *i); + void + visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *i); void visitDifferentiabilityWitnessFunctionInst( DifferentiabilityWitnessFunctionInst *i); @@ -1814,6 +1817,62 @@ void IRGenSILFunction::visitSILBasicBlock(SILBasicBlock *BB) { assert(Builder.hasPostTerminatorIP() && "SIL bb did not terminate block?!"); } +void IRGenSILFunction::visitDifferentiableFunctionInst( + DifferentiableFunctionInst *i) { + auto origFnExp = getLoweredExplosion(i->getOriginalFunction()); + Explosion e; + e.add(origFnExp.claimAll()); + // TODO(TF-1211): Uncomment assertions after upstreaming differentiation + // transform. + // The mandatory differentiation transform canonicalizes + // `differentiable_function` instructions and ensures that derivative operands + // are populated. + /* + assert(i->hasDerivativeFunctions()); + for (auto &derivFnOperand : i->getDerivativeFunctionArray()) + e.add(getLoweredExplosion(derivFnOperand.get()).claimAll()); + setLoweredExplosion(i, e); + */ + // Note: code below is a temporary measure until TF-1211. Derivative function + // operands should always exist after the differentiation transform. + auto getDerivativeExplosion = [&](AutoDiffDerivativeFunctionKind kind) { + // If the derivative value exists, get its explosion. + if (i->hasDerivativeFunctions()) + return getLoweredExplosion(i->getDerivativeFunction(kind)); + // Otherwise, create an undef explosion. + auto origFnType = + i->getOriginalFunction()->getType().castTo(); + auto derivativeFnType = origFnType->getAutoDiffDerivativeFunctionType( + i->getParameterIndices(), /*resultIndex*/ 0, kind, i->getModule().Types, + LookUpConformanceInModule(i->getModule().getSwiftModule())); + auto *undef = SILUndef::get( + SILType::getPrimitiveObjectType(derivativeFnType), *i->getFunction()); + return getLoweredExplosion(undef); + }; + auto jvpExp = getDerivativeExplosion(AutoDiffDerivativeFunctionKind::JVP); + e.add(jvpExp.claimAll()); + auto vjpExp = getDerivativeExplosion(AutoDiffDerivativeFunctionKind::VJP); + e.add(vjpExp.claimAll()); + setLoweredExplosion(i, e); +} + +void IRGenSILFunction::visitDifferentiableFunctionExtractInst( + DifferentiableFunctionExtractInst *i) { + unsigned structFieldOffset = i->getExtractee().rawValue; + unsigned fieldSize = 1; + auto fnRepr = i->getOperand()->getType().getFunctionRepresentation(); + if (fnRepr == SILFunctionTypeRepresentation::Thick) { + structFieldOffset *= 2; + fieldSize = 2; + } + auto diffFnExp = getLoweredExplosion(i->getOperand()); + assert(diffFnExp.size() == fieldSize * 3); + Explosion e; + e.add(diffFnExp.getRange(structFieldOffset, structFieldOffset + fieldSize)); + (void)diffFnExp.claimAll(); + setLoweredExplosion(i, e); +} + void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst( DifferentiabilityWitnessFunctionInst *i) { llvm::Value *diffWitness = diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index ff4ff034f94d2..8ec55b7eaa1ce 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -816,12 +816,8 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) { /// \verbatim /// differentiable-attribute-arguments: /// '(' (differentiability-params-clause ',')? -/// (differentiable-attr-func-specifier ',')? -/// differentiable-attr-func-specifier? /// where-clause? /// ')' -/// differentiable-attr-func-specifier: -/// ('jvp' | 'vjp') ':' decl-name /// \endverbatim ParserResult Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { @@ -829,15 +825,12 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { SourceLoc lParenLoc = loc, rParenLoc = loc; bool linear = false; SmallVector parameters; - Optional jvpSpec; - Optional vjpSpec; TrailingWhereClause *whereClause = nullptr; // Parse '('. if (consumeIf(tok::l_paren, lParenLoc)) { // Parse @differentiable attribute arguments. - if (parseDifferentiableAttributeArguments(linear, parameters, jvpSpec, - vjpSpec, whereClause)) + if (parseDifferentiableAttributeArguments(linear, parameters, whereClause)) return makeParserError(); // Parse ')'. if (!consumeIf(tok::r_paren, rParenLoc)) { @@ -849,7 +842,7 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { return ParserResult(DifferentiableAttr::create( Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), linear, - parameters, jvpSpec, vjpSpec, whereClause)); + parameters, whereClause)); } // Attribute parsing error helper. @@ -963,8 +956,7 @@ bool Parser::parseDifferentiabilityParametersClause( bool Parser::parseDifferentiableAttributeArguments( bool &linear, SmallVectorImpl ¶meters, - Optional &jvpSpec, - Optional &vjpSpec, TrailingWhereClause *&whereClause) { + TrailingWhereClause *&whereClause) { StringRef AttrName = "differentiable"; // Parse trailing comma, if it exists, and check for errors. @@ -975,9 +967,8 @@ bool Parser::parseDifferentiableAttributeArguments( diagnose(Tok, diag::unexpected_separator, ","); return true; } - // Check that token after comma is 'wrt' or a function specifier label. - if (isIdentifier(Tok, "wrt") || isIdentifier(Tok, "jvp") || - isIdentifier(Tok, "vjp")) { + // Check that token after comma is 'wrt'. + if (isIdentifier(Tok, "wrt")) { return false; } diagnose(Tok, diag::attr_differentiable_expected_label); @@ -1021,66 +1012,6 @@ bool Parser::parseDifferentiableAttributeArguments( return errorAndSkipUntilConsumeRightParen(*this, AttrName); } - // Function that parses a label and a function specifier, e.g. 'vjp: foo(_:)'. - // Return true on error. - auto parseFuncSpec = [&](StringRef label, DeclNameRefWithLoc &result, - bool &terminateParsingArgs) -> bool { - // Parse label. - if (parseSpecificIdentifier(label, diag::attr_missing_label, label, - AttrName) || - parseToken(tok::colon, diag::expected_colon_after_label, label)) - return true; - // Parse the name of the function. - SyntaxParsingContext FuncDeclNameContext( - SyntaxContext, SyntaxKind::FunctionDeclName); - Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID, - { label }); - result.Name = parseDeclNameRef(result.Loc, funcDiag, - DeclNameFlag::AllowZeroArgCompoundNames | DeclNameFlag::AllowOperators); - // Emit warning for deprecated `jvp:` and `vjp:` arguments. - // TODO(TF-1001): Remove deprecated `jvp:` and `vjp:` arguments. - if (result.Loc.isValid()) { - diagnose(result.Loc.getStartLoc(), - diag::attr_differentiable_jvp_vjp_deprecated_warning) - .highlight(result.Loc.getSourceRange()); - } - // If no trailing comma or 'where' clause, terminate parsing arguments. - if (Tok.isNot(tok::comma, tok::kw_where)) - terminateParsingArgs = true; - return !result.Name; - }; - - // Store whether to terminate parsing arguments. - bool terminateParsingArgs = false; - - // Parse 'jvp: ' (optional). - if (isIdentifier(Tok, "jvp")) { - SyntaxParsingContext JvpContext( - SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier); - jvpSpec = DeclNameRefWithLoc(); - if (parseFuncSpec("jvp", *jvpSpec, terminateParsingArgs)) - return errorAndSkipUntilConsumeRightParen(*this, AttrName); - if (terminateParsingArgs) - return false; - if (consumeIfTrailingComma()) - return errorAndSkipUntilConsumeRightParen(*this, AttrName); - } - - // Parse 'vjp: ' (optional). - if (isIdentifier(Tok, "vjp")) { - SyntaxParsingContext VjpContext( - SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier); - vjpSpec = DeclNameRefWithLoc(); - if (parseFuncSpec("vjp", *vjpSpec, terminateParsingArgs)) - return errorAndSkipUntilConsumeRightParen(*this, AttrName); - if (terminateParsingArgs) - return false; - // Note: intentionally parse trailing comma here, even though it's the last - // function specifier. `consumeIfTrailingComma` will emit an error. - if (consumeIfTrailingComma()) - return errorAndSkipUntilConsumeRightParen(*this, AttrName); - } - // If parser has not advanced and token is not 'where' or ')', emit error. if (Tok.getLoc() == startingLoc && Tok.isNot(tok::kw_where, tok::r_paren)) { diagnose(Tok, diag::attr_differentiable_expected_label); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 2bc8370406fee..1e22b48fef23c 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -5022,6 +5022,88 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B, blockType, subMap); break; } + case SILInstructionKind::DifferentiableFunctionInst: { + // e.g. differentiable_function [parameters 0 1 2] %0 : $T + // + // e.g. differentiable_function [parameters 0 1 2] %0 : $T with_derivative + // {%1 : $T, %2 : $T} + // ^~ jvp ^~ vjp + // Parse `[parameters ...]`. + SmallVector parameterIndices; + if (parseIndexList(P, "parameters", parameterIndices, + diag::sil_autodiff_expected_parameter_index)) + return true; + // Parse the original function value. + SILValue original; + SourceLoc originalOperandLoc; + if (parseTypedValueRef(original, originalOperandLoc, B)) + return true; + auto fnType = original->getType().getAs(); + if (!fnType) { + P.diagnose(originalOperandLoc, + diag::sil_inst_autodiff_expected_function_type_operand); + return true; + } + Optional> derivativeFunctions = None; + // Parse an optional operand list + // `with_derivative { , }`. + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with_derivative") { + P.consumeToken(tok::identifier); + // Parse derivative function values as an operand list. + // FIXME(rxwei): Change this to *not* require a type signature once + // we can infer derivative function types. + derivativeFunctions = std::make_pair(SILValue(), SILValue()); + if (P.parseToken( + tok::l_brace, + diag::sil_inst_autodiff_operand_list_expected_lbrace) || + parseTypedValueRef(derivativeFunctions->first, B) || + P.parseToken(tok::comma, + diag::sil_inst_autodiff_operand_list_expected_comma) || + parseTypedValueRef(derivativeFunctions->second, B) || + P.parseToken(tok::r_brace, + diag::sil_inst_autodiff_operand_list_expected_rbrace)) + return true; + } + if (parseSILDebugLocation(InstLoc, B)) + return true; + auto *parameterIndicesSubset = IndexSubset::get( + P.Context, fnType->getNumParameters(), parameterIndices); + ResultVal = B.createDifferentiableFunction( + InstLoc, parameterIndicesSubset, original, derivativeFunctions); + break; + } + case SILInstructionKind::DifferentiableFunctionExtractInst: { + // Parse the rest of the instruction: an extractee, a differentiable + // function operand, an optional explicit extractee type, and a debug + // location. + NormalDifferentiableFunctionTypeComponent extractee; + StringRef extracteeNames[3] = {"original", "jvp", "vjp"}; + SILValue functionOperand; + SourceLoc lastLoc; + if (P.parseToken( + tok::l_square, + diag::sil_inst_autodiff_expected_differentiable_extractee_kind) || + parseSILIdentifierSwitch( + extractee, extracteeNames, + diag::sil_inst_autodiff_expected_differentiable_extractee_kind) || + P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare, + "extractee kind")) + return true; + if (parseTypedValueRef(functionOperand, B)) + return true; + // Parse an optional explicit extractee type. + Optional extracteeType = None; + if (P.consumeIf(tok::kw_as)) { + extracteeType = SILType(); + if (parseSILType(*extracteeType)) + return true; + } + if (parseSILDebugLocation(InstLoc, B)) + return true; + ResultVal = B.createDifferentiableFunctionExtract( + InstLoc, extractee, functionOperand, extracteeType); + break; + } case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { // e.g. differentiability_witness_function // [jvp] [parameters 0 1] [results 0] diff --git a/lib/SIL/InstructionUtils.cpp b/lib/SIL/InstructionUtils.cpp index 63708a88dd1c4..b51e15605b172 100644 --- a/lib/SIL/InstructionUtils.cpp +++ b/lib/SIL/InstructionUtils.cpp @@ -455,6 +455,13 @@ void swift::findClosuresForFunctionValue( worklistInsert(SVI->getOperand(0)); continue; } + // Look through `differentiable_function` operands, which are all + // function-typed. + if (auto *DFI = dyn_cast(I)) { + for (auto &fn : DFI->getAllOperands()) + worklistInsert(fn.get()); + continue; + } } // Look through Optionals. if (V->getType().getOptionalObjectType()) { diff --git a/lib/SIL/LinearLifetimeChecker.cpp b/lib/SIL/LinearLifetimeChecker.cpp index 1b7fc4acaf205..2120762abeb70 100644 --- a/lib/SIL/LinearLifetimeChecker.cpp +++ b/lib/SIL/LinearLifetimeChecker.cpp @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "sil-linear-lifetime-checker" -#include "swift/SIL/LinearLifetimeChecker.h" +#include "LinearLifetimeCheckerPrivate.h" #include "swift/SIL/BasicBlockUtils.h" #include "swift/SIL/OwnershipUtils.h" #include "swift/SIL/SILBasicBlock.h" @@ -31,7 +31,6 @@ #include "llvm/Support/Debug.h" using namespace swift; -using namespace swift::ownership; //===----------------------------------------------------------------------===// // Declarations @@ -51,15 +50,15 @@ struct State { /// The result error object that use to signal either that no errors were /// found or if errors are found the specific type of error that was found. - LinearLifetimeError error; + LinearLifetimeChecker::Error error; /// The blocks that we have already visited. SmallPtrSetImpl &visitedBlocks; - /// If non-null a list that we should place any detected leaking blocks for + /// If non-null a callback that we should pass any detected leaking blocks for /// our caller. The intention is that this can be used in a failing case to /// put in missing destroys. - SmallVectorImpl *leakingBlocks; + Optional> leakingBlockCallback; /// The list of passed in consuming uses. ArrayRef consumingUses; @@ -82,20 +81,22 @@ struct State { SmallSetVector successorBlocksThatMustBeVisited; State(SILValue value, SmallPtrSetImpl &visitedBlocks, - ErrorBehaviorKind errorBehavior, - SmallVectorImpl *leakingBlocks, + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior, + Optional> leakingBlockCallback, ArrayRef consumingUses, ArrayRef nonConsumingUses) : value(value), beginBlock(value->getParentBlock()), error(errorBehavior), - visitedBlocks(visitedBlocks), leakingBlocks(leakingBlocks), + visitedBlocks(visitedBlocks), + leakingBlockCallback(leakingBlockCallback), consumingUses(consumingUses), nonConsumingUses(nonConsumingUses) {} State(SILBasicBlock *beginBlock, SmallPtrSetImpl &visitedBlocks, - ErrorBehaviorKind errorBehavior, - SmallVectorImpl *leakingBlocks, + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior, + Optional> leakingBlockCallback, ArrayRef consumingUses, ArrayRef nonConsumingUses) : value(), beginBlock(beginBlock), error(errorBehavior), - visitedBlocks(visitedBlocks), leakingBlocks(leakingBlocks), + visitedBlocks(visitedBlocks), + leakingBlockCallback(leakingBlockCallback), consumingUses(consumingUses), nonConsumingUses(nonConsumingUses) {} void initializeAllNonConsumingUses(ArrayRef nonConsumingUsers); @@ -429,9 +430,10 @@ void State::checkDataflowEndState(DeadEndBlocks &deBlocks) { if (!successorBlocksThatMustBeVisited.empty()) { // If we are asked to store any leaking blocks, put them in the leaking // blocks array. - if (leakingBlocks) { - llvm::copy(successorBlocksThatMustBeVisited, - std::back_inserter(*leakingBlocks)); + if (leakingBlockCallback) { + for (auto *block : successorBlocksThatMustBeVisited) { + (*leakingBlockCallback)(block); + } } // If we are supposed to error on leaks, do so now. @@ -495,15 +497,15 @@ void State::checkDataflowEndState(DeadEndBlocks &deBlocks) { // Top Level Entrypoints //===----------------------------------------------------------------------===// -LinearLifetimeError LinearLifetimeChecker::checkValue( +LinearLifetimeChecker::Error LinearLifetimeChecker::checkValueImpl( SILValue value, ArrayRef consumingUses, ArrayRef nonConsumingUses, ErrorBehaviorKind errorBehavior, - SmallVectorImpl *leakingBlocks) { + Optional> leakingBlockCallback) { assert((!consumingUses.empty() || !deadEndBlocks.empty()) && "Must have at least one consuming user?!"); - State state(value, visitedBlocks, errorBehavior, leakingBlocks, consumingUses, - nonConsumingUses); + State state(value, visitedBlocks, errorBehavior, leakingBlockCallback, + consumingUses, nonConsumingUses); // First add our non-consuming uses and their blocks to the // blocksWithNonConsumingUses map. While we do this, if we have multiple uses @@ -588,3 +590,41 @@ LinearLifetimeError LinearLifetimeChecker::checkValue( state.checkDataflowEndState(deadEndBlocks); return state.error; } + +LinearLifetimeChecker::Error LinearLifetimeChecker::checkValue( + SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses, ErrorBehaviorKind errorBehavior) { + return checkValueImpl(value, consumingUses, nonConsumingUses, errorBehavior, + None); +} + +LinearLifetimeChecker::Error LinearLifetimeChecker::checkValue( + SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses, ErrorBehaviorKind errorBehavior, + function_ref leakingBlocksCallback) { + return checkValueImpl(value, consumingUses, nonConsumingUses, errorBehavior, + leakingBlocksCallback); +} + +bool LinearLifetimeChecker::completeConsumingUseSet( + SILValue value, Operand *consumingUse, + function_ref visitor) { + auto error = + checkValue(value, {consumingUse}, {}, ErrorBehaviorKind::ReturnFalse, + [&](SILBasicBlock *block) { return visitor(block->begin()); }); + + if (!error.getFoundError()) { + return false; + } + + // Return true if we found an over consume (meaning our use is in a loop). + return error.getFoundOverConsume(); +} + +bool LinearLifetimeChecker::validateLifetime( + SILValue value, ArrayRef consumingUses, + ArrayRef nonConsumingUses) { + return !checkValue(value, consumingUses, nonConsumingUses, + ErrorBehaviorKind::ReturnFalse) + .getFoundError(); +} diff --git a/lib/SIL/LinearLifetimeCheckerPrivate.h b/lib/SIL/LinearLifetimeCheckerPrivate.h new file mode 100644 index 0000000000000..ce08a650530c8 --- /dev/null +++ b/lib/SIL/LinearLifetimeCheckerPrivate.h @@ -0,0 +1,114 @@ +//===--- LinearLifetimeCheckerPrivate.h -----------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_SIL_LINEARLIFETIMECHECKER_PRIVATE_H +#define SWIFT_SIL_LINEARLIFETIMECHECKER_PRIVATE_H + +#include "swift/SIL/LinearLifetimeChecker.h" + +namespace swift { + +struct LinearLifetimeChecker::ErrorBehaviorKind { + enum inner_t { + Invalid = 0, + ReturnFalse = 1, + PrintMessage = 2, + Assert = 4, + ReturnFalseOnLeak = 8, + PrintMessageAndReturnFalse = PrintMessage | ReturnFalse, + PrintMessageAndAssert = PrintMessage | Assert, + ReturnFalseOnLeakAssertOtherwise = ReturnFalseOnLeak | Assert, + } Value; + + ErrorBehaviorKind() : Value(Invalid) {} + ErrorBehaviorKind(inner_t Inner) : Value(Inner) { assert(Value != Invalid); } + + bool shouldAssert() const { + assert(Value != Invalid); + return Value & Assert; + } + + bool shouldReturnFalseOnLeak() const { + assert(Value != Invalid); + return Value & ReturnFalseOnLeak; + } + + bool shouldPrintMessage() const { + assert(Value != Invalid); + return Value & PrintMessage; + } + + bool shouldReturnFalse() const { + assert(Value != Invalid); + return Value & ReturnFalse; + } +}; + +class LinearLifetimeChecker::Error { + ErrorBehaviorKind errorBehavior; + bool foundUseAfterFree = false; + bool foundLeak = false; + bool foundOverConsume = false; + +public: + Error(ErrorBehaviorKind errorBehavior) : errorBehavior(errorBehavior) {} + + bool getFoundError() const { + return foundUseAfterFree || foundLeak || foundOverConsume; + } + + bool getFoundLeak() const { return foundLeak; } + + bool getFoundUseAfterFree() const { return foundUseAfterFree; } + + bool getFoundOverConsume() const { return foundOverConsume; } + + void handleLeak(llvm::function_ref &&messagePrinterFunc) { + foundLeak = true; + + if (errorBehavior.shouldPrintMessage()) + messagePrinterFunc(); + + if (errorBehavior.shouldReturnFalseOnLeak()) + return; + + // We already printed out our error if we needed to, so don't pass it along. + handleError([]() {}); + } + + void handleOverConsume(llvm::function_ref &&messagePrinterFunc) { + foundOverConsume = true; + handleError(std::move(messagePrinterFunc)); + } + + void handleUseAfterFree(llvm::function_ref &&messagePrinterFunc) { + foundUseAfterFree = true; + handleError(std::move(messagePrinterFunc)); + } + +private: + void handleError(llvm::function_ref &&messagePrinterFunc) { + if (errorBehavior.shouldPrintMessage()) + messagePrinterFunc(); + + if (errorBehavior.shouldReturnFalse()) { + return; + } + + assert(errorBehavior.shouldAssert() && "At this point, we should assert"); + llvm_unreachable("triggering standard assertion failure routine"); + } +}; + +} // namespace swift + +#endif diff --git a/lib/SIL/OperandOwnership.cpp b/lib/SIL/OperandOwnership.cpp index c6a4a8fe11da2..006f1ff755fd9 100644 --- a/lib/SIL/OperandOwnership.cpp +++ b/lib/SIL/OperandOwnership.cpp @@ -10,8 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "LinearLifetimeCheckerPrivate.h" #include "swift/SIL/ApplySite.h" -#include "swift/SIL/LinearLifetimeChecker.h" #include "swift/SIL/OwnershipUtils.h" #include "swift/SIL/SILBuiltinVisitor.h" #include "swift/SIL/SILInstruction.h" @@ -19,7 +19,6 @@ #include "swift/SIL/SILVisitor.h" using namespace swift; -using namespace swift::ownership; //===----------------------------------------------------------------------===// // OperandOwnershipKindClassifier @@ -37,7 +36,7 @@ class OperandOwnershipKindClassifier LLVM_ATTRIBUTE_UNUSED SILModule &mod; const Operand &op; - ErrorBehaviorKind errorBehavior; + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior; bool checkingSubObject; public: @@ -48,9 +47,10 @@ class OperandOwnershipKindClassifier /// should be the subobject and Value should be the parent object. An example /// of where one would want to do this is in the case of value projections /// like struct_extract. - OperandOwnershipKindClassifier(SILModule &mod, const Operand &op, - ErrorBehaviorKind errorBehavior, - bool checkingSubObject) + OperandOwnershipKindClassifier( + SILModule &mod, const Operand &op, + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior, + bool checkingSubObject) : mod(mod), op(op), errorBehavior(errorBehavior), checkingSubObject(checkingSubObject) {} @@ -348,6 +348,7 @@ FORWARD_ANY_OWNERSHIP_INST(UncheckedEnumData) FORWARD_ANY_OWNERSHIP_INST(DestructureStruct) FORWARD_ANY_OWNERSHIP_INST(DestructureTuple) FORWARD_ANY_OWNERSHIP_INST(InitExistentialRef) +FORWARD_ANY_OWNERSHIP_INST(DifferentiableFunction) #undef FORWARD_ANY_OWNERSHIP_INST // An instruction that forwards a constant ownership or trivial ownership. @@ -366,6 +367,8 @@ FORWARD_ANY_OWNERSHIP_INST(InitExistentialRef) } FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, TupleExtract) FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, StructExtract) +FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, + DifferentiableFunctionExtract) FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Owned, MustBeInvalidated, MarkUninitialized) #undef CONSTANT_OR_NONE_OWNERSHIP_INST @@ -1042,8 +1045,9 @@ OperandOwnershipKindClassifier::visitBuiltinInst(BuiltinInst *bi) { OperandOwnershipKindMap Operand::getOwnershipKindMap(bool isForwardingSubValue) const { - OperandOwnershipKindClassifier classifier(getUser()->getModule(), *this, - ErrorBehaviorKind::ReturnFalse, - isForwardingSubValue); + OperandOwnershipKindClassifier classifier( + getUser()->getModule(), *this, + LinearLifetimeChecker::ErrorBehaviorKind::ReturnFalse, + isForwardingSubValue); return classifier.visit(const_cast(getUser())); } diff --git a/lib/SIL/OwnershipUtils.cpp b/lib/SIL/OwnershipUtils.cpp index 5f671d9cf613c..23dfe8e50e975 100644 --- a/lib/SIL/OwnershipUtils.cpp +++ b/lib/SIL/OwnershipUtils.cpp @@ -30,6 +30,7 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) { case SILNodeKind::TupleInst: case SILNodeKind::StructInst: case SILNodeKind::EnumInst: + case SILNodeKind::DifferentiableFunctionInst: case SILNodeKind::OpenExistentialRefInst: case SILNodeKind::UpcastInst: case SILNodeKind::UncheckedRefCastInst: @@ -58,6 +59,7 @@ bool swift::isGuaranteedForwardingValueKind(SILNodeKind kind) { switch (kind) { case SILNodeKind::TupleExtractInst: case SILNodeKind::StructExtractInst: + case SILNodeKind::DifferentiableFunctionExtractInst: case SILNodeKind::OpenExistentialValueInst: case SILNodeKind::OpenExistentialBoxValueInst: return true; diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index 18c3d1d71ffb1..728353be8283e 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -293,15 +293,8 @@ SILLinkage SILDeclRef::getLinkage(ForDefinition_t forDefinition) const { limit = Limit::NeverPublic; } - // The property wrapper backing initializer is never public for resilient - // properties. - if (kind == SILDeclRef::Kind::PropertyWrapperBackingInitializer) { - if (cast(d)->isResilient()) - limit = Limit::NeverPublic; - } - // Stored property initializers get the linkage of their containing type. - if (isStoredPropertyInitializer()) { + if (isStoredPropertyInitializer() || isPropertyWrapperBackingInitializer()) { // Three cases: // // 1) Type is formally @_fixed_layout/@frozen. Root initializers can be @@ -483,7 +476,7 @@ IsSerialized_t SILDeclRef::isSerialized() const { // Stored property initializers are inlinable if the type is explicitly // marked as @frozen. - if (isStoredPropertyInitializer()) { + if (isStoredPropertyInitializer() || isPropertyWrapperBackingInitializer()) { auto *nominal = cast(d->getDeclContext()); auto scope = nominal->getFormalAccessScope(/*useDC=*/nullptr, @@ -794,18 +787,6 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) { declRef.derivativeFunctionIdentifier->getParameterIndices(); }); assert(derivedDiffAttr && "Expected `@differentiable` attribute"); - // If the derived `@differentiable` attribute specifies a derivative function, - // then a new vtable entry is needed. Return true. - switch (declRef.derivativeFunctionIdentifier->getKind()) { - case AutoDiffDerivativeFunctionKind::JVP: - if (!overridden.requiresNewVTableEntry() && derivedDiffAttr->getJVP()) - return true; - break; - case AutoDiffDerivativeFunctionKind::VJP: - if (!overridden.requiresNewVTableEntry() && derivedDiffAttr->getVJP()) - return true; - break; - } // Otherwise, if the base `@differentiable` attribute specifies a derivative // function, then the derivative is inherited and no new vtable entry is // needed. Return false. @@ -1130,8 +1111,7 @@ bool SILDeclRef::canBeDynamicReplacement() const { bool SILDeclRef::isDynamicallyReplaceable() const { if (kind == SILDeclRef::Kind::DefaultArgGenerator) return false; - if (isStoredPropertyInitializer() || - kind == SILDeclRef::Kind::PropertyWrapperBackingInitializer) + if (isStoredPropertyInitializer() || isPropertyWrapperBackingInitializer()) return false; // Class allocators are not dynamic replaceable. diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 80447c02adfe3..9b2e6136da4b3 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -607,6 +607,95 @@ TryApplyInst *TryApplyInst::create( normalBB, errorBB, specializationInfo); } +SILType DifferentiableFunctionInst::getDifferentiableFunctionType( + SILValue OriginalFunction, IndexSubset *ParameterIndices) { + auto fnTy = OriginalFunction->getType().castTo(); + auto diffTy = fnTy->getWithDifferentiability(DifferentiabilityKind::Normal, + ParameterIndices); + return SILType::getPrimitiveObjectType(diffTy); +} + +ValueOwnershipKind DifferentiableFunctionInst::getMergedOwnershipKind( + SILValue OriginalFunction, ArrayRef DerivativeFunctions) { + if (DerivativeFunctions.empty()) + return OriginalFunction.getOwnershipKind(); + return *mergeSILValueOwnership( + {OriginalFunction, DerivativeFunctions[0], DerivativeFunctions[1]}); +} + +DifferentiableFunctionInst::DifferentiableFunctionInst( + SILDebugLocation Loc, IndexSubset *ParameterIndices, + SILValue OriginalFunction, ArrayRef DerivativeFunctions, + bool HasOwnership) + : InstructionBaseWithTrailingOperands( + OriginalFunction, DerivativeFunctions, Loc, + getDifferentiableFunctionType(OriginalFunction, ParameterIndices), + HasOwnership + ? getMergedOwnershipKind(OriginalFunction, DerivativeFunctions) + : ValueOwnershipKind(ValueOwnershipKind::None)), + ParameterIndices(ParameterIndices), + HasDerivativeFunctions(!DerivativeFunctions.empty()) { + assert(DerivativeFunctions.empty() || DerivativeFunctions.size() == 2); +} + +DifferentiableFunctionInst *DifferentiableFunctionInst::create( + SILModule &Module, SILDebugLocation Loc, IndexSubset *ParameterIndices, + SILValue OriginalFunction, + Optional> VJPAndJVPFunctions, + bool HasOwnership) { + auto derivativeFunctions = + VJPAndJVPFunctions.hasValue() + ? ArrayRef( + reinterpret_cast(VJPAndJVPFunctions.getPointer()), + 2) + : ArrayRef(); + size_t size = totalSizeToAlloc(1 + derivativeFunctions.size()); + void *buffer = Module.allocateInst(size, alignof(DifferentiableFunctionInst)); + return ::new (buffer) + DifferentiableFunctionInst(Loc, ParameterIndices, OriginalFunction, + derivativeFunctions, HasOwnership); +} + +SILType DifferentiableFunctionExtractInst::getExtracteeType( + SILValue function, NormalDifferentiableFunctionTypeComponent extractee, + SILModule &module) { + auto fnTy = function->getType().castTo(); + assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal); + auto originalFnTy = fnTy->getWithoutDifferentiability(); + auto kindOpt = extractee.getAsDerivativeFunctionKind(); + if (!kindOpt) { + assert(extractee == NormalDifferentiableFunctionTypeComponent::Original); + return SILType::getPrimitiveObjectType(originalFnTy); + } + auto resultFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( + fnTy->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0, *kindOpt, + module.Types, LookUpConformanceInModule(module.getSwiftModule())); + return SILType::getPrimitiveObjectType(resultFnTy); +} + +DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst( + SILModule &module, SILDebugLocation debugLoc, + NormalDifferentiableFunctionTypeComponent extractee, SILValue function, + Optional extracteeType) + : UnaryInstructionBase(debugLoc, function, + extracteeType + ? *extracteeType + : getExtracteeType(function, extractee, module)), + Extractee(extractee), HasExplicitExtracteeType(extracteeType.hasValue()) { +#ifndef NDEBUG + if (extracteeType.hasValue()) { + // Note: explicit extractee type is used to avoid inconsistent typing in: + // - Canonical SIL, due to generic specialization. + // - Lowered SIL, due to LoadableByAddress. + // See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an + // explanation of how explicit extractee type is used. + assert((module.getStage() == SILStage::Canonical || + module.getStage() == SILStage::Lowered) && + "Explicit type is valid only in canonical or lowered SIL"); + } +#endif +} + SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType( SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind, SILDifferentiabilityWitness *witness) { diff --git a/lib/SIL/SILModule.cpp b/lib/SIL/SILModule.cpp index 80c6cac974970..0b99e836713a7 100644 --- a/lib/SIL/SILModule.cpp +++ b/lib/SIL/SILModule.cpp @@ -308,6 +308,14 @@ const BuiltinInfo &SILModule::getBuiltinInfo(Identifier ID) { Info.ID = BuiltinValueKind::AtomicStore; else if (OperationName.startswith("allocWithTailElems_")) Info.ID = BuiltinValueKind::AllocWithTailElems; + else if (OperationName.startswith("applyDerivative_")) + Info.ID = BuiltinValueKind::ApplyDerivative; + else if (OperationName.startswith("applyTranspose_")) + Info.ID = BuiltinValueKind::ApplyTranspose; + else if (OperationName.startswith("differentiableFunction_")) + Info.ID = BuiltinValueKind::DifferentiableFunction; + else if (OperationName.startswith("linearFunction_")) + Info.ID = BuiltinValueKind::LinearFunction; else Info.ID = llvm::StringSwitch(OperationName) #define BUILTIN(id, name, attrs) .Case(name, BuiltinValueKind::id) diff --git a/lib/SIL/SILOwnershipVerifier.cpp b/lib/SIL/SILOwnershipVerifier.cpp index acac56b53c601..57ac485ee48a7 100644 --- a/lib/SIL/SILOwnershipVerifier.cpp +++ b/lib/SIL/SILOwnershipVerifier.cpp @@ -12,6 +12,7 @@ #define DEBUG_TYPE "sil-ownership-verifier" +#include "LinearLifetimeCheckerPrivate.h" #include "swift/AST/ASTContext.h" #include "swift/AST/AnyFunctionRef.h" #include "swift/AST/Decl.h" @@ -25,7 +26,6 @@ #include "swift/SIL/Dominance.h" #include "swift/SIL/DynamicCasts.h" #include "swift/SIL/InstructionUtils.h" -#include "swift/SIL/LinearLifetimeChecker.h" #include "swift/SIL/OwnershipUtils.h" #include "swift/SIL/PrettyStackTrace.h" #include "swift/SIL/Projection.h" @@ -45,7 +45,6 @@ #include using namespace swift; -using namespace swift::ownership; // This is an option to put the SILOwnershipVerifier in testing mode. This // causes the following: @@ -73,7 +72,7 @@ static llvm::cl::opt // SILValueOwnershipChecker //===----------------------------------------------------------------------===// -namespace { +namespace swift { // TODO: This class uses a bunch of global state like variables. It should be // refactored into a large state object that is used by functions. @@ -89,7 +88,7 @@ class SILValueOwnershipChecker { SILValue value; /// The action that the checker should perform on detecting an error. - ErrorBehaviorKind errorBehavior; + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior; /// The list of lifetime ending users that we found. Only valid if check is /// successful. @@ -113,7 +112,7 @@ class SILValueOwnershipChecker { public: SILValueOwnershipChecker( DeadEndBlocks &deadEndBlocks, SILValue value, - ErrorBehaviorKind errorBehavior, + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior, llvm::SmallPtrSetImpl &visitedBlocks) : result(), deadEndBlocks(deadEndBlocks), value(value), errorBehavior(errorBehavior), visitedBlocks(visitedBlocks) { @@ -173,7 +172,7 @@ class SILValueOwnershipChecker { SmallVectorImpl &implicitRegularUsers); }; -} // end anonymous namespace +} // namespace swift bool SILValueOwnershipChecker::check() { if (result.hasValue()) @@ -879,11 +878,11 @@ void SILInstruction::verifyOperandOwnership() const { if (isa(this)) return; - ErrorBehaviorKind errorBehavior; + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior; if (IsSILOwnershipVerifierTestingEnabled) { - errorBehavior = ErrorBehaviorKind::PrintMessageAndReturnFalse; + errorBehavior = decltype(errorBehavior)::PrintMessageAndReturnFalse; } else { - errorBehavior = ErrorBehaviorKind::PrintMessageAndAssert; + errorBehavior = decltype(errorBehavior)::PrintMessageAndAssert; } for (const Operand &op : getAllOperands()) { // Skip type dependence operands. @@ -954,11 +953,11 @@ void SILValue::verifyOwnership(DeadEndBlocks *deadEndBlocks) const { if (!f->hasOwnership() || !f->shouldVerifyOwnership()) return; - ErrorBehaviorKind errorBehavior; + LinearLifetimeChecker::ErrorBehaviorKind errorBehavior; if (IsSILOwnershipVerifierTestingEnabled) { - errorBehavior = ErrorBehaviorKind::PrintMessageAndReturnFalse; + errorBehavior = decltype(errorBehavior)::PrintMessageAndReturnFalse; } else { - errorBehavior = ErrorBehaviorKind::PrintMessageAndAssert; + errorBehavior = decltype(errorBehavior)::PrintMessageAndAssert; } SmallPtrSet liveBlocks; diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index d8277da05a738..03b892917ba6d 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2269,6 +2269,41 @@ class SILPrinter : public SILInstructionVisitor { } } + void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { + *this << "[parameters"; + for (auto i : dfi->getParameterIndices()->getIndices()) + *this << ' ' << i; + *this << "] "; + *this << getIDAndType(dfi->getOriginalFunction()); + if (dfi->hasDerivativeFunctions()) { + *this << " with_derivative "; + *this << '{' << getIDAndType(dfi->getJVPFunction()) << ", " + << getIDAndType(dfi->getVJPFunction()) << '}'; + } + } + + void visitDifferentiableFunctionExtractInst( + DifferentiableFunctionExtractInst *dfei) { + *this << '['; + switch (dfei->getExtractee()) { + case NormalDifferentiableFunctionTypeComponent::Original: + *this << "original"; + break; + case NormalDifferentiableFunctionTypeComponent::JVP: + *this << "jvp"; + break; + case NormalDifferentiableFunctionTypeComponent::VJP: + *this << "vjp"; + break; + } + *this << "] "; + *this << getIDAndType(dfei->getOperand()); + if (dfei->hasExplicitExtracteeType()) { + *this << " as "; + *this << dfei->getType(); + } + } + void visitDifferentiabilityWitnessFunctionInst( DifferentiabilityWitnessFunctionInst *dwfi) { auto *witness = dwfi->getWitness(); diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index b6cc2a04713d1..51b254c9ba66d 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -4595,6 +4595,59 @@ class SILVerifier : public SILVerifierBase { "unknown verfication type"); } + void checkDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { + // FIXME(TF-1197): Re-enable verification after substituted SIL function + // types. + return; +#if 0 + auto origTy = + dfi->getOriginalFunction()->getType().getAs(); + require(origTy, "The original function must have a function type"); + require(!origTy->isDifferentiable(), + "The original function must not be @differentiable"); + // Skip verification in lowered SIL: LoadableByAddress changes + // parameter/result conventions. + // TODO: Check that derivative function types match excluding + // parameter/result conventions in lowered SIL. + if (F.getModule().getStage() == SILStage::Lowered) + return; + if (dfi->hasDerivativeFunctions()) { + auto jvp = dfi->getJVPFunction(); + auto jvpType = jvp->getType().getAs(); + require(jvpType, "The JVP function must have a function type"); + require(!jvpType->isDifferentiable(), + "The JVP function must not be @differentiable"); + auto expectedJVPType = origTy->getAutoDiffDerivativeFunctionType( + dfi->getParameterIndices(), /*resultIndex*/ 0, + AutoDiffDerivativeFunctionKind::JVP, TC, + LookUpConformanceInModule(M)); + requireSameType(SILType::getPrimitiveObjectType(jvpType), + SILType::getPrimitiveObjectType(expectedJVPType), + "JVP type does not match expected JVP type"); + auto vjp = dfi->getVJPFunction(); + auto vjpType = vjp->getType().getAs(); + require(vjpType, "The VJP function must have a function type"); + require(!vjpType->isDifferentiable(), + "The VJP function must not be @differentiable"); + auto expectedVJPType = origTy->getAutoDiffDerivativeFunctionType( + dfi->getParameterIndices(), /*resultIndex*/ 0, + AutoDiffDerivativeFunctionKind::VJP, TC, + LookUpConformanceInModule(M)); + requireSameType(SILType::getPrimitiveObjectType(vjpType), + SILType::getPrimitiveObjectType(expectedVJPType), + "VJP type does not match expected VJP type"); + } +#endif + } + + void checkDifferentiableFunctionExtractInst( + DifferentiableFunctionExtractInst *dfei) { + auto fnTy = dfei->getOperand()->getType().getAs(); + require(fnTy, "The function operand must have a function type"); + require(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal, + "The function operand must be a '@differentiable' function"); + } + void checkDifferentiabilityWitnessFunctionInst( DifferentiabilityWitnessFunctionInst *dwfi) { auto witnessFnTy = dwfi->getType().castTo(); diff --git a/lib/SIL/ValueOwnership.cpp b/lib/SIL/ValueOwnership.cpp index a086223d11a21..3926c8979a2c7 100644 --- a/lib/SIL/ValueOwnership.cpp +++ b/lib/SIL/ValueOwnership.cpp @@ -159,6 +159,7 @@ CONSTANT_OWNERSHIP_INST(Unowned, ValueToBridgeObject) } CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, StructExtract) CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, TupleExtract) +CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, DifferentiableFunctionExtract) // OpenExistentialValue opens the boxed value inside an existential // CoW box. The semantics of an existential CoW box implies that we // can only consume the projected value inside the box if the box is @@ -263,6 +264,7 @@ FORWARDING_OWNERSHIP_INST(Enum) // frame from usage. In such cases, we have been creating unnecessary ref count // traffic in code. FORWARDING_OWNERSHIP_INST(InitExistentialRef) +FORWARDING_OWNERSHIP_INST(DifferentiableFunction) #undef FORWARDING_OWNERSHIP_INST ValueOwnershipKind diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 788d8edde0701..8d182ab9c77d7 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -757,8 +757,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, void SILGenModule::emitDifferentiabilityWitnessesForFunction( SILDeclRef constant, SILFunction *F) { - // Visit `@differentiable` amd `@derivative` attributes and generate SIL - // differentiability witnesses. + // Visit `@derivative` attributes and generate SIL differentiability + // witnesses. // Skip if the SILDeclRef is a: // - Default argument generator function. // - Thunk. @@ -770,12 +770,6 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto *AFD = constant.getAbstractFunctionDecl(); auto emitWitnesses = [&](DeclAttributes &Attrs) { for (auto *diffAttr : Attrs.getAttributes()) { - SILFunction *jvp = nullptr; - SILFunction *vjp = nullptr; - if (auto *jvpDecl = diffAttr->getJVPFunction()) - jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition); - if (auto *vjpDecl = diffAttr->getVJPFunction()) - vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition); auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); assert((!F->getLoweredFunctionType()->getSubstGenericSignature() || diffAttr->getDerivativeGenericSignature()) && @@ -783,7 +777,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( "all original SIL functions with generic signatures"); AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, diffAttr->getDerivativeGenericSignature()); - emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr); + emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr, + /*vjp*/ nullptr, diffAttr); } for (auto *derivAttr : Attrs.getAttributes()) { SILFunction *jvp = nullptr; diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index b52d53d17513a..6062de8864e61 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1023,6 +1023,200 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF, return ManagedValue::forUnmanaged(val); } +static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction( + AutoDiffDerivativeFunctionKind kind, unsigned arity, + bool throws, SILGenFunction &SGF, SILLocation loc, + SubstitutionMap substitutions, ArrayRef args, SGFContext C) { + // FIXME(SR-11853): Support throwing functions. + assert(!throws && "Throwing functions are not yet supported"); + + auto origFnVal = args[0].getValue(); + SmallVector origFnArgVals; + for (auto& arg : args.drop_front(1)) + origFnArgVals.push_back(arg.getValue()); + + auto origFnType = origFnVal->getType().castTo(); + auto origFnUnsubstType = origFnType->getUnsubstitutedType(SGF.getModule()); + if (origFnType != origFnUnsubstType) { + origFnVal = SGF.B.createConvertFunction( + loc, origFnVal, SILType::getPrimitiveObjectType(origFnUnsubstType), + /*withoutActuallyEscaping*/ false); + } + + // Get the derivative function. + SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract( + loc, kind, origFnVal); + auto derivativeFnType = derivativeFn->getType().castTo(); + assert(derivativeFnType->getNumResults() == 2); + assert(derivativeFnType->getNumParameters() == origFnArgVals.size()); + + auto derivativeFnUnsubstType = + derivativeFnType->getUnsubstitutedType(SGF.getModule()); + if (derivativeFnType != derivativeFnUnsubstType) { + derivativeFn = SGF.B.createConvertFunction( + loc, derivativeFn, + SILType::getPrimitiveObjectType(derivativeFnUnsubstType), + /*withoutActuallyEscaping*/ false); + } + + // We don't need to destroy the original function or retain the + // `derivativeFn`, because they are trivial (because they are @noescape). + assert(origFnVal->getType().isTrivial(SGF.F)); + assert(derivativeFn->getType().isTrivial(SGF.F)); + + // Do the apply for the indirect result case. + if (derivativeFnType->hasIndirectFormalResults()) { + auto indResBuffer = SGF.getBufferForExprResult( + loc, derivativeFnType->getAllResultsInterfaceType(), C); + SmallVector applyArgs; + applyArgs.push_back(SGF.B.createTupleElementAddr(loc, indResBuffer, 0)); + for (auto origFnArgVal : origFnArgVals) + applyArgs.push_back(origFnArgVal); + auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(), + applyArgs, /*isNonThrowing*/ false); + + derivativeFn = SILValue(); + + SGF.B.createStore(loc, differential, + SGF.B.createTupleElementAddr(loc, indResBuffer, 1), + StoreOwnershipQualifier::Init); + return SGF.manageBufferForExprResult( + indResBuffer, SGF.getTypeLowering(indResBuffer->getType()), C); + } + + // Do the apply for the direct result case. + auto resultTuple = SGF.B.createApply( + loc, derivativeFn, SubstitutionMap(), origFnArgVals, + /*isNonThrowing*/ false); + + derivativeFn = SILValue(); + + return SGF.emitManagedRValueWithCleanup(resultTuple); +} + +static ManagedValue emitBuiltinAutoDiffApplyTransposeFunction( + unsigned arity, bool throws, SILGenFunction &SGF, SILLocation loc, + SubstitutionMap substitutions, ArrayRef args, SGFContext C) { + // FIXME(SR-11853): Support throwing functions. + assert(!throws && "Throwing functions are not yet supported"); + + auto origFnVal = args.front().getValue(); + SmallVector origFnArgVals; + for (auto &arg : args.drop_front(1)) + origFnArgVals.push_back(arg.getValue()); + + // Get the transpose function. + // TODO(TF-1142): Create a linear_function_extract instead of an undef. + auto fnTy = origFnVal->getType().castTo(); + auto transposeFnType = + fnTy->getWithoutDifferentiability()->getAutoDiffTransposeFunctionType( + fnTy->getDifferentiabilityParameterIndices(), SGF.SGM.M.Types, + LookUpConformanceInModule(SGF.SGM.M.getSwiftModule())); + SILValue transposeFn = + SILUndef::get(SILType::getPrimitiveObjectType(transposeFnType), SGF.F); + auto transposeFnUnsubstType = + transposeFnType->getUnsubstitutedType(SGF.getModule()); + if (transposeFnType != transposeFnUnsubstType) { + transposeFn = SGF.B.createConvertFunction( + loc, transposeFn, + SILType::getPrimitiveObjectType(transposeFnUnsubstType), + /*withoutActuallyEscaping*/ false); + transposeFnType = transposeFn->getType().castTo(); + } + + SmallVector applyArgs; + if (transposeFnType->hasIndirectFormalResults()) + applyArgs.push_back( + SGF.getBufferForExprResult( + loc, transposeFnType->getAllResultsInterfaceType(), C)); + for (auto paramArg : args.drop_front()) { + applyArgs.push_back(paramArg.getValue()); + } + auto *apply = SGF.B.createApply( + loc, transposeFn, SubstitutionMap(), applyArgs); + if (transposeFnType->hasIndirectFormalResults()) { + auto resultAddress = applyArgs.front(); + AbstractionPattern pattern( + SGF.F.getLoweredFunctionType()->getSubstGenericSignature(), + resultAddress->getType().getASTType()); + auto &tl = + SGF.getTypeLowering(pattern, resultAddress->getType().getASTType()); + return SGF.manageBufferForExprResult(resultAddress, tl, C); + } else { + return SGF.emitManagedRValueWithCleanup(apply); + } +} + +static ManagedValue emitBuiltinApplyDerivative( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, + ArrayRef args, SGFContext C) { + auto *callExpr = loc.castToASTNode(); + auto builtinDecl = cast(cast( + cast(callExpr->getDirectCallee())->getRHS()) + ->getDecl()); + auto builtinName = builtinDecl->getName().str(); + AutoDiffDerivativeFunctionKind kind; + unsigned arity; + bool throws; + auto successfullyParsed = autodiff::getBuiltinApplyDerivativeConfig( + builtinName, kind, arity, throws); + assert(successfullyParsed); + return emitBuiltinAutoDiffApplyDerivativeFunction( + kind, arity, throws, SGF, loc, substitutions, args, C); +} + +static ManagedValue emitBuiltinApplyTranspose( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, + ArrayRef args, SGFContext C) { + auto *callExpr = loc.castToASTNode(); + auto builtinDecl = cast(cast( + cast(callExpr->getDirectCallee())->getRHS()) + ->getDecl()); + auto builtinName = builtinDecl->getName().str(); + unsigned arity; + bool throws; + auto successfullyParsed = autodiff::getBuiltinApplyTransposeConfig( + builtinName, arity, throws); + assert(successfullyParsed); + return emitBuiltinAutoDiffApplyTransposeFunction( + arity, throws, SGF, loc, substitutions, args, C); +} + +static ManagedValue emitBuiltinDifferentiableFunction( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, + ArrayRef args, SGFContext C) { + assert(args.size() == 3); + auto origFn = args.front(); + auto origType = origFn.getType().castTo(); + auto diffFn = SGF.B.createDifferentiableFunction( + loc, + IndexSubset::getDefault( + SGF.getASTContext(), origType->getNumParameters(), + /*includeAll*/ true), + origFn.forward(SGF), + std::make_pair(args[1].forward(SGF), args[2].forward(SGF))); + return SGF.emitManagedRValueWithCleanup(diffFn); +} + +static ManagedValue emitBuiltinLinearFunction( + SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions, + ArrayRef args, SGFContext C) { + assert(args.size() == 2); + auto origFn = args.front(); + auto origType = origFn.getType().castTo(); + // TODO(TF-1142): Create a linear_function instead of an undef. + auto linearFnTy = origType->getWithDifferentiability( + DifferentiabilityKind::Linear, + IndexSubset::getDefault( + SGF.getASTContext(), origType->getNumParameters(), + /*includeAll*/ true)); + SILValue linearFn = SILUndef::get( + SILType::getPrimitiveObjectType(linearFnTy), SGF.F); + return SGF.emitManagedRValueWithCleanup(linearFn); +} + + + /// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default /// ownership convention for named builtins, which is to take (non-trivial) /// arguments as Owned, this builtin accepts owned as well as guaranteed diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 4d48e8101d411..83cb2f63135a8 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -4455,10 +4455,17 @@ getWitnessFunctionRef(SILGenFunction &SGF, switch (witnessKind) { case WitnessDispatchKind::Static: if (auto *derivativeId = witness.derivativeFunctionIdentifier) { - // TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function` - // and `differentiable_function_extract`. - auto derivativeFnSilType = SILType::getPrimitiveObjectType(witnessFTy); - return SILUndef::get(derivativeFnSilType, SGF.F); + auto originalFn = + SGF.emitGlobalFunctionRef(loc, witness.asAutoDiffOriginalFunction()); + auto *loweredParamIndices = autodiff::getLoweredParameterIndices( + derivativeId->getParameterIndices(), + witness.getDecl()->getInterfaceType()->castTo()); + auto diffFn = SGF.B.createDifferentiableFunction(loc, loweredParamIndices, + originalFn); + return SGF.B.createDifferentiableFunctionExtract( + loc, + NormalDifferentiableFunctionTypeComponent(derivativeId->getKind()), + diffFn); } return SGF.emitGlobalFunctionRef(loc, witness); case WitnessDispatchKind::Dynamic: diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index cfe7c3be00ae7..c8fce513fa347 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -194,14 +194,14 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( auto *derivativeFnDecl = derivativeFnDeclRef.getDecl(); SILGenFunctionBuilder builder(*this); - auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction(); + auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction(); // TODO(TF-685): Use principled thunk mangling. // Do not simply reuse reabstraction thunk mangling. auto name = derivativeFnDeclRef.mangle() + "_vtable_entry_thunk"; auto *thunk = builder.getOrCreateFunction( - derivativeFnDecl, name, originalFn.getLinkage(ForDefinition), constantTy, - IsBare, IsTransparent, derivativeFnDeclRef.isSerialized(), IsNotDynamic, - ProfileCounter(), IsThunk); + derivativeFnDecl, name, originalFnDeclRef.getLinkage(ForDefinition), + constantTy, IsBare, IsTransparent, derivativeFnDeclRef.isSerialized(), + IsNotDynamic, ProfileCounter(), IsThunk); if (!thunk->empty()) return thunk; @@ -212,14 +212,20 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk( auto loc = derivativeFnDeclRef.getAsRegularLocation(); SGF.collectThunkParams(loc, params); - // TODO(TF-1139, TF-1140): Replace `undef` with `differentiable_function` and - // `differentiable_function_extract`. - auto derivativeSilTy = SILType::getPrimitiveObjectType(constantTy); - auto derivativeFn = SILUndef::get(derivativeSilTy, *thunk); + auto originalFn = SGF.emitGlobalFunctionRef(loc, originalFnDeclRef); + auto *loweredParamIndices = autodiff::getLoweredParameterIndices( + derivativeId->getParameterIndices(), + derivativeFnDecl->getInterfaceType()->castTo()); + auto diffFn = + SGF.B.createDifferentiableFunction(loc, loweredParamIndices, originalFn); + auto derivativeFn = SGF.B.createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent(derivativeId->getKind()), + diffFn); + auto derivativeFnSILTy = SILType::getPrimitiveObjectType(constantTy); SmallVector args(thunk->getArguments().begin(), thunk->getArguments().end()); auto apply = - SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeSilTy, + SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeFnSILTy, SGF.getForwardingSubstitutionMap(), args); SGF.B.createReturn(loc, apply); diff --git a/lib/SILOptimizer/IPO/GlobalOpt.cpp b/lib/SILOptimizer/IPO/GlobalOpt.cpp index b338a12522244..365fcd50e6199 100644 --- a/lib/SILOptimizer/IPO/GlobalOpt.cpp +++ b/lib/SILOptimizer/IPO/GlobalOpt.cpp @@ -146,8 +146,6 @@ class SILGlobalOpt { SILFunction *ParentF, llvm::DenseMap &ParentFuncs); - void placeInitializers(SILFunction *InitF, ArrayRef Calls); - /// Update UnhandledOnceCallee and InitializerCount by going through all /// "once" calls. void collectOnceCall(BuiltinInst *AI); @@ -275,6 +273,25 @@ void SILGlobalOpt::collectOnceCall(BuiltinInst *BI) { InitializerCount[Callee]++; } +static bool isPotentialStore(SILInstruction *inst) { + switch (inst->getKind()) { + case SILInstructionKind::LoadInst: + return false; + case SILInstructionKind::PointerToAddressInst: + case SILInstructionKind::StructElementAddrInst: + case SILInstructionKind::TupleElementAddrInst: + for (Operand *op : cast(inst)->getUses()) { + if (isPotentialStore(op->getUser())) + return true; + } + return false; + case SILInstructionKind::BeginAccessInst: + return cast(inst)->getAccessKind() != SILAccessKind::Read; + default: + return true; + } +} + /// return true if this block is inside a loop. bool SILGlobalOpt::isInLoop(SILBasicBlock *CurBB) { SILFunction *F = CurBB->getParent(); @@ -292,147 +309,6 @@ bool SILGlobalOpt::isInLoop(SILBasicBlock *CurBB) { return LoopBlocks.count(CurBB); } -/// Returns true if the block \p BB is terminated with a cond_br based on an -/// availability check. -static bool isAvailabilityCheck(SILBasicBlock *BB) { - auto *CBR = dyn_cast(BB->getTerminator()); - if (!CBR) - return false; - - auto *AI = dyn_cast(CBR->getCondition()); - if (!AI) - return false; - - SILFunction *F = AI->getReferencedFunctionOrNull(); - if (!F || !F->hasSemanticsAttrs()) - return false; - - return F->hasSemanticsAttrThatStartsWith("availability"); -} - -/// Returns true if there are any availability checks along the dominator tree -/// from \p From to \p To. -static bool isAvailabilityCheckOnDomPath(SILBasicBlock *From, SILBasicBlock *To, - DominanceInfo *DT) { - if (From == To) - return false; - - auto *Node = DT->getNode(To)->getIDom(); - for (;;) { - SILBasicBlock *BB = Node->getBlock(); - if (isAvailabilityCheck(BB)) - return true; - if (BB == From) - return false; - Node = Node->getIDom(); - assert(Node && "Should have hit To-block"); - } -} - -ApplyInst *SILGlobalOpt::getHoistedApplyForInitializer( - ApplyInst *AI, DominanceInfo *DT, SILFunction *InitF, SILFunction *ParentF, - llvm::DenseMap &ParentFuncs) { - auto PFI = ParentFuncs.find(ParentF); - if (PFI == ParentFuncs.end()) { - ParentFuncs[ParentF] = AI; - - // It's the first time we found a call to InitF in this function, so we - // try to hoist it out of any loop. - return AI; - } - - // Found a replacement for this init call. Ensure the replacement dominates - // the original call site. - ApplyInst *CommonAI = PFI->second; - assert(cast(CommonAI->getCallee()) - ->getReferencedFunctionOrNull() == InitF && - "ill-formed global init call"); - SILBasicBlock *DomBB = - DT->findNearestCommonDominator(AI->getParent(), CommonAI->getParent()); - - // We must not move initializers around availability-checks. - if (isAvailabilityCheckOnDomPath(DomBB, CommonAI->getParent(), DT)) - return nullptr; - - ApplyInst *Result = nullptr; - if (DomBB != CommonAI->getParent()) { - CommonAI->moveBefore(&*DomBB->begin()); - placeFuncRef(CommonAI, DT); - - // Try to hoist the existing AI again if we move it to another block, - // e.g. from a loop exit into the loop. - Result = CommonAI; - } - - AI->replaceAllUsesWith(CommonAI); - AI->eraseFromParent(); - HasChanged = true; - return Result; -} - -/// Optimize placement of initializer calls given a list of calls to the -/// same initializer. All original initialization points must be dominated by -/// the final initialization calls. -/// -/// The current heuristic hoists all initialization points within a function to -/// a single dominating call in the outer loop preheader. -void SILGlobalOpt::placeInitializers(SILFunction *InitF, - ArrayRef Calls) { - LLVM_DEBUG(llvm::dbgs() << "GlobalOpt: calls to " - << Demangle::demangleSymbolAsString(InitF->getName()) - << " : " << Calls.size() << "\n"); - // Map each initializer-containing function to its final initializer call. - llvm::DenseMap ParentFuncs; - for (auto *AI : Calls) { - assert(AI->getNumArguments() == 0 && "ill-formed global init call"); - assert( - cast(AI->getCallee())->getReferencedFunctionOrNull() == - InitF && - "wrong init call"); - SILFunction *ParentF = AI->getFunction(); - DominanceInfo *DT = DA->get(ParentF); - ApplyInst *HoistAI = - getHoistedApplyForInitializer(AI, DT, InitF, ParentF, ParentFuncs); - - // If we were unable to find anything, just go onto the next apply. - if (!HoistAI) { - continue; - } - - // Otherwise, move this call to the outermost loop preheader. - SILBasicBlock *BB = HoistAI->getParent(); - typedef llvm::DomTreeNodeBase DomTreeNode; - DomTreeNode *Node = DT->getNode(BB); - while (Node) { - SILBasicBlock *DomParentBB = Node->getBlock(); - if (isAvailabilityCheck(DomParentBB)) { - LLVM_DEBUG(llvm::dbgs() << " don't hoist above availability check " - "at bb" - << DomParentBB->getDebugID() << "\n"); - break; - } - BB = DomParentBB; - if (!isInLoop(BB)) - break; - Node = Node->getIDom(); - } - - if (BB == HoistAI->getParent()) { - // BB is either unreachable or not in a loop. - LLVM_DEBUG(llvm::dbgs() << " skipping (not in a loop): " << *HoistAI - << " in " << HoistAI->getFunction()->getName() - << "\n"); - continue; - } - - LLVM_DEBUG(llvm::dbgs() << " hoisting: " << *HoistAI << " in " - << HoistAI->getFunction()->getName() << "\n"); - HoistAI->moveBefore(&*BB->begin()); - placeFuncRef(HoistAI, DT); - HasChanged = true; - } -} - bool SILGlobalOpt::isAssignedOnlyOnceInInitializer(SILGlobalVariable *SILG, SILFunction *globalAddrF) { if (SILG->isLet()) @@ -440,16 +316,18 @@ bool SILGlobalOpt::isAssignedOnlyOnceInInitializer(SILGlobalVariable *SILG, // If we should skip this, it is probably because there are multiple stores. // Return false if there are multiple stores or no stores. - if (GlobalVarSkipProcessing.count(SILG) || !GlobalVarStore.count(SILG) || - // Check if there is more than one use the global addr function. If there - // is only one use, it must be the use that we are trying to optimize, so - // that is OK. If there is more than one use, one of the other uses may - // have a store attached to it which means there may be more than one - // assignment, so return false. - (GlobalInitCallMap.count(globalAddrF) && - GlobalInitCallMap[globalAddrF].size() != 1)) + if (GlobalVarSkipProcessing.count(SILG) || !GlobalVarStore.count(SILG)) return false; + if (GlobalInitCallMap.count(globalAddrF)) { + for (ApplyInst *initCall : GlobalInitCallMap[globalAddrF]) { + for (auto *Op : getNonDebugUses(initCall)) { + if (isPotentialStore(Op->getUser())) + return false; + } + } + } + // Otherwise, return true if this can't be used externally (false, otherwise). return !isPossiblyUsedExternally(SILG->getLinkage(), SILG->getModule().isWholeModule()); @@ -727,24 +605,6 @@ bool SILGlobalOpt::tryRemoveUnusedGlobal(SILGlobalVariable *global) { return true; } -static bool isPotentialStore(SILInstruction *inst) { - switch (inst->getKind()) { - case SILInstructionKind::LoadInst: - case SILInstructionKind::EndAccessInst: - return false; - case SILInstructionKind::StructElementAddrInst: - case SILInstructionKind::TupleElementAddrInst: - case SILInstructionKind::BeginAccessInst: - for (Operand *op : cast(inst)->getUses()) { - if (isPotentialStore(op->getUser())) - return true; - } - return false; - default: - return true; - } -} - /// If this is a read from a global let variable, map it. void SILGlobalOpt::collectGlobalAccess(GlobalAddrInst *GAI) { auto *SILG = GAI->getReferencedGlobal(); @@ -921,10 +781,6 @@ bool SILGlobalOpt::run() { } } while (changed); - for (auto &InitCalls : GlobalInitCallMap) { - placeInitializers(InitCalls.first, InitCalls.second); - } - // This is similiar to optimizeInitializer, but it's for globals which are // initialized in the "main" function and not by an initializer function. for (auto &Init : GlobalVarStore) { diff --git a/lib/SILOptimizer/LoopTransforms/LICM.cpp b/lib/SILOptimizer/LoopTransforms/LICM.cpp index efd423ced97f1..5c2e306db8b69 100644 --- a/lib/SILOptimizer/LoopTransforms/LICM.cpp +++ b/lib/SILOptimizer/LoopTransforms/LICM.cpp @@ -146,6 +146,57 @@ static bool mayWriteTo(AliasAnalysis *AA, SideEffectAnalysis *SEA, return false; } +/// Returns true if \p sideEffectInst cannot be reordered with a call to a +/// global initialier. +static bool mayConflictWithGlobalInit(AliasAnalysis *AA, + SILInstruction *sideEffectInst, ApplyInst *globalInitCall) { + if (auto *SI = dyn_cast(sideEffectInst)) { + return AA->mayReadOrWriteMemory(globalInitCall, SI->getDest()); + } + if (auto *LI = dyn_cast(sideEffectInst)) { + return AA->mayWriteToMemory(globalInitCall, LI->getOperand()); + } + return true; +} + +/// Returns true if any of the instructions in \p sideEffectInsts which are +/// post-dominated by a call to a global initialier cannot be reordered with +/// the call. +static bool mayConflictWithGlobalInit(AliasAnalysis *AA, + InstSet &sideEffectInsts, + ApplyInst *globalInitCall, + SILBasicBlock *preHeader, PostDominanceInfo *PD) { + if (!PD->dominates(globalInitCall->getParent(), preHeader)) + return true; + + SILBasicBlock *globalInitBlock = globalInitCall->getParent(); + for (auto *seInst : sideEffectInsts) { + // Only check instructions in blocks which are "before" (i.e. post-dominated + // by) the block which contains the init-call. + // Instructions which are before the call in the same block have already + // been checked. + if (PD->properlyDominates(globalInitBlock, seInst->getParent())) { + if (mayConflictWithGlobalInit(AA, seInst, globalInitCall)) + return true; + } + } + return false; +} + +/// Returns true if any of the instructions in \p sideEffectInsts cannot be +/// reordered with a call to a global initialier (which is in the same basic +/// block). +static bool mayConflictWithGlobalInit(AliasAnalysis *AA, + ArrayRef sideEffectInsts, + ApplyInst *globalInitCall) { + for (auto *seInst : sideEffectInsts) { + assert(seInst->getParent() == globalInitCall->getParent()); + if (mayConflictWithGlobalInit(AA, seInst, globalInitCall)) + return true; + } + return false; +} + // When Hoisting / Sinking, // Don't descend into control-dependent code. // Only traverse into basic blocks that dominate all exits. @@ -409,6 +460,8 @@ class LoopTreeOptimization { AliasAnalysis *AA; SideEffectAnalysis *SEA; DominanceInfo *DomTree; + PostDominanceAnalysis *PDA; + PostDominanceInfo *postDomTree = nullptr; AccessedStorageAnalysis *ASA; bool Changed; @@ -435,10 +488,11 @@ class LoopTreeOptimization { public: LoopTreeOptimization(SILLoop *TopLevelLoop, SILLoopInfo *LI, AliasAnalysis *AA, SideEffectAnalysis *SEA, - DominanceInfo *DT, AccessedStorageAnalysis *ASA, + DominanceInfo *DT, PostDominanceAnalysis *PDA, + AccessedStorageAnalysis *ASA, bool RunsOnHighLevelSil) - : LoopInfo(LI), AA(AA), SEA(SEA), DomTree(DT), ASA(ASA), Changed(false), - RunsOnHighLevelSIL(RunsOnHighLevelSil) { + : LoopInfo(LI), AA(AA), SEA(SEA), DomTree(DT), PDA(PDA), ASA(ASA), + Changed(false), RunsOnHighLevelSIL(RunsOnHighLevelSil) { // Collect loops for a recursive bottom-up traversal in the loop tree. BotUpWorkList.push_back(TopLevelLoop); for (unsigned i = 0; i < BotUpWorkList.size(); ++i) { @@ -556,9 +610,11 @@ static bool isSafeReadOnlyApply(SideEffectAnalysis *SEA, ApplyInst *AI) { } static void checkSideEffects(swift::SILInstruction &Inst, - InstSet &SideEffectInsts) { + InstSet &SideEffectInsts, + SmallVectorImpl &sideEffectsInBlock) { if (Inst.mayHaveSideEffects()) { SideEffectInsts.insert(&Inst); + sideEffectsInBlock.push_back(&Inst); } } @@ -708,6 +764,7 @@ void LoopTreeOptimization::analyzeCurrentLoop( // Interesting instructions in the loop: SmallVector ReadOnlyApplies; + SmallVector globalInitCalls; SmallVector Loads; SmallVector Stores; SmallVector FixLifetimes; @@ -715,6 +772,7 @@ void LoopTreeOptimization::analyzeCurrentLoop( SmallVector fullApplies; for (auto *BB : Loop->getBlocks()) { + SmallVector sideEffectsInBlock; for (auto &Inst : *BB) { switch (Inst.getKind()) { case SILInstructionKind::FixLifetimeInst: { @@ -731,12 +789,12 @@ void LoopTreeOptimization::analyzeCurrentLoop( case SILInstructionKind::StoreInst: { Stores.push_back(cast(&Inst)); LoadsAndStores.push_back(&Inst); - checkSideEffects(Inst, sideEffects); + checkSideEffects(Inst, sideEffects, sideEffectsInBlock); break; } case SILInstructionKind::BeginAccessInst: BeginAccesses.push_back(cast(&Inst)); - checkSideEffects(Inst, sideEffects); + checkSideEffects(Inst, sideEffects, sideEffectsInBlock); break; case SILInstructionKind::RefElementAddrInst: SpecialHoist.push_back(cast(&Inst)); @@ -747,12 +805,21 @@ void LoopTreeOptimization::analyzeCurrentLoop( // cond_fail that would have protected (executed before) a memory access // must - after hoisting - also be executed before said access. HoistUp.insert(&Inst); - checkSideEffects(Inst, sideEffects); + checkSideEffects(Inst, sideEffects, sideEffectsInBlock); break; case SILInstructionKind::ApplyInst: { auto *AI = cast(&Inst); if (isSafeReadOnlyApply(SEA, AI)) { ReadOnlyApplies.push_back(AI); + } else if (SILFunction *callee = AI->getReferencedFunctionOrNull()) { + // Calls to global inits are different because we don't care about + // side effects which are "after" the call in the loop. + if (callee->isGlobalInit() && + // Check against side-effects within the same block. + // Side-effects in other blocks are checked later (after we + // scanned all blocks of the loop). + !mayConflictWithGlobalInit(AA, sideEffectsInBlock, AI)) + globalInitCalls.push_back(AI); } // check for array semantics and side effects - same as default LLVM_FALLTHROUGH; @@ -761,7 +828,7 @@ void LoopTreeOptimization::analyzeCurrentLoop( if (auto fullApply = FullApplySite::isa(&Inst)) { fullApplies.push_back(fullApply); } - checkSideEffects(Inst, sideEffects); + checkSideEffects(Inst, sideEffects, sideEffectsInBlock); if (canHoistUpDefault(&Inst, Loop, DomTree, RunsOnHighLevelSIL)) { HoistUp.insert(&Inst); } @@ -780,6 +847,23 @@ void LoopTreeOptimization::analyzeCurrentLoop( HoistUp.insert(LI); } } + + if (!globalInitCalls.empty()) { + if (!postDomTree) { + postDomTree = PDA->get(Preheader->getParent()); + } + if (postDomTree->getRootNode()) { + for (ApplyInst *ginitCall : globalInitCalls) { + // Check against side effects which are "before" (i.e. post-dominated + // by) the global initializer call. + if (!mayConflictWithGlobalInit(AA, sideEffects, ginitCall, Preheader, + postDomTree)) { + HoistUp.insert(ginitCall); + } + } + } + } + // Collect memory locations for which we can move all loads and stores out // of the loop. for (StoreInst *SI : Stores) { @@ -1041,6 +1125,7 @@ class LICM : public SILFunctionTransform { } DominanceAnalysis *DA = PM->getAnalysis(); + PostDominanceAnalysis *PDA = PM->getAnalysis(); AliasAnalysis *AA = PM->getAnalysis(); SideEffectAnalysis *SEA = PM->getAnalysis(); AccessedStorageAnalysis *ASA = getAnalysis(); @@ -1051,8 +1136,8 @@ class LICM : public SILFunctionTransform { for (auto *TopLevelLoop : *LoopInfo) { if (!DomTree) DomTree = DA->get(F); - LoopTreeOptimization Opt(TopLevelLoop, LoopInfo, AA, SEA, DomTree, ASA, - RunsOnHighLevelSil); + LoopTreeOptimization Opt(TopLevelLoop, LoopInfo, AA, SEA, DomTree, PDA, + ASA, RunsOnHighLevelSil); Changed |= Opt.optimize(); } diff --git a/lib/SILOptimizer/Mandatory/MandatoryInlining.cpp b/lib/SILOptimizer/Mandatory/MandatoryInlining.cpp index 6e1bd8cd2f5f4..9278164033c8e 100644 --- a/lib/SILOptimizer/Mandatory/MandatoryInlining.cpp +++ b/lib/SILOptimizer/Mandatory/MandatoryInlining.cpp @@ -73,8 +73,6 @@ static void fixupReferenceCounts( DeadEndBlocks deadEndBlocks(pai->getFunction()); SmallVector leakingBlocks; - auto errorBehavior = ownership::ErrorBehaviorKind::ReturnFalse; - // Add a copy of each non-address type capture argument to lifetime extend the // captured argument over at least the inlined function and till the end of a // box if we have an address. This deals with the possibility of the closure @@ -136,28 +134,23 @@ static void fixupReferenceCounts( // am going to change this to use a different API on the linear lifetime // checker that makes this clearer. LinearLifetimeChecker checker(visitedBlocks, deadEndBlocks); - auto error = checker.checkValue(pai, {applySite.getCalleeOperand()}, {}, - errorBehavior, &leakingBlocks); - if (error.getFoundLeak()) { - while (!leakingBlocks.empty()) { - auto *leakingBlock = leakingBlocks.pop_back_val(); - auto loc = RegularLocation::getAutoGeneratedLocation(); - SILBuilderWithScope builder(leakingBlock->begin()); - if (hasOwnership) { - builder.createEndBorrow(loc, argument); - } - builder.emitDestroyValueOperation(loc, copy); - } - } - - // If we found an over consume it means that our value is consumed within - // the loop. That means our leak code will have lifetime extended the - // value over the loop. So we should /not/ insert a destroy after the - // apply site. In contrast, if we do not have an over consume, we must - // have been compensating for uses in the top of a diamond and need to - // insert a destroy after the apply since the leak will just cover the - // other path. - if (!error.getFoundOverConsume()) { + bool consumedInLoop = checker.completeConsumingUseSet( + pai, applySite.getCalleeOperand(), + [&](SILBasicBlock::iterator insertPt) { + SILBuilderWithScope builder(insertPt); + if (hasOwnership) { + builder.createEndBorrow(loc, argument); + } + builder.emitDestroyValueOperation(loc, copy); + }); + + // Since our applySite is in a different loop than our partial apply means + // thatour leak code will have lifetime extended the value over the + // loop. So we should /not/ insert a destroy after the apply site. In + // contrast, if we do not have a loop, we must have been compensating for + // uses in the top of a diamond and need to insert a destroy after the + // apply since the leak will just cover the other path. + if (!consumedInLoop) { applySite.insertAfterInvocation([&](SILBasicBlock::iterator iter) { if (hasOwnership) { SILBuilderWithScope(iter).createEndBorrow(loc, argument); @@ -174,7 +167,9 @@ static void fixupReferenceCounts( v = SILBuilderWithScope(pai).emitCopyValueOperation(loc, v); visitedBlocks.clear(); - // If we need to insert compensating destroys, do so. + // If our consuming partial apply does not post-dominate our + // partial_apply, compute the completion of the post dominance set and if + // that set is non-empty, insert compensating destroys at those places. // // NOTE: We use pai here since in non-ossa code emitCopyValueOperation // returns the operand of the strong_retain which may have a ValueBase @@ -187,17 +182,16 @@ static void fixupReferenceCounts( // am going to change this to use a different API on the linear lifetime // checker that makes this clearer. LinearLifetimeChecker checker(visitedBlocks, deadEndBlocks); - auto error = checker.checkValue(pai, {applySite.getCalleeOperand()}, {}, - errorBehavior, &leakingBlocks); - if (error.getFoundError()) { - while (!leakingBlocks.empty()) { - auto *leakingBlock = leakingBlocks.pop_back_val(); - auto loc = RegularLocation::getAutoGeneratedLocation(); - SILBuilderWithScope builder(leakingBlock->begin()); - builder.emitDestroyValueOperation(loc, v); - } - } - + checker.completeConsumingUseSet( + pai, applySite.getCalleeOperand(), + [&](SILBasicBlock::iterator insertPt) { + auto loc = RegularLocation::getAutoGeneratedLocation(); + SILBuilderWithScope builder(insertPt); + builder.emitDestroyValueOperation(loc, v); + }); + + // Then insert destroys after the apply site since our value is not being + // consumed as part of the actual apply. applySite.insertAfterInvocation([&](SILBasicBlock::iterator iter) { SILBuilderWithScope(iter).emitDestroyValueOperation(loc, v); }); @@ -226,17 +220,16 @@ static void fixupReferenceCounts( // am going to change this to use a different API on the linear lifetime // checker that makes this clearer. LinearLifetimeChecker checker(visitedBlocks, deadEndBlocks); - auto error = checker.checkValue(pai, {applySite.getCalleeOperand()}, {}, - errorBehavior, &leakingBlocks); - if (error.getFoundError()) { - while (!leakingBlocks.empty()) { - auto *leakingBlock = leakingBlocks.pop_back_val(); - auto loc = RegularLocation::getAutoGeneratedLocation(); - SILBuilderWithScope builder(leakingBlock->begin()); - builder.emitDestroyValueOperation(loc, v); - } - } - + checker.completeConsumingUseSet( + pai, applySite.getCalleeOperand(), + [&](SILBasicBlock::iterator insertPt) { + auto loc = RegularLocation::getAutoGeneratedLocation(); + SILBuilderWithScope builder(insertPt); + builder.emitDestroyValueOperation(loc, v); + }); + + // NOTE: Unlike with the unowned case above, when we are owned we do not + // need to insert destroys since the apply will consume the value for us. break; } } diff --git a/lib/SILOptimizer/Mandatory/PredictableMemOpt.cpp b/lib/SILOptimizer/Mandatory/PredictableMemOpt.cpp index c0873c021d01b..5d4d4f889cb6c 100644 --- a/lib/SILOptimizer/Mandatory/PredictableMemOpt.cpp +++ b/lib/SILOptimizer/Mandatory/PredictableMemOpt.cpp @@ -1197,38 +1197,30 @@ void AvailableValueAggregator::addHandOffCopyDestroysForPhis( // // Then perform the linear lifetime check. If we succeed, continue. We have // no further work to do. - auto errorKind = ownership::ErrorBehaviorKind::ReturnFalse; auto *loadOperand = &load->getAllOperands()[0]; LinearLifetimeChecker checker(visitedBlocks, deadEndBlocks); - auto error = - checker.checkValue(phi, {loadOperand}, {}, errorKind, &leakingBlocks); - - if (!error.getFoundError()) { - // If we did not find an error, then our copy_value must be strongly - // control equivalent as our load_borrow. So just insert a destroy_value - // for the copy_value. - auto next = std::next(load->getIterator()); - SILBuilderWithScope builder(next); - builder.emitDestroyValueOperation(next->getLoc(), phi); - continue; - } + bool consumedInLoop = checker.completeConsumingUseSet( + phi, loadOperand, [&](SILBasicBlock::iterator iter) { + SILBuilderWithScope builder(iter); + builder.emitDestroyValueOperation(loc, phi); + }); - // Ok, we found some leaking blocks and potentially a loop. If we do not - // find a loop, insert the destroy_value after the load_borrow. We do not do - // this if we found a loop since our leaking blocks will lifetime extend the - // value over the loop. - if (!error.getFoundOverConsume()) { - auto next = std::next(load->getIterator()); - SILBuilderWithScope builder(next); - builder.emitDestroyValueOperation(next->getLoc(), phi); + // Ok, we found some leaking blocks and potentially that our load is + // "consumed" inside a different loop in the loop nest from cvi. If we are + // consumed in the loop, then our visit should have inserted all of the + // necessary destroys for us by inserting the destroys on the loop + // boundaries. So, continue. + // + // NOTE: This includes cases where due to an infinite loop, we did not + // insert /any/ destroys since the loop has no boundary in a certain sense. + if (consumedInLoop) { + continue; } - // Ok, we found some leaking blocks. Insert destroys at the beginning of - // these blocks for our copy_value. - for (auto *bb : leakingBlocks) { - SILBuilderWithScope b(bb->begin()); - b.emitDestroyValueOperation(loc, phi); - } + // Otherwise, we need to insert one last destroy after the load for our phi. + auto next = std::next(load->getIterator()); + SILBuilderWithScope builder(next); + builder.emitDestroyValueOperation(next->getLoc(), phi); } // Alright! In summary, we just lifetime extended all of our phis, @@ -1289,38 +1281,30 @@ void AvailableValueAggregator::addMissingDestroysForCopiedValues( // // Then perform the linear lifetime check. If we succeed, continue. We have // no further work to do. - auto errorKind = ownership::ErrorBehaviorKind::ReturnFalse; auto *loadOperand = &load->getAllOperands()[0]; LinearLifetimeChecker checker(visitedBlocks, deadEndBlocks); - auto error = - checker.checkValue(cvi, {loadOperand}, {}, errorKind, &leakingBlocks); - - if (!error.getFoundError()) { - // If we did not find an error, then our copy_value must be strongly - // control equivalent as our load_borrow. So just insert a destroy_value - // for the copy_value. - auto next = std::next(load->getIterator()); - SILBuilderWithScope builder(next); - builder.emitDestroyValueOperation(next->getLoc(), cvi); - continue; - } + bool consumedInLoop = checker.completeConsumingUseSet( + cvi, loadOperand, [&](SILBasicBlock::iterator iter) { + SILBuilderWithScope builder(iter); + builder.emitDestroyValueOperation(loc, cvi); + }); - // Ok, we found some leaking blocks and potentially a loop. If we do not - // find a loop, insert the destroy_value after the load_borrow. We do not do - // this if we found a loop since our leaking blocks will lifetime extend the - // value over the loop. - if (!error.getFoundOverConsume()) { - auto next = std::next(load->getIterator()); - SILBuilderWithScope builder(next); - builder.emitDestroyValueOperation(next->getLoc(), cvi); + // Ok, we found some leaking blocks and potentially that our load is + // "consumed" inside a different loop in the loop nest from cvi. If we are + // consumed in the loop, then our visit should have inserted all of the + // necessary destroys for us by inserting the destroys on the loop + // boundaries. So, continue. + // + // NOTE: This includes cases where due to an infinite loop, we did not + // insert /any/ destroys since the loop has no boundary in a certain sense. + if (consumedInLoop) { + continue; } - // Ok, we found some leaking blocks. Insert destroys at the beginning of - // these blocks for our copy_value. - for (auto *bb : leakingBlocks) { - SILBuilderWithScope b(bb->begin()); - b.emitDestroyValueOperation(loc, cvi); - } + // Otherwise, we need to insert one last destroy after the load for our phi. + auto next = std::next(load->getIterator()); + SILBuilderWithScope builder(next); + builder.emitDestroyValueOperation(next->getLoc(), cvi); } } diff --git a/lib/SILOptimizer/PassManager/PassPipeline.cpp b/lib/SILOptimizer/PassManager/PassPipeline.cpp index e02810668e55c..f44e1bb56f573 100644 --- a/lib/SILOptimizer/PassManager/PassPipeline.cpp +++ b/lib/SILOptimizer/PassManager/PassPipeline.cpp @@ -453,6 +453,9 @@ static bool addMidLevelPassPipeline(SILPassPipelinePlan &P) { // for CapturePropagation. P.addDeadArgSignatureOpt(); + // A LICM pass at mid-level is mainly needed to hoist addressors of globals. + // It needs to be before global_init functions are inlined. + P.addLICM(); // Run loop unrolling after inlining and constant propagation, because loop // trip counts may have became constant. P.addLoopUnroll(); diff --git a/lib/SILOptimizer/SILCombiner/SILCombine.cpp b/lib/SILOptimizer/SILCombiner/SILCombine.cpp index 98d7e92ce0b72..5ee721a3692b8 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombine.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombine.cpp @@ -250,10 +250,6 @@ class SILCombine : public SILFunctionTransform { /// The entry point to the transformation. void run() override { - // FIXME: We should be able to handle ownership. - if (getFunction()->hasOwnership()) - return; - auto *AA = PM->getAnalysis(); auto *DA = PM->getAnalysis(); auto *PCA = PM->getAnalysis(); diff --git a/lib/SILOptimizer/SILCombiner/SILCombiner.h b/lib/SILOptimizer/SILCombiner/SILCombiner.h index 5d7733c287a45..7d0632a36412f 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombiner.h +++ b/lib/SILOptimizer/SILCombiner/SILCombiner.h @@ -163,8 +163,6 @@ class SILCombiner : /// Instruction visitors. SILInstruction *visitReleaseValueInst(ReleaseValueInst *DI); SILInstruction *visitRetainValueInst(RetainValueInst *CI); - SILInstruction *visitReleaseValueAddrInst(ReleaseValueAddrInst *DI); - SILInstruction *visitRetainValueAddrInst(RetainValueAddrInst *CI); SILInstruction *visitPartialApplyInst(PartialApplyInst *AI); SILInstruction *visitApplyInst(ApplyInst *AI); SILInstruction *visitBeginApplyInst(BeginApplyInst *BAI); @@ -208,14 +206,12 @@ class SILCombiner : SILInstruction *visitTupleExtractInst(TupleExtractInst *TEI); SILInstruction *visitFixLifetimeInst(FixLifetimeInst *FLI); SILInstruction *visitSwitchValueInst(SwitchValueInst *SVI); - SILInstruction *visitSelectValueInst(SelectValueInst *SVI); SILInstruction * visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI); SILInstruction * visitCheckedCastBranchInst(CheckedCastBranchInst *CBI); SILInstruction *visitUnreachableInst(UnreachableInst *UI); SILInstruction *visitAllocRefDynamicInst(AllocRefDynamicInst *ARDI); - SILInstruction *visitEnumInst(EnumInst *EI); SILInstruction *visitMarkDependenceInst(MarkDependenceInst *MDI); SILInstruction *visitClassifyBridgeObjectInst(ClassifyBridgeObjectInst *CBOI); diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp index 4b4d37d77dd3f..977a2f7dd7355 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp @@ -79,6 +79,9 @@ static bool foldInverseReabstractionThunks(PartialApplyInst *PAI, } SILInstruction *SILCombiner::visitPartialApplyInst(PartialApplyInst *PAI) { + if (PAI->getFunction()->hasOwnership()) + return nullptr; + // partial_apply without any substitutions or arguments is just a // thin_to_thick_function. if (!PAI->hasSubstitutions() && (PAI->getNumArguments() == 0)) { @@ -1234,6 +1237,9 @@ FullApplySite SILCombiner::rewriteApplyCallee(FullApplySite apply, } SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) { + if (AI->getFunction()->hasOwnership()) + return nullptr; + Builder.setCurrentDebugScope(AI->getDebugScope()); // apply{partial_apply(x,y)}(z) -> apply(z,x,y) is triggered // from visitPartialApplyInst(), so bail here. @@ -1313,6 +1319,9 @@ SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) { } SILInstruction *SILCombiner::visitBeginApplyInst(BeginApplyInst *BAI) { + if (BAI->getFunction()->hasOwnership()) + return nullptr; + if (tryOptimizeInoutKeypath(BAI)) return nullptr; return nullptr; @@ -1368,6 +1377,9 @@ isTryApplyResultNotUsed(UserListTy &AcceptedUses, TryApplyInst *TAI) { } SILInstruction *SILCombiner::visitTryApplyInst(TryApplyInst *AI) { + if (AI->getFunction()->hasOwnership()) + return nullptr; + // apply{partial_apply(x,y)}(z) -> apply(z,x,y) is triggered // from visitPartialApplyInst(), so bail here. if (isa(AI->getCallee())) diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerBuiltinVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerBuiltinVisitors.cpp index 7c37686e464f4..fd48a8511ba46 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerBuiltinVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerBuiltinVisitors.cpp @@ -530,6 +530,9 @@ SILInstruction *SILCombiner::optimizeStringObject(BuiltinInst *BI) { } SILInstruction *SILCombiner::visitBuiltinInst(BuiltinInst *I) { + if (I->getFunction()->hasOwnership()) + return nullptr; + if (I->getBuiltinInfo().ID == BuiltinValueKind::CanBeObjCClass) return optimizeBuiltinCanBeObjCClass(I); if (I->getBuiltinInfo().ID == BuiltinValueKind::IsConcrete) diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp index 5de5677773758..e68a0ba7cfc5a 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp @@ -31,6 +31,9 @@ using namespace swift::PatternMatch; SILInstruction * SILCombiner::visitRefToRawPointerInst(RefToRawPointerInst *RRPI) { + if (RRPI->getFunction()->hasOwnership()) + return nullptr; + // Ref to raw pointer consumption of other ref casts. if (auto *URCI = dyn_cast(RRPI->getOperand())) { // (ref_to_raw_pointer (unchecked_ref_cast x)) @@ -57,6 +60,9 @@ SILCombiner::visitRefToRawPointerInst(RefToRawPointerInst *RRPI) { } SILInstruction *SILCombiner::visitUpcastInst(UpcastInst *UCI) { + if (UCI->getFunction()->hasOwnership()) + return nullptr; + // Ref to raw pointer consumption of other ref casts. // // (upcast (upcast x)) -> (upcast x) @@ -72,6 +78,8 @@ SILInstruction * SILCombiner:: visitPointerToAddressInst(PointerToAddressInst *PTAI) { auto *F = PTAI->getFunction(); + if (F->hasOwnership()) + return nullptr; Builder.setCurrentDebugScope(PTAI->getDebugScope()); @@ -200,6 +208,9 @@ visitPointerToAddressInst(PointerToAddressInst *PTAI) { SILInstruction * SILCombiner::visitUncheckedAddrCastInst(UncheckedAddrCastInst *UADCI) { + if (UADCI->getFunction()->hasOwnership()) + return nullptr; + Builder.setCurrentDebugScope(UADCI->getDebugScope()); // (unchecked-addr-cast (unchecked-addr-cast x X->Y) Y->Z) @@ -221,6 +232,9 @@ SILCombiner::visitUncheckedAddrCastInst(UncheckedAddrCastInst *UADCI) { SILInstruction * SILCombiner::visitUncheckedRefCastInst(UncheckedRefCastInst *URCI) { + if (URCI->getFunction()->hasOwnership()) + return nullptr; + // (unchecked-ref-cast (unchecked-ref-cast x X->Y) Y->Z) // -> // (unchecked-ref-cast x X->Z) @@ -253,6 +267,8 @@ SILCombiner::visitUncheckedRefCastInst(UncheckedRefCastInst *URCI) { SILInstruction * SILCombiner::visitBridgeObjectToRefInst(BridgeObjectToRefInst *BORI) { + if (BORI->getFunction()->hasOwnership()) + return nullptr; // Fold noop casts through Builtin.BridgeObject. // (bridge_object_to_ref (unchecked-ref-cast x BridgeObject) y) // -> (unchecked-ref-cast x y) @@ -266,6 +282,9 @@ SILCombiner::visitBridgeObjectToRefInst(BridgeObjectToRefInst *BORI) { SILInstruction * SILCombiner::visitUncheckedRefCastAddrInst(UncheckedRefCastAddrInst *URCI) { + if (URCI->getFunction()->hasOwnership()) + return nullptr; + SILType SrcTy = URCI->getSrc()->getType(); if (!SrcTy.isLoadable(*URCI->getFunction())) return nullptr; @@ -301,6 +320,9 @@ SILCombiner::visitUncheckedRefCastAddrInst(UncheckedRefCastAddrInst *URCI) { SILInstruction * SILCombiner:: visitUnconditionalCheckedCastAddrInst(UnconditionalCheckedCastAddrInst *UCCAI) { + if (UCCAI->getFunction()->hasOwnership()) + return nullptr; + if (CastOpt.optimizeUnconditionalCheckedCastAddrInst(UCCAI)) MadeChange = true; @@ -310,6 +332,9 @@ visitUnconditionalCheckedCastAddrInst(UnconditionalCheckedCastAddrInst *UCCAI) { SILInstruction * SILCombiner:: visitUnconditionalCheckedCastInst(UnconditionalCheckedCastInst *UCCI) { + if (UCCI->getFunction()->hasOwnership()) + return nullptr; + if (CastOpt.optimizeUnconditionalCheckedCastInst(UCCI)) { MadeChange = true; return nullptr; @@ -338,6 +363,9 @@ visitUnconditionalCheckedCastInst(UnconditionalCheckedCastInst *UCCI) { SILInstruction * SILCombiner:: visitRawPointerToRefInst(RawPointerToRefInst *RawToRef) { + if (RawToRef->getFunction()->hasOwnership()) + return nullptr; + // (raw_pointer_to_ref (ref_to_raw_pointer x X->Y) Y->Z) // -> // (unchecked_ref_cast X->Z) @@ -353,6 +381,9 @@ visitRawPointerToRefInst(RawPointerToRefInst *RawToRef) { SILInstruction * SILCombiner:: visitUncheckedTrivialBitCastInst(UncheckedTrivialBitCastInst *UTBCI) { + if (UTBCI->getFunction()->hasOwnership()) + return nullptr; + // (unchecked_trivial_bit_cast Y->Z // (unchecked_trivial_bit_cast X->Y x)) // -> @@ -406,6 +437,9 @@ visitUncheckedBitwiseCastInst(UncheckedBitwiseCastInst *UBCI) { SILInstruction * SILCombiner::visitThickToObjCMetatypeInst(ThickToObjCMetatypeInst *TTOCMI) { + if (TTOCMI->getFunction()->hasOwnership()) + return nullptr; + // Perform the following transformations: // (thick_to_objc_metatype (metatype @thick)) -> // (metatype @objc_metatype) @@ -423,6 +457,9 @@ SILCombiner::visitThickToObjCMetatypeInst(ThickToObjCMetatypeInst *TTOCMI) { SILInstruction * SILCombiner::visitObjCToThickMetatypeInst(ObjCToThickMetatypeInst *OCTTMI) { + if (OCTTMI->getFunction()->hasOwnership()) + return nullptr; + // Perform the following transformations: // (objc_to_thick_metatype (metatype @objc_metatype)) -> // (metatype @thick) @@ -440,6 +477,9 @@ SILCombiner::visitObjCToThickMetatypeInst(ObjCToThickMetatypeInst *OCTTMI) { SILInstruction * SILCombiner::visitCheckedCastBranchInst(CheckedCastBranchInst *CBI) { + if (CBI->getFunction()->hasOwnership()) + return nullptr; + if (CastOpt.optimizeCheckedCastBranchInst(CBI)) MadeChange = true; @@ -449,6 +489,9 @@ SILCombiner::visitCheckedCastBranchInst(CheckedCastBranchInst *CBI) { SILInstruction * SILCombiner:: visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) { + if (CCABI->getFunction()->hasOwnership()) + return nullptr; + if (CastOpt.optimizeCheckedCastAddrBranchInst(CCABI)) MadeChange = true; @@ -457,6 +500,9 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) { SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst( ConvertEscapeToNoEscapeInst *Cvt) { + if (Cvt->getFunction()->hasOwnership()) + return nullptr; + auto *OrigThinToThick = dyn_cast(Cvt->getConverted()); if (!OrigThinToThick) @@ -470,6 +516,9 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst( } SILInstruction *SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *CFI) { + if (CFI->getFunction()->hasOwnership()) + return nullptr; + // If this conversion only changes substitutions, then rewrite applications // of the converted function as applications of the original. // diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp index e1066b6c02923..743e4a83d0b5a 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp @@ -42,6 +42,8 @@ static llvm::cl::opt SILInstruction* SILCombiner::visitAllocExistentialBoxInst(AllocExistentialBoxInst *AEBI) { + if (AEBI->getFunction()->hasOwnership()) + return nullptr; // Optimize away the pattern below that happens when exceptions are created // and in some cases, due to inlining, are not needed. @@ -171,6 +173,9 @@ static EnumElementDecl *getInjectEnumCaseTo(SILValue Addr) { } SILInstruction *SILCombiner::visitSwitchEnumAddrInst(SwitchEnumAddrInst *SEAI) { + if (SEAI->getFunction()->hasOwnership()) + return nullptr; + // Convert switch_enum_addr -> br // if the only thing which writes to the address is an inject_enum_addr. SILValue Addr = SEAI->getOperand(); @@ -206,6 +211,9 @@ SILInstruction *SILCombiner::visitSwitchEnumAddrInst(SwitchEnumAddrInst *SEAI) { } SILInstruction *SILCombiner::visitSelectEnumAddrInst(SelectEnumAddrInst *SEAI) { + if (SEAI->getFunction()->hasOwnership()) + return nullptr; + // Canonicalize a select_enum_addr: if the default refers to exactly one case, // then replace the default with that case. Builder.setCurrentDebugScope(SEAI->getDebugScope()); @@ -249,11 +257,10 @@ SILInstruction *SILCombiner::visitSelectEnumAddrInst(SelectEnumAddrInst *SEAI) { return I; } -SILInstruction *SILCombiner::visitSelectValueInst(SelectValueInst *SVI) { - return nullptr; -} - SILInstruction *SILCombiner::visitSwitchValueInst(SwitchValueInst *SVI) { + if (SVI->getFunction()->hasOwnership()) + return nullptr; + SILValue Cond = SVI->getOperand(); BuiltinIntegerType *CondTy = Cond->getType().getAs(); if (!CondTy || !CondTy->isFixedWidth(1)) @@ -448,6 +455,9 @@ static bool somethingIsRetained(SILInstruction *from, AllocStackInst *alloc) { } SILInstruction *SILCombiner::visitAllocStackInst(AllocStackInst *AS) { + if (AS->getFunction()->hasOwnership()) + return nullptr; + // If we are testing SILCombine and we are asked not to eliminate // alloc_stacks, just return. if (DisableAllocStackOpts) @@ -559,6 +569,9 @@ SILInstruction *SILCombiner::visitAllocStackInst(AllocStackInst *AS) { } SILInstruction *SILCombiner::visitAllocRefInst(AllocRefInst *AR) { + if (AR->getFunction()->hasOwnership()) + return nullptr; + if (!AR) return nullptr; // Check if the only uses are deallocating stack or deallocating. @@ -695,6 +708,9 @@ static bool isZeroLoadFromEmptyCollection(LoadInst *LI) { } SILInstruction *SILCombiner::visitLoadInst(LoadInst *LI) { + if (LI->getFunction()->hasOwnership()) + return nullptr; + // (load (upcast-ptr %x)) -> (upcast-ref (load %x)) Builder.setCurrentDebugScope(LI->getDebugScope()); if (auto *UI = dyn_cast(LI->getOperand())) { @@ -727,6 +743,9 @@ SILInstruction *SILCombiner::visitLoadInst(LoadInst *LI) { /// -> /// %2 = index_addr %ptr, x+y SILInstruction *SILCombiner::visitIndexAddrInst(IndexAddrInst *IA) { + if (IA->getFunction()->hasOwnership()) + return nullptr; + unsigned index = 0; SILValue base = isConstIndexAddr(IA, index); if (!base) @@ -743,6 +762,8 @@ SILInstruction *SILCombiner::visitIndexAddrInst(IndexAddrInst *IA) { } SILInstruction *SILCombiner::visitReleaseValueInst(ReleaseValueInst *RVI) { + assert(!RVI->getFunction()->hasOwnership()); + SILValue Operand = RVI->getOperand(); SILType OperandTy = Operand->getType(); @@ -779,6 +800,8 @@ SILInstruction *SILCombiner::visitReleaseValueInst(ReleaseValueInst *RVI) { } SILInstruction *SILCombiner::visitRetainValueInst(RetainValueInst *RVI) { + assert(!RVI->getFunction()->hasOwnership()); + SILValue Operand = RVI->getOperand(); SILType OperandTy = Operand->getType(); @@ -845,17 +868,10 @@ SILInstruction *SILCombiner::visitRetainValueInst(RetainValueInst *RVI) { return nullptr; } -SILInstruction * -SILCombiner::visitReleaseValueAddrInst(ReleaseValueAddrInst *RVI) { - return nullptr; -} - -SILInstruction * -SILCombiner::visitRetainValueAddrInst(RetainValueAddrInst *RVI) { - return nullptr; -} - SILInstruction *SILCombiner::visitCondFailInst(CondFailInst *CFI) { + if (CFI->getFunction()->hasOwnership()) + return nullptr; + // Remove runtime asserts such as overflow checks and bounds checks. if (RemoveCondFails) return eraseInstFromFunction(*CFI); @@ -894,6 +910,8 @@ SILInstruction *SILCombiner::visitCondFailInst(CondFailInst *CFI) { } SILInstruction *SILCombiner::visitStrongRetainInst(StrongRetainInst *SRI) { + assert(!SRI->getFunction()->hasOwnership()); + // Retain of ThinToThickFunction is a no-op. SILValue funcOper = SRI->getOperand(); if (auto *CFI = dyn_cast(funcOper)) @@ -1024,6 +1042,9 @@ static SILValue createValueFromAddr(SILValue addr, SILBuilder *builder, /// We leave the cleaning up to mem2reg. SILInstruction * SILCombiner::visitInjectEnumAddrInst(InjectEnumAddrInst *IEAI) { + if (IEAI->getFunction()->hasOwnership()) + return nullptr; + // Given an inject_enum_addr of a concrete type without payload, promote it to // a store of an enum. Mem2reg/load forwarding will clean things up for us. We // can't handle the payload case here due to the flow problems caused by the @@ -1334,6 +1355,9 @@ SILCombiner::visitInjectEnumAddrInst(InjectEnumAddrInst *IEAI) { SILInstruction * SILCombiner:: visitUnreachableInst(UnreachableInst *UI) { + if (UI->getFunction()->hasOwnership()) + return nullptr; + // Make sure that this unreachable instruction // is the last instruction in the basic block. if (UI->getParent()->getTerminator() == UI) @@ -1366,6 +1390,9 @@ visitUnreachableInst(UnreachableInst *UI) { SILInstruction * SILCombiner:: visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *TEDAI) { + if (TEDAI->getFunction()->hasOwnership()) + return nullptr; + // If our TEDAI has no users, there is nothing to do. if (TEDAI->use_empty()) return nullptr; @@ -1438,6 +1465,8 @@ visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *TEDAI) { } SILInstruction *SILCombiner::visitStrongReleaseInst(StrongReleaseInst *SRI) { + assert(!SRI->getFunction()->hasOwnership()); + // Release of ThinToThickFunction is a no-op. if (isa(SRI->getOperand())) return eraseInstFromFunction(*SRI); @@ -1467,6 +1496,9 @@ SILInstruction *SILCombiner::visitStrongReleaseInst(StrongReleaseInst *SRI) { } SILInstruction *SILCombiner::visitCondBranchInst(CondBranchInst *CBI) { + if (CBI->getFunction()->hasOwnership()) + return nullptr; + // cond_br(xor(x, 1)), t_label, f_label -> cond_br x, f_label, t_label // cond_br(x == 0), t_label, f_label -> cond_br x, f_label, t_label // cond_br(x != 1), t_label, f_label -> cond_br x, f_label, t_label @@ -1599,6 +1631,9 @@ SILInstruction *SILCombiner::visitCondBranchInst(CondBranchInst *CBI) { } SILInstruction *SILCombiner::visitSelectEnumInst(SelectEnumInst *SEI) { + if (SEI->getFunction()->hasOwnership()) + return nullptr; + // Canonicalize a select_enum: if the default refers to exactly one case, then // replace the default with that case. if (SEI->hasDefault()) { @@ -1645,6 +1680,9 @@ SILInstruction *SILCombiner::visitSelectEnumInst(SelectEnumInst *SEI) { } SILInstruction *SILCombiner::visitTupleExtractInst(TupleExtractInst *TEI) { + if (TEI->getFunction()->hasOwnership()) + return nullptr; + // tuple_extract(apply([add|sub|...]overflow(x, 0)), 1) -> 0 // if it can be proven that no overflow can happen. if (TEI->getFieldNo() != 1) @@ -1659,6 +1697,9 @@ SILInstruction *SILCombiner::visitTupleExtractInst(TupleExtractInst *TEI) { } SILInstruction *SILCombiner::visitFixLifetimeInst(FixLifetimeInst *FLI) { + if (FLI->getFunction()->hasOwnership()) + return nullptr; + // fix_lifetime(alloc_stack) -> fix_lifetime(load(alloc_stack)) Builder.setCurrentDebugScope(FLI->getDebugScope()); if (auto *AI = dyn_cast(FLI->getOperand())) { @@ -1674,6 +1715,9 @@ SILInstruction *SILCombiner::visitFixLifetimeInst(FixLifetimeInst *FLI) { SILInstruction * SILCombiner:: visitAllocRefDynamicInst(AllocRefDynamicInst *ARDI) { + if (ARDI->getFunction()->hasOwnership()) + return nullptr; + SmallVector Counts; auto getCounts = [&] (AllocRefDynamicInst *AI) -> ArrayRef { for (Operand &Op : AI->getTailAllocatedCounts()) { @@ -1741,11 +1785,10 @@ visitAllocRefDynamicInst(AllocRefDynamicInst *ARDI) { return NewInst; } -SILInstruction *SILCombiner::visitEnumInst(EnumInst *EI) { - return nullptr; -} - SILInstruction *SILCombiner::visitMarkDependenceInst(MarkDependenceInst *mdi) { + if (mdi->getFunction()->hasOwnership()) + return nullptr; + // Simplify the base operand of a MarkDependenceInst to eliminate unnecessary // instructions that aren't adding value. // @@ -1795,6 +1838,9 @@ SILInstruction *SILCombiner::visitMarkDependenceInst(MarkDependenceInst *mdi) { SILInstruction *SILCombiner:: visitClassifyBridgeObjectInst(ClassifyBridgeObjectInst *CBOI) { + if (CBOI->getFunction()->hasOwnership()) + return nullptr; + auto *URC = dyn_cast(CBOI->getOperand()); if (!URC) return nullptr; diff --git a/lib/SILOptimizer/Transforms/CSE.cpp b/lib/SILOptimizer/Transforms/CSE.cpp index 6dcb18f2b50c4..696a137963041 100644 --- a/lib/SILOptimizer/Transforms/CSE.cpp +++ b/lib/SILOptimizer/Transforms/CSE.cpp @@ -965,6 +965,11 @@ bool CSE::canHandle(SILInstruction *Inst) { if (isLazyPropertyGetter(AI)) return true; + + if (SILFunction *callee = AI->getReferencedFunctionOrNull()) { + if (callee->isGlobalInit()) + return true; + } return false; } diff --git a/lib/SILOptimizer/Transforms/DestroyHoisting.cpp b/lib/SILOptimizer/Transforms/DestroyHoisting.cpp index eeff8989096d1..a4a5957c1bd69 100644 --- a/lib/SILOptimizer/Transforms/DestroyHoisting.cpp +++ b/lib/SILOptimizer/Transforms/DestroyHoisting.cpp @@ -549,15 +549,17 @@ SILValue DestroyHoisting::createAddress(unsigned locIdx, SILBuilder &builder) { assert(!isa(loc->representativeValue) && "only a root location can be a begin_access"); - SingleValueInstruction *&cachedProj = addressProjections[locIdx]; - if (cachedProj) - return cachedProj; - if (!domTree) domTree = DA->get(function); + + SILInstruction *ip = &*builder.getInsertionPoint(); + + SingleValueInstruction *&cachedProj = addressProjections[locIdx]; + if (cachedProj && domTree->properlyDominates(cachedProj, ip)) + return cachedProj; auto *projInst = cast(loc->representativeValue); - if (domTree->properlyDominates(projInst, &*builder.getInsertionPoint())) { + if (domTree->properlyDominates(projInst, ip)) { cachedProj = projInst; return projInst; } @@ -580,7 +582,7 @@ SILValue DestroyHoisting::createAddress(unsigned locIdx, SILBuilder &builder) { newProj = projBuilder.createTupleElementAddr(TEA->getLoc(), baseAddr, TEA->getFieldNo(), TEA->getType()); } - assert(domTree->properlyDominates(newProj, &*builder.getInsertionPoint()) && + assert(domTree->properlyDominates(newProj, ip) && "new projection does not dominate insert point"); // We need to remember the new projection instruction because in tailMerging // we might call locations.getLocationIdx() on such a new instruction. diff --git a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp index 3b6bd7da1d04b..4b54d12b759c3 100644 --- a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp +++ b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp @@ -327,6 +327,8 @@ static bool hasOpaqueArchetype(TypeExpansionContext context, case SILInstructionKind::CondFailInst: case SILInstructionKind::DestructureStructInst: case SILInstructionKind::DestructureTupleInst: + case SILInstructionKind::DifferentiableFunctionInst: + case SILInstructionKind::DifferentiableFunctionExtractInst: case SILInstructionKind::DifferentiabilityWitnessFunctionInst: // Handle by operand and result check. break; diff --git a/lib/SILOptimizer/Utils/SILInliner.cpp b/lib/SILOptimizer/Utils/SILInliner.cpp index cf760936987e8..476d2927dc9dd 100644 --- a/lib/SILOptimizer/Utils/SILInliner.cpp +++ b/lib/SILOptimizer/Utils/SILInliner.cpp @@ -875,6 +875,8 @@ InlineCost swift::instructionInlineCost(SILInstruction &I) { case SILInstructionKind::SelectValueInst: case SILInstructionKind::KeyPathInst: case SILInstructionKind::GlobalValueInst: + case SILInstructionKind::DifferentiableFunctionInst: + case SILInstructionKind::DifferentiableFunctionExtractInst: case SILInstructionKind::DifferentiabilityWitnessFunctionInst: #define COMMON_ALWAYS_OR_SOMETIMES_LOADABLE_CHECKED_REF_STORAGE(Name) \ case SILInstructionKind::Name##ToRefInst: \ diff --git a/lib/Sema/CMakeLists.txt b/lib/Sema/CMakeLists.txt index c7ed5de620c91..f7a9958c76c1e 100644 --- a/lib/Sema/CMakeLists.txt +++ b/lib/Sema/CMakeLists.txt @@ -1,4 +1,3 @@ - add_swift_host_library(swiftSema STATIC BuilderTransform.cpp CSApply.cpp @@ -16,6 +15,7 @@ add_swift_host_library(swiftSema STATIC ConstraintLocator.cpp ConstraintSystem.cpp DebuggerTestingTransform.cpp + DerivedConformanceAdditiveArithmetic.cpp DerivedConformanceCaseIterable.cpp DerivedConformanceCodable.cpp DerivedConformanceCodingKey.cpp diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index f2a5f84f4689a..4e5df5613376e 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -3132,9 +3132,6 @@ namespace { nameLoc.getBaseNameLoc(), toType)); } - case OverloadChoiceKind::BaseType: - return base; - case OverloadChoiceKind::KeyPathApplication: llvm_unreachable("should only happen in a subscript"); @@ -8028,6 +8025,17 @@ Optional ConstraintSystem::applySolution( } } + // If there are no fixes recorded but score indicates that there + // should have been at least one, let's fail application and + // produce a fallback diagnostic to highlight the problem. + { + const auto &score = solution.getFixedScore(); + if (score.Data[SK_Fix] > 0 || score.Data[SK_Hole] > 0) { + maybeProduceFallbackDiagnostic(target); + return None; + } + } + ExprRewriter rewriter(*this, solution, shouldSuppressDiagnostics()); ExprWalker walker(rewriter); auto resultTarget = walker.rewriteTarget(target); diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 285ca364c1e74..0782046b41413 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1086,6 +1086,10 @@ bool TypeVariableBinding::attempt(ConstraintSystem &cs) const { cs.DefaultedConstraints.push_back(srcLocator); if (type->isHole()) { + // Reflect in the score that this type variable couldn't be + // resolved and had to be bound to a placeholder "hole" type. + cs.increaseScore(SK_Hole); + if (auto *GP = TypeVar->getImpl().getGenericParameter()) { auto path = dstLocator->getPath(); // Drop `generic parameter` locator element so that all missing diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp index 297aca198d2e7..83d4782865231 100644 --- a/lib/Sema/CSDiagnostics.cpp +++ b/lib/Sema/CSDiagnostics.cpp @@ -674,6 +674,15 @@ bool GenericArgumentsMismatchFailure::diagnoseAsError() { } break; } + + case ConstraintLocator::OptionalPayload: { + // If we have an inout expression, this comes from an + // InoutToPointer argument mismatch failure. + if (isa(anchor)) { + diagnostic = diag::cannot_convert_argument_value; + } + break; + } case ConstraintLocator::TupleElement: { auto *anchor = getRawAnchor(); @@ -3710,15 +3719,19 @@ bool AllowTypeOrInstanceMemberFailure::diagnoseAsError() { } // Fall back to a fix-it with a full type qualifier - if (auto *NTD = Member->getDeclContext()->getSelfNominalTypeDecl()) { - auto type = NTD->getSelfInterfaceType(); - if (auto *SE = dyn_cast(getRawAnchor())) { - auto *baseExpr = SE->getBase(); - Diag->fixItReplace(baseExpr->getSourceRange(), diag::replace_with_type, - type); - } else { - Diag->fixItInsert(loc, diag::insert_type_qualification, type); - } + const Expr *baseExpr = nullptr; + if (const auto SE = dyn_cast(getRawAnchor())) + baseExpr = SE->getBase(); + else if (const auto UDE = dyn_cast(getRawAnchor())) + baseExpr = UDE->getBase(); + + // An implicit 'self' reference base expression means we should + // prepend with qualification. + if (baseExpr && !baseExpr->isImplicit()) { + Diag->fixItReplace(baseExpr->getSourceRange(), + diag::replace_with_type, baseTy); + } else { + Diag->fixItInsert(loc, diag::insert_type_qualification, baseTy); } return true; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 0139cb7b67727..1546b99a2d84a 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -931,6 +931,9 @@ namespace { = { nullptr, nullptr }; unsigned currentEditorPlaceholderVariable = 0; + /// Keep track of acceptable DiscardAssignmentExpr's. + llvm::SmallPtrSet CorrectDiscardAssignmentExprs; + /// Returns false and emits the specified diagnostic if the member reference /// base is a nil literal. Returns true otherwise. bool isValidBaseOfMemberRef(Expr *base, Diag<> diagnostic) { @@ -1155,7 +1158,18 @@ namespace { while (parentExpr && isa(parentExpr)) parentExpr = CS.getParentExpr(parentExpr); + // In cases like `_ = nil?` AST would have `nil` + // wrapped in `BindOptionalExpr`. + if (parentExpr && isa(parentExpr)) + parentExpr = CS.getParentExpr(parentExpr); + if (parentExpr) { + // `_ = nil as? ...` + if (isa(parentExpr)) { + DE.diagnose(expr->getLoc(), diag::conditional_cast_from_nil); + return Type(); + } + // `_ = nil!` if (isa(parentExpr)) { DE.diagnose(expr->getLoc(), diag::cannot_force_unwrap_nil_literal); @@ -1167,13 +1181,13 @@ namespace { DE.diagnose(expr->getLoc(), diag::unresolved_nil_literal); return Type(); } - } - // `_ = nil` - if (auto *assignment = dyn_cast_or_null(parentExpr)) { - if (isa(assignment->getDest())) { - DE.diagnose(expr->getLoc(), diag::unresolved_nil_literal); - return Type(); + // `_ = nil` + if (auto *assignment = dyn_cast(parentExpr)) { + if (isa(assignment->getDest())) { + DE.diagnose(expr->getLoc(), diag::unresolved_nil_literal); + return Type(); + } } } @@ -2982,26 +2996,6 @@ namespace { if (!fromExpr) // Either wasn't constructed correctly or wasn't folded. return nullptr; - std::function nilLiteralExpr = [&](Expr *expr) -> Expr * { - expr = expr->getSemanticsProvidingExpr(); - if (expr->getKind() == ExprKind::NilLiteral) - return expr; - - if (auto *optionalEvalExpr = dyn_cast(expr)) - return nilLiteralExpr(optionalEvalExpr->getSubExpr()); - - if (auto *bindOptionalExpr = dyn_cast(expr)) - return nilLiteralExpr(bindOptionalExpr->getSubExpr()); - - return nullptr; - }; - - if (auto nilLiteral = nilLiteralExpr(fromExpr)) { - ctx.Diags.diagnose(nilLiteral->getLoc(), - diag::conditional_cast_from_nil); - return nullptr; - } - // Validate the resulting type. TypeResolutionOptions options(TypeResolverContext::ExplicitCastExpr); options |= TypeResolutionFlags::AllowUnboundGenerics; @@ -3066,9 +3060,15 @@ namespace { } Type visitDiscardAssignmentExpr(DiscardAssignmentExpr *expr) { + /// Diagnose a '_' that isn't on the immediate LHS of an assignment. + if (!CorrectDiscardAssignmentExprs.count(expr)) { + auto &DE = CS.getASTContext().Diags; + DE.diagnose(expr->getLoc(), diag::discard_expr_outside_of_assignment); + return Type(); + } + auto locator = CS.getConstraintLocator(expr); - auto typeVar = CS.createTypeVariable(locator, TVO_CanBindToNoEscape | - TVO_CanBindToHole); + auto typeVar = CS.createTypeVariable(locator, TVO_CanBindToNoEscape); return LValueType::get(typeVar); } @@ -3104,6 +3104,25 @@ namespace { } } + /// Scout out the specified destination of an AssignExpr to recursively + /// identify DiscardAssignmentExpr in legal places. We can only allow them + /// in simple pattern-like expressions, so we reject anything complex here. + void markAcceptableDiscardExprs(Expr *E) { + if (!E) return; + + if (auto *PE = dyn_cast(E)) + return markAcceptableDiscardExprs(PE->getSubExpr()); + if (auto *TE = dyn_cast(E)) { + for (auto &elt : TE->getElements()) + markAcceptableDiscardExprs(elt); + return; + } + if (auto *DAE = dyn_cast(E)) + CorrectDiscardAssignmentExprs.insert(DAE); + + // Otherwise, we can't support this. + } + Type visitAssignExpr(AssignExpr *expr) { // Handle invalid code. if (!expr->getDest() || !expr->getSrc()) @@ -3933,6 +3952,9 @@ namespace { return { false, expr }; } + if (auto *assignment = dyn_cast(expr)) + CG.markAcceptableDiscardExprs(assignment->getDest()); + return { true, expr }; } diff --git a/lib/Sema/CSRanking.cpp b/lib/Sema/CSRanking.cpp index e4d75539dd0fb..1100a309608b3 100644 --- a/lib/Sema/CSRanking.cpp +++ b/lib/Sema/CSRanking.cpp @@ -41,6 +41,10 @@ void ConstraintSystem::increaseScore(ScoreKind kind, unsigned value) { log.indent(solverState->depth * 2); log << "(increasing score due to "; switch (kind) { + case SK_Hole: + log << "hole in the constraint system"; + break; + case SK_Unavailable: log << "use of an unavailable declaration"; break; @@ -143,7 +147,6 @@ static bool sameOverloadChoice(const OverloadChoice &x, return false; switch (x.getKind()) { - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::KeyPathApplication: // FIXME: Compare base types after substitution? return true; @@ -895,7 +898,6 @@ SolutionCompareResult ConstraintSystem::compareSolutions( case OverloadChoiceKind::TupleIndex: continue; - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::KeyPathApplication: llvm_unreachable("Never considered different"); diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index a75bf1890b228..40cd26c18bc06 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -135,7 +135,6 @@ static bool areConservativelyCompatibleArgumentLabels( decl = choice.getDecl(); break; - case OverloadChoiceKind::BaseType: // KeyPath application is not filtered in `performMemberLookup`. case OverloadChoiceKind::KeyPathApplication: case OverloadChoiceKind::DynamicMemberLookup: @@ -2395,10 +2394,7 @@ ConstraintSystem::matchTypesBindTypeVar( TypeMatchOptions flags, ConstraintLocatorBuilder locator, llvm::function_ref formUnsolvedResult) { assert(typeVar->is() && "Expected a type variable!"); - // FIXME: Due to some SE-0110 related code farther up we can end - // up with type variables wrapped in parens that will trip this - // assert. For now, maintain the existing behavior. - // assert(!type->is() && "Expected a non-type variable!"); + assert(!type->is() && "Expected a non-type variable!"); // Simplify the right-hand type and perform the "occurs" check. typeVar = getRepresentative(typeVar); @@ -4105,14 +4101,8 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind, return getTypeMatchSuccess(); } - assert((type1->is() || type2->is()) && - "Expected a type variable!"); - // FIXME: Due to some SE-0110 related code farther up we can end - // up with type variables wrapped in parens that will trip this - // assert. For now, maintain the existing behavior. - // assert( - // (!type1->is() || !type2->is()) - // && "Expected a non-type variable!"); + assert((type1->is() != type2->is()) && + "Expected a type variable and a non type variable!"); auto *typeVar = typeVar1 ? typeVar1 : typeVar2; auto type = typeVar1 ? type2 : type1; @@ -5083,7 +5073,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( // func foo(_: T) {} // foo(Foo.bar) <- if `Foo` doesn't have `bar` there is // no reason to complain about missing conformance. - increaseScore(SK_Fix); return SolutionKind::Solved; } @@ -6476,7 +6465,6 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint( baseObjTy->isHole()) { // If base type is a "hole" there is no reason to record any // more "member not found" fixes for chained member references. - increaseScore(SK_Fix); markMemberTypeAsPotentialHole(memberTy); return SolutionKind::Solved; } diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 870370f7d4b0f..842691afb44a8 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1139,23 +1139,6 @@ static bool debugConstraintSolverForTarget( return startBound != endBound; } -/// If we aren't certain that we've emitted a diagnostic, emit a fallback -/// diagnostic. -static void maybeProduceFallbackDiagnostic( - ConstraintSystem &cs, SolutionApplicationTarget target) { - if (cs.Options.contains(ConstraintSystemFlags::SuppressDiagnostics)) - return; - - // Before producing fatal error here, let's check if there are any "error" - // diagnostics already emitted or waiting to be emitted. Because they are - // a better indication of the problem. - ASTContext &ctx = cs.getASTContext(); - if (ctx.Diags.hadAnyError() || ctx.hasDelayedConformanceErrors()) - return; - - ctx.Diags.diagnose(target.getLoc(), diag::failed_to_produce_diagnostic); -} - Optional> ConstraintSystem::solve( SolutionApplicationTarget &target, ExprTypeCheckListener *listener, @@ -1201,7 +1184,7 @@ Optional> ConstraintSystem::solve( } case SolutionResult::Error: - maybeProduceFallbackDiagnostic(*this, target); + maybeProduceFallbackDiagnostic(target); return None; case SolutionResult::TooComplex: diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp index a895d469d0876..c24ca5c075a63 100644 --- a/lib/Sema/CodeSynthesis.cpp +++ b/lib/Sema/CodeSynthesis.cpp @@ -1201,6 +1201,93 @@ SynthesizeMemberwiseInitRequest::evaluate(Evaluator &evaluator, return ctor; } +llvm::Expected +ResolveEffectiveMemberwiseInitRequest::evaluate(Evaluator &evaluator, + NominalTypeDecl *decl) const { + // Compute the access level for the memberwise initializer. The minimum of: + // - Public, by default. This enables public nominal types to have public + // memberwise initializers. + // - The `public` default is important for synthesized member types, e.g. + // `TangentVector` structs synthesized during `Differentiable` derived + // conformances. Manually extending these types to define a public + // memberwise initializer causes a redeclaration error. + // - The minimum access level of memberwise-initialized properties in the + // nominal type declaration. + auto accessLevel = AccessLevel::Public; + for (auto *member : decl->getMembers()) { + auto *var = dyn_cast(member); + if (!var || + !var->isMemberwiseInitialized(/*preferDeclaredProperties*/ true)) + continue; + accessLevel = std::min(accessLevel, var->getFormalAccess()); + } + auto &ctx = decl->getASTContext(); + + // If a memberwise initializer exists, set its access level and return it. + if (auto *initDecl = decl->getMemberwiseInitializer()) { + initDecl->overwriteAccess(accessLevel); + return initDecl; + } + + auto isEffectiveMemberwiseInitializer = [&](ConstructorDecl *initDecl) { + // Check for `nullptr`. + if (!initDecl) + return false; + // Get all stored properties, excluding `let` properties with initial + // values. + SmallVector storedProperties; + for (auto *vd : decl->getStoredProperties()) { + if (vd->isLet() && vd->hasInitialValue()) + continue; + storedProperties.push_back(vd); + } + // Return false if initializer does not have interface type set. It is not + // possible to determine whether it is a memberwise initializer. + if (!initDecl->hasInterfaceType()) + return false; + auto initDeclType = + initDecl->getMethodInterfaceType()->getAs(); + // Return false if initializer does not have a valid interface type. + if (!initDeclType) + return false; + // Return false if stored property count does not have parameter count. + if (storedProperties.size() != initDeclType->getNumParams()) + return false; + // Return true if all stored property types/names match initializer + // parameter types/labels. + return llvm::all_of( + llvm::zip(storedProperties, initDeclType->getParams()), + [&](std::tuple pair) { + auto *storedProp = std::get<0>(pair); + auto param = std::get<1>(pair); + return storedProp->getInterfaceType()->isEqual( + param.getPlainType()) && + storedProp->getName() == param.getLabel(); + }); + }; + + // Otherwise, look for a user-defined effective memberwise initializer. + ConstructorDecl *memberwiseInitDecl = nullptr; + auto initDecls = decl->lookupDirect(DeclBaseName::createConstructor()); + for (auto *decl : initDecls) { + auto *initDecl = dyn_cast(decl); + if (!isEffectiveMemberwiseInitializer(initDecl)) + continue; + assert(!memberwiseInitDecl && "Memberwise initializer already found"); + memberwiseInitDecl = initDecl; + } + + // Otherwise, create a memberwise initializer, set its access level, and + // return it. + if (!memberwiseInitDecl) { + memberwiseInitDecl = createImplicitConstructor( + decl, ImplicitConstructorKind::Memberwise, ctx); + memberwiseInitDecl->overwriteAccess(accessLevel); + decl->addMember(memberwiseInitDecl); + } + return memberwiseInitDecl; +} + llvm::Expected HasDefaultInitRequest::evaluate(Evaluator &evaluator, NominalTypeDecl *decl) const { @@ -1263,3 +1350,22 @@ SynthesizeDefaultInitRequest::evaluate(Evaluator &evaluator, ctor->setBodySynthesizer(synthesizeSingleReturnFunctionBody); return ctor; } + +ValueDecl *swift::getProtocolRequirement(ProtocolDecl *protocol, + Identifier name) { + auto lookup = protocol->lookupDirect(name); + // Erase declarations that are not protocol requirements. + // This is important for removing default implementations of the same name. + llvm::erase_if(lookup, [](ValueDecl *v) { + return !isa(v->getDeclContext()) || + !v->isProtocolRequirement(); + }); + assert(lookup.size() == 1 && "Ambiguous protocol requirement"); + return lookup.front(); +} + +bool swift::hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { + return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { + return v->isLet() && v->hasInitialValue(); + }); +} diff --git a/lib/Sema/CodeSynthesis.h b/lib/Sema/CodeSynthesis.h index 50032b3021cf1..096fe48ecad04 100644 --- a/lib/Sema/CodeSynthesis.h +++ b/lib/Sema/CodeSynthesis.h @@ -60,6 +60,12 @@ Expr *buildSelfReference(VarDecl *selfDecl, Expr *buildArgumentForwardingExpr(ArrayRef params, ASTContext &ctx); +/// Returns the protocol requirement with the specified name. +ValueDecl *getProtocolRequirement(ProtocolDecl *protocol, Identifier name); + +// Returns true if given nominal type has a `let` stored with an initial value. +bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal); + } // end namespace swift #endif diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index 0b3ebb8e1a7cd..438a70355d491 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -405,9 +405,6 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const { case OverloadChoiceKind::KeyPathDynamicMemberLookup: Out << "dynamic member lookup '" << overload.getName() << "'"; break; - case OverloadChoiceKind::BaseType: - Out << "base type"; - break; case OverloadChoiceKind::TupleIndex: Out << "tuple index " << overload.getTupleIndex(); break; diff --git a/lib/Sema/ConstraintGraph.cpp b/lib/Sema/ConstraintGraph.cpp index 9a9bdd8617de3..82c9a441522cd 100644 --- a/lib/Sema/ConstraintGraph.cpp +++ b/lib/Sema/ConstraintGraph.cpp @@ -428,6 +428,9 @@ void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1, } void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) { + assert(!fixed->is() && + "Cannot bind to type variable; merge equivalence classes instead"); + // If there are no type variables in the fixed type, there's nothing to do. if (!fixed->hasTypeVariable()) return; diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 0d3cc987537a3..b2ed04447aad4 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -1549,7 +1549,6 @@ Type ConstraintSystem::getEffectiveOverloadType(const OverloadChoice &overload, // Declaration choices are handled below. break; - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::DeclViaBridge: case OverloadChoiceKind::DeclViaDynamic: case OverloadChoiceKind::DeclViaUnwrappedOptional: @@ -1918,7 +1917,6 @@ std::pair ConstraintSystem::adjustTypeOfOverloadReference( case OverloadChoiceKind::DeclViaBridge: case OverloadChoiceKind::DeclViaUnwrappedOptional: case OverloadChoiceKind::TupleIndex: - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::KeyPathApplication: return {refType, bindConstraintCreated}; case OverloadChoiceKind::DeclViaDynamic: { @@ -2020,7 +2018,6 @@ void ConstraintSystem::bindOverloadType( case OverloadChoiceKind::DeclViaBridge: case OverloadChoiceKind::DeclViaUnwrappedOptional: case OverloadChoiceKind::TupleIndex: - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::KeyPathApplication: case OverloadChoiceKind::DeclViaDynamic: bindTypeOrIUO(openedType); @@ -2272,10 +2269,6 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, break; } - case OverloadChoiceKind::BaseType: - refType = choice.getBaseType(); - break; - case OverloadChoiceKind::TupleIndex: if (auto lvalueTy = choice.getBaseType()->getAs()) { // When the base of a tuple lvalue, the member is always an lvalue. @@ -2532,7 +2525,6 @@ DeclName OverloadChoice::getName() const { case OverloadChoiceKind::KeyPathDynamicMemberLookup: return DeclName(DynamicMember.getPointer()); - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::TupleIndex: llvm_unreachable("no name!"); } @@ -3246,7 +3238,6 @@ bool ConstraintSystem::diagnoseAmbiguity(ArrayRef solutions) { // want them to noise up unrelated subscript diagnostics. break; - case OverloadChoiceKind::BaseType: case OverloadChoiceKind::TupleIndex: // FIXME: Actually diagnose something here. break; @@ -4396,3 +4387,20 @@ bool ConstraintSystem::isDeclUnavailable(const Decl *D, AvailabilityContext result = AvailabilityContext::alwaysAvailable(); return !TypeChecker::isDeclAvailable(D, loc, DC, result); } + +/// If we aren't certain that we've emitted a diagnostic, emit a fallback +/// diagnostic. +void ConstraintSystem::maybeProduceFallbackDiagnostic( + SolutionApplicationTarget target) const { + if (Options.contains(ConstraintSystemFlags::SuppressDiagnostics)) + return; + + // Before producing fatal error here, let's check if there are any "error" + // diagnostics already emitted or waiting to be emitted. Because they are + // a better indication of the problem. + ASTContext &ctx = getASTContext(); + if (ctx.Diags.hadAnyError() || ctx.hasDelayedConformanceErrors()) + return; + + ctx.Diags.diagnose(target.getLoc(), diag::failed_to_produce_diagnostic); +} diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index 9f7c2ad2976a6..4498a2e519c91 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -634,6 +634,8 @@ enum ScoreKind { /// A fix needs to be applied to the source. SK_Fix, + /// A hole in the constraint system. + SK_Hole, /// A reference to an @unavailable declaration. SK_Unavailable, /// A use of a disfavored overload. @@ -4643,6 +4645,10 @@ class ConstraintSystem { SmallVectorImpl &Ordering, SmallVectorImpl &PartitionBeginning); + /// If we aren't certain that we've emitted a diagnostic, emit a fallback + /// diagnostic. + void maybeProduceFallbackDiagnostic(SolutionApplicationTarget target) const; + SWIFT_DEBUG_DUMP; SWIFT_DEBUG_DUMPER(dump(Expr *)); diff --git a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp new file mode 100644 index 0000000000000..9c9d630e56cb9 --- /dev/null +++ b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp @@ -0,0 +1,329 @@ +//===--- DerivedConformanceAdditiveArithmetic.cpp -------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file implements explicit derivation of the AdditiveArithmetic protocol +// for struct types. +// +// Currently, this is gated by a frontend flag: +// `-enable-experimental-additive-arithmetic-derivation`. +// +// Swift Evolution pitch thread: +// https://forums.swift.org/t/additivearithmetic-conformance-synthesis-for-structs/26159 +// +//===----------------------------------------------------------------------===// + +#include "CodeSynthesis.h" +#include "TypeChecker.h" +#include "swift/AST/Decl.h" +#include "swift/AST/Expr.h" +#include "swift/AST/GenericSignature.h" +#include "swift/AST/Module.h" +#include "swift/AST/ParameterList.h" +#include "swift/AST/Pattern.h" +#include "swift/AST/ProtocolConformance.h" +#include "swift/AST/Stmt.h" +#include "swift/AST/Types.h" +#include "DerivedConformances.h" + +using namespace swift; + +// Represents synthesizable math operators. +enum MathOperator { + // `+(Self, Self)`: AdditiveArithmetic + Add, + // `-(Self, Self)`: AdditiveArithmetic + Subtract, +}; + +static StringRef getMathOperatorName(MathOperator op) { + switch (op) { + case Add: + return "+"; + case Subtract: + return "-"; + } +} + +bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal, + DeclContext *DC) { + // Experimental `AdditiveArithmetic` derivation must be enabled. + auto &ctx = nominal->getASTContext(); + if (!ctx.LangOpts.EnableExperimentalAdditiveArithmeticDerivedConformances) + return false; + // Nominal type must be a struct. (No stored properties is okay.) + auto *structDecl = dyn_cast(nominal); + if (!structDecl) + return false; + // Must not have any `let` stored properties with an initial value. + // - This restriction may be lifted later with support for "true" memberwise + // initializers that initialize all stored properties, including initial + // value information. + if (hasLetStoredPropertyWithInitialValue(nominal)) + return false; + // All stored properties must conform to `AdditiveArithmetic`. + auto &C = nominal->getASTContext(); + auto *proto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) { + if (v->getInterfaceType()->hasError()) + return false; + auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); + return (bool)TypeChecker::conformsToProtocol(varType, proto, DC, None); + }); +} + +// Synthesize body for math operator. +static std::pair +deriveBodyMathOperator(AbstractFunctionDecl *funcDecl, MathOperator op) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + // Create memberwise initializer: `Nominal.init(...)`. + auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer(); + assert(memberwiseInitDecl && "Memberwise initializer must exist"); + auto *initDRE = + new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + auto *nominalTypeExpr = TypeExpr::createForDecl(DeclNameLoc(), nominal, + funcDecl, /*Implicit*/ true); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr); + + // Get operator protocol requirement. + auto *proto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + auto operatorId = C.getIdentifier(getMathOperatorName(op)); + auto *operatorReq = getProtocolRequirement(proto, operatorId); + + // Create reference to operator parameters: lhs and rhs. + auto params = funcDecl->getParameters(); + + // Create expression combining lhs and rhs members using member operator. + auto createMemberOpExpr = [&](VarDecl *member) -> Expr * { + auto module = nominal->getModuleContext(); + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto confRef = module->lookupConformance(memberType, proto); + assert(confRef && "Member does not conform to math protocol"); + + // Get member type's math operator, e.g. `Member.+`. + // Use protocol requirement declaration for the operator by default: this + // will be dynamically dispatched. + ValueDecl *memberOpDecl = operatorReq; + // If conformance reference is concrete, then use concrete witness + // declaration for the operator. + if (confRef.isConcrete()) + if (auto *concreteMemberMethodDecl = + confRef.getConcrete()->getWitnessDecl(operatorReq)) + memberOpDecl = concreteMemberMethodDecl; + assert(memberOpDecl && "Member operator declaration must exist"); + auto memberOpDRE = + new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto memberOpExpr = + new (C) DotSyntaxCallExpr(memberOpDRE, SourceLoc(), memberTypeExpr); + + // Create expression `lhs.member rhs.member`. + // NOTE(TF-1054): create new `DeclRefExpr`s per loop iteration to avoid + // `ConstraintSystem::resolveOverload` error. + auto *lhsDRE = + new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true); + auto *rhsDRE = + new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true); + Expr *lhsArg = new (C) MemberRefExpr(lhsDRE, SourceLoc(), member, + DeclNameLoc(), /*Implicit*/ true); + auto *rhsArg = new (C) MemberRefExpr(rhsDRE, SourceLoc(), member, + DeclNameLoc(), /*Implicit*/ true); + auto *memberOpArgs = + TupleExpr::create(C, SourceLoc(), {lhsArg, rhsArg}, {}, {}, SourceLoc(), + /*HasTrailingClosure*/ false, + /*Implicit*/ true); + auto *memberOpCallExpr = + new (C) BinaryExpr(memberOpExpr, memberOpArgs, /*Implicit*/ true); + return memberOpCallExpr; + }; + + // Create array of member operator call expressions. + llvm::SmallVector memberOpExprs; + llvm::SmallVector memberNames; + for (auto member : nominal->getStoredProperties()) { + memberOpExprs.push_back(createMemberOpExpr(member)); + memberNames.push_back(member->getName()); + } + // Call memberwise initializer with member operator call expressions. + auto *callExpr = + CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames); + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); + return std::pair( + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true), false); +} + +// Synthesize function declaration for the given math operator. +static ValueDecl *deriveMathOperator(DerivedConformance &derived, + MathOperator op) { + auto nominal = derived.Nominal; + auto parentDC = derived.getConformanceContext(); + auto &C = derived.Context; + auto selfInterfaceType = parentDC->getDeclaredInterfaceType(); + + // Create parameter declaration with the given name and type. + auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * { + auto *param = + new (C) ParamDecl(SourceLoc(), SourceLoc(), Identifier(), SourceLoc(), + C.getIdentifier(name), parentDC); + param->setSpecifier(ParamDecl::Specifier::Default); + param->setInterfaceType(type); + return param; + }; + + ParameterList *params = + ParameterList::create(C, {createParamDecl("lhs", selfInterfaceType), + createParamDecl("rhs", selfInterfaceType)}); + + auto operatorId = C.getIdentifier(getMathOperatorName(op)); + DeclName operatorDeclName(C, operatorId, params); + auto operatorDecl = + FuncDecl::create(C, SourceLoc(), StaticSpellingKind::KeywordStatic, + SourceLoc(), operatorDeclName, SourceLoc(), + /*Throws*/ false, SourceLoc(), + /*GenericParams=*/nullptr, params, + TypeLoc::withoutLoc(selfInterfaceType), parentDC); + operatorDecl->setImplicit(); + auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl, + void *ctx) -> std::pair { + auto op = (MathOperator) reinterpret_cast(ctx); + return deriveBodyMathOperator(funcDecl, op); + }; + operatorDecl->setBodySynthesizer(bodySynthesizer, (void *)op); + operatorDecl->setGenericSignature(parentDC->getGenericSignatureOfContext()); + operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); + + derived.addMembersToConformanceContext({operatorDecl}); + return operatorDecl; +} + +// Synthesize body for a property computed property getter. +static std::pair +deriveBodyPropertyGetter(AbstractFunctionDecl *funcDecl, ProtocolDecl *proto, + ValueDecl *reqDecl) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer(); + assert(memberwiseInitDecl && "Memberwise initializer must exist"); + auto *initDRE = + new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + + auto *nominalTypeExpr = TypeExpr::createForDecl(DeclNameLoc(), nominal, + funcDecl, /*Implicit*/ true); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr); + + auto createMemberPropertyExpr = [&](VarDecl *member) -> Expr * { + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + Expr *memberExpr = nullptr; + // If the property is static, create a type expression: `Member`. + if (reqDecl->isStatic()) { + memberExpr = TypeExpr::createImplicit(memberType, C); + } + // If the property is not static, create a member ref expression: + // `self.member`. + else { + auto *selfDecl = funcDecl->getImplicitSelfDecl(); + auto *selfDRE = + new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + memberExpr = + new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), + /*Implicit*/ true); + } + auto *module = nominal->getModuleContext(); + auto confRef = module->lookupConformance(memberType, proto); + assert(confRef && "Member does not conform to `AdditiveArithmetic`"); + // If conformance reference is not concrete, then concrete witness + // declaration for property cannot be resolved. Return reference to + // protocol requirement: this will be dynamically dispatched. + if (!confRef.isConcrete()) { + return new (C) MemberRefExpr(memberExpr, SourceLoc(), reqDecl, + DeclNameLoc(), /*Implicit*/ true); + } + // Otherwise, return reference to concrete witness declaration. + auto conf = confRef.getConcrete(); + auto *witnessDecl = conf->getWitnessDecl(reqDecl); + return new (C) MemberRefExpr(memberExpr, SourceLoc(), witnessDecl, + DeclNameLoc(), /*Implicit*/ true); + }; + + // Create array of `member.` expressions. + llvm::SmallVector memberPropExprs; + llvm::SmallVector memberNames; + for (auto member : nominal->getStoredProperties()) { + memberPropExprs.push_back(createMemberPropertyExpr(member)); + memberNames.push_back(member->getName()); + } + // Call memberwise initializer with member property expressions. + auto *callExpr = + CallExpr::createImplicit(C, initExpr, memberPropExprs, memberNames); + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); + return std::pair(braceStmt, false); +} + +// Synthesize body for the `AdditiveArithmetic.zero` computed property getter. +static std::pair +deriveBodyAdditiveArithmetic_zero(AbstractFunctionDecl *funcDecl, void *) { + auto &C = funcDecl->getASTContext(); + auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + auto *zeroReq = getProtocolRequirement(addArithProto, C.Id_zero); + return deriveBodyPropertyGetter(funcDecl, addArithProto, zeroReq); +} + +// Synthesize the static property declaration for `AdditiveArithmetic.zero`. +static ValueDecl *deriveAdditiveArithmetic_zero(DerivedConformance &derived) { + auto &C = derived.Context; + auto *nominal = derived.Nominal; + auto *parentDC = derived.getConformanceContext(); + + auto returnInterfaceTy = nominal->getDeclaredInterfaceType(); + auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy); + + // Create property declaration. + VarDecl *propDecl; + PatternBindingDecl *pbDecl; + std::tie(propDecl, pbDecl) = derived.declareDerivedProperty( + C.Id_zero, returnInterfaceTy, returnTy, /*isStatic*/ true, + /*isFinal*/ true); + + // Create property getter. + auto *getterDecl = + derived.addGetterToReadOnlyDerivedProperty(propDecl, returnTy); + getterDecl->setBodySynthesizer(deriveBodyAdditiveArithmetic_zero, nullptr); + + derived.addMembersToConformanceContext({propDecl, pbDecl}); + return propDecl; +} + +ValueDecl * +DerivedConformance::deriveAdditiveArithmetic(ValueDecl *requirement) { + // Diagnose conformances in disallowed contexts. + if (checkAndDiagnoseDisallowedContext(requirement)) + return nullptr; + if (requirement->getBaseName() == Context.getIdentifier("+")) + return deriveMathOperator(*this, Add); + if (requirement->getBaseName() == Context.getIdentifier("-")) + return deriveMathOperator(*this, Subtract); + if (requirement->getBaseName() == Context.Id_zero) + return deriveAdditiveArithmetic_zero(*this); + Context.Diags.diagnose(requirement->getLoc(), + diag::broken_additive_arithmetic_requirement); + return nullptr; +} diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index fc7c23fe18873..ae2130163656b 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -66,6 +66,9 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, return canDeriveHashable(Nominal); } + if (*knownProtocol == KnownProtocolKind::AdditiveArithmetic) + return canDeriveAdditiveArithmetic(Nominal, DC); + if (auto *enumDecl = dyn_cast(Nominal)) { switch (*knownProtocol) { // The presence of a raw type is an explicit declaration that @@ -217,6 +220,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal, if (name.isSimpleName(ctx.Id_intValue)) return getRequirement(KnownProtocolKind::CodingKey); + // AdditiveArithmetic.zero + if (name.isSimpleName(ctx.Id_zero)) + return getRequirement(KnownProtocolKind::AdditiveArithmetic); + return nullptr; } @@ -228,6 +235,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal, if (func->isOperator() && name.getBaseName() == "==") return getRequirement(KnownProtocolKind::Equatable); + // AdditiveArithmetic.+ + // AdditiveArithmetic.- + if (func->isOperator() && name.getArgumentNames().size() == 2 && + (name.getBaseName() == "+" || name.getBaseName() == "-")) { + return getRequirement(KnownProtocolKind::AdditiveArithmetic); + } + // Encodable.encode(to: Encoder) if (name.isCompoundName() && name.getBaseName() == ctx.Id_encode) { auto argumentNames = name.getArgumentNames(); diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index 332dc1f4c85d6..5a3465dd1594c 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -96,6 +96,17 @@ class DerivedConformance { static ValueDecl *getDerivableRequirement(NominalTypeDecl *nominal, ValueDecl *requirement); + /// Determine if an AdditiveArithmetic requirement can be derived for a type. + /// + /// \returns True if the requirement can be derived. + static bool canDeriveAdditiveArithmetic(NominalTypeDecl *type, + DeclContext *DC); + + /// Derive an AdditiveArithmetic requirement for a nominal type. + /// + /// \returns the derived member, which will also be added to the type. + ValueDecl *deriveAdditiveArithmetic(ValueDecl *requirement); + /// Derive a CaseIterable requirement for an enum if it has no associated /// values for any of its cases. /// diff --git a/lib/Sema/LookupVisibleDecls.cpp b/lib/Sema/LookupVisibleDecls.cpp index 2fa7173610dab..445ceb400e138 100644 --- a/lib/Sema/LookupVisibleDecls.cpp +++ b/lib/Sema/LookupVisibleDecls.cpp @@ -438,6 +438,9 @@ static void lookupDeclsFromProtocolsBeingConformedTo( continue; } if (auto *VD = dyn_cast(Member)) { + if (!isDeclVisibleInLookupMode(VD, LS, FromContext)) + continue; + if (!VD->isProtocolRequirement()) continue; diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index a1550e52353c3..3360edd2aa261 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -68,9 +68,6 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC, SmallPtrSet AlreadyDiagnosedMetatypes; SmallPtrSet AlreadyDiagnosedBitCasts; - /// Keep track of acceptable DiscardAssignmentExpr's. - SmallPtrSet CorrectDiscardAssignmentExprs; - /// Keep track of the arguments to CallExprs. SmallPtrSet CallArgs; @@ -229,7 +226,6 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC, // If we have an assignment expression, scout ahead for acceptable _'s. if (auto *AE = dyn_cast(E)) { auto destExpr = AE->getDest(); - markAcceptableDiscardExprs(destExpr); // If the user is assigning the result of a function that returns // Void to _ then warn, because that is redundant. if (auto DAE = dyn_cast(destExpr)) { @@ -246,14 +242,6 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC, } } - /// Diagnose a '_' that isn't on the immediate LHS of an assignment. - if (auto *DAE = dyn_cast(E)) { - if (!CorrectDiscardAssignmentExprs.count(DAE) && - !DAE->getType()->hasError()) - Ctx.Diags.diagnose(DAE->getLoc(), - diag::discard_expr_outside_of_assignment); - } - // Diagnose 'self.init' or 'super.init' nested in another expression // or closure. if (auto *rebindSelfExpr = dyn_cast(E)) { @@ -425,26 +413,6 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC, } } - - /// Scout out the specified destination of an AssignExpr to recursively - /// identify DiscardAssignmentExpr in legal places. We can only allow them - /// in simple pattern-like expressions, so we reject anything complex here. - void markAcceptableDiscardExprs(Expr *E) { - if (!E) return; - - if (auto *PE = dyn_cast(E)) - return markAcceptableDiscardExprs(PE->getSubExpr()); - if (auto *TE = dyn_cast(E)) { - for (auto &elt : TE->getElements()) - markAcceptableDiscardExprs(elt); - return; - } - if (auto *DAE = dyn_cast(E)) - CorrectDiscardAssignmentExprs.insert(DAE); - - // Otherwise, we can't support this. - } - void checkMagicIdentifierMismatch(ConcreteDeclRef callee, unsigned uncurryLevel, unsigned argIndex, diff --git a/lib/Sema/OverloadChoice.h b/lib/Sema/OverloadChoice.h index 7d28061d4ad31..072d0cb0533de 100644 --- a/lib/Sema/OverloadChoice.h +++ b/lib/Sema/OverloadChoice.h @@ -40,10 +40,6 @@ enum class OverloadChoiceKind : int { /// found via dynamic lookup and, therefore, might not actually be /// available at runtime. DeclViaDynamic, - /// The overload choice equates the member type with the - /// base type. Used for unresolved member expressions like ".none" that - /// refer to enum members with unit type. - BaseType, /// The overload choice selects a key path subscripting operation. KeyPathApplication, /// The member is looked up using @dynamicMemberLookup. diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 1a7690fba878f..a0889d34a07b7 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3598,108 +3598,6 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) { return originalType; } -// Finds a derivative function declaration using the given function specifier, -// original function declaration, expected type, and "is valid" predicate. If no -// valid derivative function is found, emits diagnostics and returns false. -static FuncDecl *findAutoDiffDerivativeFunction( - DeclNameRefWithLoc specifier, AbstractFunctionDecl *original, - Type expectedTy, std::function isValid) { - auto &ctx = original->getASTContext(); - auto &diags = ctx.Diags; - auto noneValidDiagnostic = [&]() { - diags.diagnose(specifier.Loc, diag::differentiable_attr_overload_not_found, - specifier.Name, expectedTy); - }; - auto ambiguousDiagnostic = [&]() { - diags.diagnose(specifier.Loc, diag::attr_ambiguous_reference_to_decl, - specifier.Name, "differentiable"); - }; - auto notFunctionDiagnostic = [&]() { - diags.diagnose(specifier.Loc, - diag::differentiable_attr_derivative_not_function, - specifier.Name); - }; - std::function invalidTypeContextDiagnostic = [&]() { - diags.diagnose(specifier.Loc, - diag::differentiable_attr_function_not_same_type_context, - specifier.Name); - }; - - // Returns true if the original function and derivative function candidate are - // defined in compatible type contexts. If the original function and the - // derivative function have different parents, or if they both have no type - // context and are in different modules, return false. - std::function hasValidTypeContext = - [&](AbstractFunctionDecl *func) { - // Check if both functions are top-level. - if (!original->getInnermostTypeContext() && - !func->getInnermostTypeContext() && - original->getParentModule() == func->getParentModule()) - return true; - // Check if both functions are defined in the same type context. - if (auto typeCtx1 = original->getInnermostTypeContext()) - if (auto typeCtx2 = func->getInnermostTypeContext()) - return typeCtx1->getSelfNominalTypeDecl() == - typeCtx2->getSelfNominalTypeDecl(); - return original->getParent() == func->getParent(); - }; - - auto isABIPublic = [&](AbstractFunctionDecl *func) { - return func->getFormalAccess() >= AccessLevel::Public || - func->getAttrs().hasAttribute() || - func->getAttrs().hasAttribute(); - }; - - // If the original function is exported (i.e. it is public or - // `@usableFromInline`), then the derivative functions must also be exported. - // Returns true on error. - auto checkAccessControl = [&](AbstractFunctionDecl *func) { - if (!isABIPublic(original)) - return false; - if (isABIPublic(func)) - return false; - diags.diagnose(specifier.Loc, diag::differentiable_attr_invalid_access, - specifier.Name, original->getFullName()); - return true; - }; - - auto originalTypeCtx = original->getInnermostTypeContext(); - if (!originalTypeCtx) - originalTypeCtx = original->getParent(); - assert(originalTypeCtx); - - // Set lookup options. - auto lookupOptions = - defaultMemberLookupOptions | NameLookupFlags::IgnoreAccessControl; - - auto *candidate = findAbstractFunctionDecl( - specifier.Name, specifier.Loc.getBaseNameLoc(), /*baseType*/ Type(), - originalTypeCtx, isValid, noneValidDiagnostic, ambiguousDiagnostic, - notFunctionDiagnostic, lookupOptions, hasValidTypeContext, - invalidTypeContextDiagnostic); - if (!candidate) - return nullptr; - // Reject non-`func` registered derivatives. JVPs and VJPs must be `func` - // declarations. - if (isa(candidate)) { - diags.diagnose(specifier.Loc, - diag::differentiable_attr_derivative_not_function, - specifier.Name); - return nullptr; - } - if (checkAccessControl(candidate)) - return nullptr; - // Derivatives of class members must be final. - if (original->getDeclContext()->getSelfClassDecl() && !candidate->isFinal()) { - diags.diagnose(specifier.Loc, - diag::differentiable_attr_class_derivative_not_final); - return nullptr; - } - assert(isa(candidate)); - auto *funcDecl = cast(candidate); - return funcDecl; -} - /// Given a `@differentiable` attribute, attempts to resolve the original /// `AbstractFunctionDecl` for which it is registered, using the declaration /// on which it is actually declared. On error, emits diagnostic and returns @@ -3713,15 +3611,6 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) { auto &diags = ctx.Diags; auto *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { - // Derivative registration is unsupported for stored properties. - if (asd->getImplInfo().isSimpleStored() && - (attr->getJVP() || attr->getVJP())) { - diagnoseAndRemoveAttr( - diags, D, attr, - diag::differentiable_attr_stored_property_variable_unsupported); - attr->setInvalid(); - return nullptr; - } // If `@differentiable` attribute is declared directly on a // `AbstractStorageDecl` (a stored/computed property or subscript), // forward the attribute to the storage's getter. @@ -3917,79 +3806,6 @@ bool resolveDifferentiableAttrDifferentiabilityParameters( return false; } -/// Given a `@differentiable` attribute, attempts to resolve the JVP and VJP -/// derivative function declarations, if specified. The JVP and VJP functions -/// are returned as `jvp` and `vjp`, respectively. On error, emits diagnostic, -/// assigns `nullptr` to `jvp` and `vjp`, and returns true. -bool resolveDifferentiableAttrDerivativeFunctions( - DifferentiableAttr *attr, AbstractFunctionDecl *original, - IndexSubset *resolvedDiffParamIndices, GenericSignature derivativeGenSig, - FuncDecl *&jvp, FuncDecl *&vjp) { - jvp = nullptr; - vjp = nullptr; - - auto &ctx = original->getASTContext(); - auto &diags = ctx.Diags; - - // `@differentiable` attributes on protocol requirements do not support - // JVP/VJP. - bool isOriginalProtocolRequirement = - isa(original->getDeclContext()) && - original->isProtocolRequirement(); - if (isOriginalProtocolRequirement && (attr->getJVP() || attr->getVJP())) { - diags.diagnose(attr->getLocation(), - diag::differentiable_attr_protocol_req_assoc_func); - attr->setInvalid(); - return false; - } - - auto *originalFnTy = original->getInterfaceType()->castTo(); - auto lookupConformance = - LookUpConformanceInModule(original->getDeclContext()->getParentModule()); - - // Resolve the JVP function, if it is specified and exists. - if (attr->getJVP()) { - auto *expectedJVPFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( - resolvedDiffParamIndices, AutoDiffDerivativeFunctionKind::JVP, - lookupConformance, derivativeGenSig, /*makeSelfParamFirst*/ true); - auto isValidJVP = [&](AbstractFunctionDecl *jvpCandidate) -> bool { - return checkFunctionSignature( - cast(expectedJVPFnTy->getCanonicalType()), - jvpCandidate->getInterfaceType()->getCanonicalType()); - }; - auto *jvp = findAutoDiffDerivativeFunction( - attr->getJVP().getValue(), original, expectedJVPFnTy, isValidJVP); - if (!jvp) { - attr->setInvalid(); - return true; - } - // Set the JVP function in the attribute. - attr->setJVPFunction(jvp); - } - - // Resolve the VJP function, if it is specified and exists. - if (attr->getVJP()) { - auto *expectedVJPFnTy = originalFnTy->getAutoDiffDerivativeFunctionType( - resolvedDiffParamIndices, AutoDiffDerivativeFunctionKind::VJP, - lookupConformance, derivativeGenSig, /*makeSelfParamFirst*/ true); - auto isValidVJP = [&](AbstractFunctionDecl *vjpCandidate) -> bool { - return checkFunctionSignature( - cast(expectedVJPFnTy->getCanonicalType()), - vjpCandidate->getInterfaceType()->getCanonicalType()); - }; - auto *vjp = findAutoDiffDerivativeFunction( - attr->getVJP().getValue(), original, expectedVJPFnTy, isValidVJP); - if (!vjp) { - attr->setInvalid(); - return true; - } - // Set the VJP function in the attribute. - attr->setVJPFunction(vjp); - } - - return false; -} - /// Checks whether differentiable programming is enabled for the given /// differentiation-related attribute. Returns true on error. bool checkIfDifferentiableProgrammingEnabled( @@ -4034,16 +3850,6 @@ llvm::Expected DifferentiableAttributeTypeCheckRequest::evaluate( if (checkIfDifferentiableProgrammingEnabled(ctx, attr)) return nullptr; - // Derivative registration is disabled for `@differentiable(linear)` - // attributes. Instead, use `@transpose` attribute to register transpose - // functions. - if (attr->isLinear() && (attr->getVJP() || attr->getJVP())) { - diagnoseAndRemoveAttr(diags, D, attr, - diag::differentiable_attr_no_vjp_or_jvp_when_linear); - attr->setInvalid(); - return nullptr; - } - // Resolve the original `AbstractFunctionDecl`. auto *original = resolveDifferentiableAttrOriginalFunction(attr); if (!original) @@ -4142,13 +3948,6 @@ llvm::Expected DifferentiableAttributeTypeCheckRequest::evaluate( return nullptr; } - // Resolve JVP and VJP derivative functions, if specified. - FuncDecl *jvp = nullptr; - FuncDecl *vjp = nullptr; - if (resolveDifferentiableAttrDerivativeFunctions( - attr, original, resolvedDiffParamIndices, derivativeGenSig, jvp, vjp)) - return nullptr; - if (auto *asd = dyn_cast(D)) { // Remove `@differentiable` attribute from storage declaration to prevent // duplicate attribute registration during SILGen. @@ -4158,10 +3957,8 @@ llvm::Expected DifferentiableAttributeTypeCheckRequest::evaluate( auto *getterDecl = asd->getAccessor(AccessorKind::Get); auto *newAttr = DifferentiableAttr::create( getterDecl, /*implicit*/ true, attr->AtLoc, attr->getRange(), - attr->isLinear(), resolvedDiffParamIndices, attr->getJVP(), - attr->getVJP(), attr->getDerivativeGenericSignature()); - newAttr->setJVPFunction(attr->getJVPFunction()); - newAttr->setVJPFunction(attr->getVJPFunction()); + attr->isLinear(), resolvedDiffParamIndices, + attr->getDerivativeGenericSignature()); auto insertion = ctx.DifferentiableAttrs.try_emplace( {getterDecl, resolvedDiffParamIndices}, newAttr); // Reject duplicate `@differentiable` attributes. diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 91abd3910406b..c6202c8f838c4 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -3175,10 +3175,6 @@ void Solution::dump(raw_ostream &out) const { << ovl.second.openedType->getString(PO) << "\n"; break; - case OverloadChoiceKind::BaseType: - out << "base type " << choice.getBaseType()->getString(PO) << "\n"; - break; - case OverloadChoiceKind::KeyPathApplication: out << "key path application root " << choice.getBaseType()->getString(PO) << "\n"; @@ -3381,10 +3377,6 @@ void ConstraintSystem::print(raw_ostream &out) const { << resolved.openedType->getString(PO) << "\n"; break; - case OverloadChoiceKind::BaseType: - out << "base type " << choice.getBaseType()->getString(PO) << "\n"; - break; - case OverloadChoiceKind::KeyPathApplication: out << "key path application root " << choice.getBaseType()->getString(PO) << "\n"; diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 75561bd466040..055a12f91df27 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -1253,12 +1253,9 @@ static PrecedenceGroupDecl * lookupPrecedenceGroup(const PrecedenceGroupDescriptor &descriptor) { auto *dc = descriptor.dc; if (auto sf = dc->getParentSourceFile()) { - OperatorLookupDescriptor desc{ - sf, - descriptor.ident, - dc->isCascadingContextForLookup(false), - descriptor.nameLoc - }; + auto desc = OperatorLookupDescriptor::forFile( + sf, descriptor.ident, dc->isCascadingContextForLookup(false), + descriptor.nameLoc); return evaluateOrDefault(sf->getASTContext().evaluator, LookupPrecedenceGroupRequest{desc}, nullptr); } else { @@ -1730,12 +1727,9 @@ FunctionOperatorRequest::evaluate(Evaluator &evaluator, FuncDecl *FD) const { FD->diagnose(diag::operator_in_local_scope); } - OperatorLookupDescriptor desc{ - FD->getDeclContext()->getParentSourceFile(), - operatorName, - FD->isCascadingContextForLookup(false), - FD->getLoc() - }; + auto desc = OperatorLookupDescriptor::forFile( + FD->getDeclContext()->getParentSourceFile(), operatorName, + FD->isCascadingContextForLookup(false), FD->getLoc()); OperatorDecl *op = nullptr; if (FD->isUnaryOperator()) { if (FD->getAttrs().hasAttribute()) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 93a79abbb2ce7..afe783e5c7f12 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -655,8 +655,7 @@ static bool hasOverridingDifferentiableAttribute(ValueDecl *derivedDecl, // Get `@differentiable` attribute description. std::string baseDiffAttrString; llvm::raw_string_ostream os(baseDiffAttrString); - baseDA->print(os, derivedDecl, omitWrtClause, - /*omitDerivativeFunctions*/ true); + baseDA->print(os, derivedDecl, omitWrtClause); os.flush(); diags .diagnose(derivedDecl, diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index aa3443b05b7b4..f1350fbeb1b57 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -130,12 +130,9 @@ Expr *TypeChecker::substituteInputSugarTypeForResult(ApplyExpr *E) { static PrecedenceGroupDecl *lookupPrecedenceGroupForOperator(DeclContext *DC, Identifier name, SourceLoc loc) { - OperatorLookupDescriptor desc{ - DC->getParentSourceFile(), - name, - DC->isCascadingContextForLookup(true), - loc - }; + auto desc = OperatorLookupDescriptor::forFile( + DC->getParentSourceFile(), name, DC->isCascadingContextForLookup(true), + loc); auto &Ctx = DC->getASTContext(); if (auto op = evaluateOrDefault(Ctx.evaluator, LookupInfixOperatorRequest{desc}, diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 5cfe36429cbe5..da4f908c50652 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness, /// witness. /// - If requirement's `@differentiable` attributes are met, or if `result` is /// not viable, returns `result`. -/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`. +/// - Otherwise, returns a "missing `@differentiable` attribute" +/// `RequirementMatch`. // Note: the `result` argument is only necessary for using // `RequirementMatch::WitnessSubstitutions`. static RequirementMatch @@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, } if (!foundExactConfig) { bool success = false; - if (supersetConfig) { - // If the witness has a "superset" derivative configuration, create an - // implicit `@differentiable` attribute with the exact requirement - // `@differentiable` attribute parameter indices. + // If no exact witness derivative configuration was found, check + // conditions for creating an implicit witness `@differentiable` attribute + // with the exact derivative configuration: + // - If the witness has a "superset" derivative configuration. + // - If the witness is less than public and is declared in the same file + // as the conformance. + // - `@differentiable` attributes are really only significant for public + // declarations: it improves usability to not require explicit + // `@differentiable` attributes for less-visible declarations. + bool createImplicitWitnessAttribute = + supersetConfig || witness->getFormalAccess() < AccessLevel::Public; + // If the witness has less-than-public visibility and is declared in a + // different file than the conformance, produce an error. + if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public && + dc->getModuleScopeContext() != + witness->getDeclContext()->getModuleScopeContext()) { + // FIXME(TF-1014): `@differentiable` attribute diagnostic does not + // appear if associated type inference is involved. + if (auto *vdWitness = dyn_cast(witness)) { + return RequirementMatch( + getStandinForAccessor(vdWitness, AccessorKind::Get), + MatchKind::MissingDifferentiableAttr, reqDiffAttr); + } else { + return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr, + reqDiffAttr); + } + } + if (createImplicitWitnessAttribute) { + auto derivativeGenSig = witnessAFD->getGenericSignature(); + if (supersetConfig) + derivativeGenSig = supersetConfig->derivativeGenericSignature; + // Use source location of the witness declaration as the source location + // of the implicit `@differentiable` attribute. auto *newAttr = DifferentiableAttr::create( - witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc, - reqDiffAttr->getRange(), reqDiffAttr->isLinear(), - reqDiffAttr->getParameterIndices(), /*jvp*/ None, - /*vjp*/ None, supersetConfig->derivativeGenericSignature); + witnessAFD, /*implicit*/ true, witness->getLoc(), witness->getLoc(), + reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(), + derivativeGenSig); + // If the implicit attribute is inherited from a protocol requirement's + // attribute, store the protocol requirement attribute's location for + // use in diagnostics. + if (witness->getFormalAccess() < AccessLevel::Public) { + newAttr->getImplicitlyInheritedDifferentiableAttrLocation( + reqDiffAttr->getLocation()); + } auto insertion = ctx.DifferentiableAttrs.try_emplace( {witnessAFD, newAttr->getParameterIndices()}, newAttr); // Valid `@differentiable` attributes are uniqued by original function @@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, if (auto *vdWitness = dyn_cast(witness)) { return RequirementMatch( getStandinForAccessor(vdWitness, AccessorKind::Get), - MatchKind::DifferentiableConflict, reqDiffAttr); + MatchKind::MissingDifferentiableAttr, reqDiffAttr); } else { - return RequirementMatch(witness, MatchKind::DifferentiableConflict, + return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr); } } @@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::NonObjC: diags.diagnose(match.Witness, diag::protocol_witness_not_objc); break; - case MatchKind::DifferentiableConflict: { + case MatchKind::MissingDifferentiableAttr: { + auto *witness = match.Witness; // Emit a note and fix-it showing the missing requirement `@differentiable` // attribute. auto *reqAttr = cast(match.UnmetAttribute); assert(reqAttr); // Omit printing `wrt:` clause if attribute's differentiability // parameters match inferred differentiability parameters. - auto *original = cast(match.Witness); + auto *original = cast(witness); auto *whereClauseGenEnv = reqAttr->getDerivativeGenericEnvironment(original); auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters( @@ -2334,13 +2371,31 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, inferredParameters->getNumIndices(); std::string reqDiffAttrString; llvm::raw_string_ostream os(reqDiffAttrString); - reqAttr->print(os, req, omitWrtClause, /*omitDerivativeFunctions*/ true); + reqAttr->print(os, req, omitWrtClause); os.flush(); - diags - .diagnose(match.Witness, - diag::protocol_witness_missing_differentiable_attr, - reqDiffAttrString) - .fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' '); + // If the witness has less-than-public visibility and is declared in a + // different file than the conformance, emit a specialized diagnostic. + if (witness->getFormalAccess() < AccessLevel::Public && + conformance->getDeclContext()->getModuleScopeContext() != + witness->getDeclContext()->getModuleScopeContext()) { + diags + .diagnose( + witness, + diag:: + protocol_witness_missing_differentiable_attr_nonpublic_other_file, + reqDiffAttrString, witness->getDescriptiveKind(), + witness->getFullName(), req->getDescriptiveKind(), + req->getFullName(), conformance->getType(), + conformance->getProtocol()->getDeclaredInterfaceType()) + .fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' '); + } + // Otherwise, emit a general "missing attribute" diagnostic. + else { + diags + .diagnose(witness, diag::protocol_witness_missing_differentiable_attr, + reqDiffAttrString) + .fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' '); + } break; } } @@ -3795,10 +3850,7 @@ static void recordConformanceDependency(DeclContext *DC, Conformance->getDeclContext()->getParentModule()) return; - // FIXME: 'deinit' is being used as a dummy identifier here. Really we - // don't care about /any/ of the type's members, only that it conforms to - // the protocol. - tracker->addUsedMember({Adoptee, DeclBaseName::createDestructor()}, + tracker->addUsedMember({Adoptee, Identifier()}, DC->isCascadingContextForLookup(InExpression)); } @@ -5539,6 +5591,9 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC, case KnownProtocolKind::Decodable: return derived.deriveDecodable(Requirement); + case KnownProtocolKind::AdditiveArithmetic: + return derived.deriveAdditiveArithmetic(Requirement); + default: return nullptr; } diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 9050e88102637..e64b2a3aa940c 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -209,9 +209,8 @@ enum class MatchKind : uint8_t { /// The witness is explicitly @nonobjc but the requirement is @objc. NonObjC, - /// The witness does not have a `@differentiable` attribute satisfying one - /// from the requirement. - DifferentiableConflict, + /// The witness is missing a `@differentiable` attribute from the requirement. + MissingDifferentiableAttr, }; /// Describes the kind of optional adjustment performed when @@ -362,7 +361,7 @@ struct RequirementMatch { : Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr), ReqEnv(None) { assert(!hasWitnessType() && "Should have witness type"); - assert(UnmetAttribute); + assert(hasUnmetAttribute() && "Should have unmet attribute"); } RequirementMatch(ValueDecl *witness, MatchKind kind, @@ -437,7 +436,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -467,7 +466,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -478,7 +477,9 @@ struct RequirementMatch { bool hasRequirement() { return Kind == MatchKind::MissingRequirement; } /// Determine whether this requirement match has an unmet attribute. - bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; } + bool hasUnmetAttribute() { + return Kind == MatchKind::MissingDifferentiableAttr; + } swift::Witness getWitness(ASTContext &ctx) const; }; diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 465f6e1967b8b..8a1028633cf49 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -880,7 +880,9 @@ Type TypeChecker::applyUnboundGenericArguments( auto genericSig = genericEnv->getGenericSignature(); for (auto gp : genericSig->getGenericParams()) { subs[gp->getCanonicalType()->castTo()] = - genericEnv->mapTypeIntoContext(gp); + (resolution.usesArchetypes() + ? genericEnv->mapTypeIntoContext(gp) + : gp); } } diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 3daed88e91baf..4dbce8844c407 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -527,7 +527,10 @@ ModuleFile::readConformanceChecked(llvm::BitstreamCursor &Cursor, "reading specialized conformance for", conformingType); - auto subMap = getSubstitutionMap(substitutionMapID); + auto subMapOrError = getSubstitutionMapChecked(substitutionMapID); + if (!subMapOrError) + return subMapOrError.takeError(); + auto subMap = subMapOrError.get(); ProtocolConformanceRef genericConformance = readConformance(Cursor, genericEnv); @@ -571,7 +574,11 @@ ModuleFile::readConformanceChecked(llvm::BitstreamCursor &Cursor, case NORMAL_PROTOCOL_CONFORMANCE_ID: { NormalConformanceID conformanceID; NormalProtocolConformanceIdLayout::readRecord(scratch, conformanceID); - return ProtocolConformanceRef(readNormalConformance(conformanceID)); + + auto conformance = readNormalConformanceChecked(conformanceID); + if (!conformance) + return conformance.takeError(); + return ProtocolConformanceRef(conformance.get()); } case PROTOCOL_CONFORMANCE_XREF: { @@ -614,7 +621,7 @@ ModuleFile::readConformanceChecked(llvm::BitstreamCursor &Cursor, } } -NormalProtocolConformance *ModuleFile::readNormalConformance( +Expected ModuleFile::readNormalConformanceChecked( NormalConformanceID conformanceID) { auto &conformanceEntry = NormalConformances[conformanceID-1]; if (conformanceEntry.isComplete()) { @@ -647,13 +654,21 @@ NormalProtocolConformance *ModuleFile::readNormalConformance( rawIDs); ASTContext &ctx = getContext(); - DeclContext *dc = getDeclContext(contextID); + auto doOrError = getDeclContextChecked(contextID); + if (!doOrError) + return doOrError.takeError(); + DeclContext *dc = doOrError.get(); + assert(!isa(dc->getModuleScopeContext()) && "should not have serialized a conformance from a clang module"); Type conformingType = dc->getDeclaredInterfaceType(); PrettyStackTraceType trace(ctx, "reading conformance for", conformingType); - auto proto = cast(getDecl(protoID)); + auto protoOrError = getDeclChecked(protoID); + if (!protoOrError) + return protoOrError.takeError(); + auto proto = cast(protoOrError.get()); + PrettyStackTraceDecl traceTo("... to", proto); ++NumNormalProtocolConformancesLoaded; @@ -1069,7 +1084,10 @@ ModuleFile::getSubstitutionMapChecked(serialization::SubstitutionMapID id) { conformances.reserve(numConformances); for (unsigned i : range(numConformances)) { (void)i; - conformances.push_back(readConformance(DeclTypeCursor)); + auto conformanceOrError = readConformanceChecked(DeclTypeCursor); + if (!conformanceOrError) + return conformanceOrError.takeError(); + conformances.push_back(conformanceOrError.get()); } // Form the substitution map and record it. @@ -1654,7 +1672,7 @@ ModuleFile::resolveCrossReference(ModuleID MID, uint32_t pathLen) { return true; if (!fn->getOperatorDecl()) return true; - if (getStableFixity(fn->getOperatorDecl()->getKind()) != rawKind) + if (getStableFixity(fn->getOperatorDecl()->getFixity()) != rawKind) return true; return false; }); @@ -2268,6 +2286,29 @@ static bool attributeChainContains(DeclAttribute *attr) { return tempAttrs.hasAttribute(); } +// Set original declaration and parameter indices in `@differentiable` +// attributes. +// +// Serializing/deserializing the original declaration DeclID in +// `@differentiable` attributes does not work because it causes +// `@differentiable` attribute deserialization to enter an infinite loop. +// +// Instead, call this ad-hoc function after deserializing a declaration to set +// the original declaration and parameter indices for its `@differentiable` +// attributes. +static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes( + Decl *decl, DeclAttribute *attrs, + llvm::DenseMap + &diffAttrParamIndicesMap) { + DeclAttributes tempAttrs; + tempAttrs.setRawAttributeChain(attrs); + for (auto *attr : tempAttrs.getAttributes()) { + auto *diffAttr = const_cast(attr); + diffAttr->setOriginalDeclaration(decl); + diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]); + } +} + Decl *ModuleFile::getDecl(DeclID DID) { Expected deserialized = getDeclChecked(DID); if (!deserialized) { @@ -2294,6 +2335,9 @@ class DeclDeserializer { unsigned localDiscriminator = 0; StringRef filenameForPrivate; + // Auxiliary map for deserializing `@differentiable` attributes. + llvm::DenseMap diffAttrParamIndicesMap; + void AddAttribute(DeclAttribute *Attr) { // Advance the linked list. // This isn't just using DeclAttributes because that would result in the @@ -2772,7 +2816,10 @@ class DeclDeserializer { var->setIsSetterMutating(isSetterMutating); declOrOffset = var; - Type interfaceType = MF.getType(interfaceTypeID); + auto interfaceTypeOrError = MF.getTypeChecked(interfaceTypeID); + if (!interfaceTypeOrError) + return interfaceTypeOrError.takeError(); + Type interfaceType = interfaceTypeOrError.get(); var->setInterfaceType(interfaceType); var->setImplicitlyUnwrappedOptional(isIUO); @@ -3195,9 +3242,12 @@ class DeclDeserializer { auto genericSig = MF.getGenericSignature(genericSigID); if (genericSig) opaqueDecl->setGenericSignature(genericSig); - if (underlyingTypeID) - opaqueDecl->setUnderlyingTypeSubstitutions( - MF.getSubstitutionMap(underlyingTypeID)); + if (underlyingTypeID) { + auto subMapOrError = MF.getSubstitutionMapChecked(underlyingTypeID); + if (!subMapOrError) + return subMapOrError.takeError(); + opaqueDecl->setUnderlyingTypeSubstitutions(subMapOrError.get()); + } SubstitutionMap subs; if (genericSig) { subs = genericSig->getIdentitySubstitutionMap(); @@ -4257,6 +4307,36 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { break; } + case decls_block::Differentiable_DECL_ATTR: { + bool isImplicit; + bool linear; + GenericSignatureID derivativeGenSigId; + ArrayRef parameters; + + serialization::decls_block::DifferentiableDeclAttrLayout::readRecord( + scratch, isImplicit, linear, derivativeGenSigId, parameters); + + auto derivativeGenSig = MF.getGenericSignature(derivativeGenSigId); + llvm::SmallBitVector parametersBitVector(parameters.size()); + for (unsigned i : indices(parameters)) + parametersBitVector[i] = parameters[i]; + auto *indices = IndexSubset::get(ctx, parametersBitVector); + auto *diffAttr = DifferentiableAttr::create( + ctx, isImplicit, SourceLoc(), SourceRange(), linear, + /*parsedParameters*/ {}, /*trailingWhereClause*/ nullptr); + + // Cache parameter indices so that they can set later. + // `DifferentiableAttr::setParameterIndices` cannot be called here + // because it requires `DifferentiableAttr::setOriginalDeclaration` to + // be called first. `DifferentiableAttr::setOriginalDeclaration` cannot + // be called here because the original declaration is not accessible in + // this function (`DeclDeserializer::deserializeDeclAttributes`). + diffAttrParamIndicesMap[diffAttr] = indices; + diffAttr->setDerivativeGenericSignature(derivativeGenSig); + Attr = diffAttr; + break; + } + case decls_block::Derivative_DECL_ATTR: { bool isImplicit; uint64_t origNameId; @@ -4391,8 +4471,18 @@ DeclDeserializer::getDeclCheckedImpl( switch (recordID) { #define CASE(RECORD_NAME) \ - case decls_block::RECORD_NAME##Layout::Code: \ - return deserialize##RECORD_NAME(scratch, blobData); + case decls_block::RECORD_NAME##Layout::Code: {\ + auto decl = deserialize##RECORD_NAME(scratch, blobData); \ + if (decl) { \ + /* \ + // Set original declaration and parameter indices in `@differentiable` \ + // attributes. \ + */ \ + setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes(\ + decl.get(), DAttrs, diffAttrParamIndicesMap); \ + } \ + return decl; \ + } CASE(TypeAlias) CASE(GenericTypeParamDecl) @@ -5036,7 +5126,11 @@ class TypeDeserializer { decls_block::OpaqueArchetypeTypeLayout::readRecord(scratch, opaqueDeclID, subsID); - auto opaqueDecl = cast(MF.getDecl(opaqueDeclID)); + auto opaqueTypeOrError = MF.getDeclChecked(opaqueDeclID); + if (!opaqueTypeOrError) + return opaqueTypeOrError.takeError(); + + auto opaqueDecl = cast(opaqueTypeOrError.get()); auto subs = MF.getSubstitutionMap(subsID); return OpaqueTypeArchetypeType::get(opaqueDecl, subs); @@ -5761,9 +5855,22 @@ ModuleFile::loadAllConformances(const Decl *D, uint64_t contextData, fatalIfNotSuccess(DeclTypeCursor.JumpToBit(bitPosition)); while (numConformances--) { - auto conf = readConformance(DeclTypeCursor); - if (conf.isConcrete()) - conformances.push_back(conf.getConcrete()); + auto conformance = readConformanceChecked(DeclTypeCursor); + + if (!conformance) { + // Missing module errors are most likely caused by an + // implementation-only import hiding types and decls. + // rdar://problem/60291019 + if (conformance.errorIsA()) { + consumeError(conformance.takeError()); + return; + } + else + fatal(conformance.takeError()); + } + + if (conformance.get().isConcrete()) + conformances.push_back(conformance.get().getConcrete()); } } diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index a4cde60617cb4..6bc1dd5a4aa41 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1042,7 +1042,8 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, Builder.setInsertionPoint(BB); Builder.setCurrentDebugScope(Fn->getDebugScope()); unsigned RawOpCode = 0, TyCategory = 0, TyCategory2 = 0, TyCategory3 = 0, - Attr = 0, NumSubs = 0, NumConformances = 0, IsNonThrowingApply = 0; + Attr = 0, Attr2 = 0, NumSubs = 0, NumConformances = 0, + IsNonThrowingApply = 0; ValueID ValID, ValID2, ValID3; TypeID TyID, TyID2, TyID3; TypeID ConcreteTyID; @@ -1146,6 +1147,18 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, TyCategory3, ValID3, ListOfValues); RawOpCode = (unsigned)SILInstructionKind::WitnessMethodInst; break; + case SIL_INST_DIFFERENTIABLE_FUNCTION: + SILInstDifferentiableFunctionLayout::readRecord( + scratch, /*numParams*/ Attr, /*hasDerivativeFunctions*/ Attr2, + ListOfValues); + RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionInst; + break; + case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT: + SILInstDifferentiableFunctionExtractLayout::readRecord( + scratch, TyID, TyCategory, ValID, /*extractee*/ Attr, + /*hasExplicitExtracteeType*/ Attr2); + RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst; + break; } // FIXME: validate @@ -2569,6 +2582,43 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, ResultVal = Builder.createKeyPath(Loc, pattern, subMap, operands, kpTy); break; } + case SILInstructionKind::DifferentiableFunctionInst: { + bool hasDerivativeFunctions = (bool)Attr2; + unsigned numOperands = hasDerivativeFunctions ? 3 : 1; + auto numParamIndices = ListOfValues.size() - numOperands * 3; + assert(ListOfValues.size() == numParamIndices + numOperands * 3); + auto rawParamIndices = + map>(ListOfValues.take_front(numParamIndices), + [](uint64_t i) { return (unsigned)i; }); + auto numParams = Attr; + auto *paramIndices = + IndexSubset::get(MF->getContext(), numParams, rawParamIndices); + SmallVector operands; + for (auto i = numParamIndices; i < numParamIndices + numOperands * 3; + i += 3) { + auto astTy = MF->getType(ListOfValues[i]); + auto silTy = getSILType(astTy, (SILValueCategory)ListOfValues[i + 1], Fn); + operands.push_back(getLocalValue(ListOfValues[i + 2], silTy)); + } + Optional> derivativeFunctions = None; + if (hasDerivativeFunctions) + derivativeFunctions = std::make_pair(operands[1], operands[2]); + ResultVal = Builder.createDifferentiableFunction( + Loc, paramIndices, operands[0], derivativeFunctions); + break; + } + case SILInstructionKind::DifferentiableFunctionExtractInst: { + auto astTy = MF->getType(TyID); + auto silTy = getSILType(astTy, SILValueCategory::Object, Fn); + auto val = getLocalValue(ValID, silTy); + NormalDifferentiableFunctionTypeComponent extractee(Attr); + Optional explicitExtracteeType = None; + if (Attr2) + explicitExtracteeType = silTy; + ResultVal = Builder.createDifferentiableFunctionExtract( + Loc, extractee, val, explicitExtracteeType); + break; + } case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { StringRef mangledKey = MF->getIdentifierText(ValID); auto *witness = getSILDifferentiabilityWitnessForReference(mangledKey); diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index 1e68f88a2710f..b66a76d892039 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -2156,7 +2156,8 @@ TypeDecl *ModuleFile::lookupNestedType(Identifier name, return nullptr; } -OperatorDecl *ModuleFile::lookupOperator(Identifier name, DeclKind fixity) { +OperatorDecl *ModuleFile::lookupOperator(Identifier name, + OperatorFixity fixity) { PrettyStackTraceModuleFile stackEntry(*this); if (!OperatorDecls) diff --git a/lib/Serialization/ModuleFile.h b/lib/Serialization/ModuleFile.h index 628c46f86469c..c52d8452ba428 100644 --- a/lib/Serialization/ModuleFile.h +++ b/lib/Serialization/ModuleFile.h @@ -731,7 +731,7 @@ class ModuleFile /// Searches the module's operators for one with the given name and fixity. /// /// If none is found, returns null. - OperatorDecl *lookupOperator(Identifier name, DeclKind fixity); + OperatorDecl *lookupOperator(Identifier name, OperatorFixity fixity); /// Searches the module's precedence groups for one with the given /// name and fixity. @@ -978,13 +978,14 @@ class ModuleFile llvm::Expected readConformanceChecked(llvm::BitstreamCursor &Cursor, GenericEnvironment *genericEnv = nullptr); - + /// Read a SILLayout from the given cursor. SILLayout *readSILLayout(llvm::BitstreamCursor &Cursor); - /// Read the given normal conformance from the current module file. - NormalProtocolConformance * - readNormalConformance(serialization::NormalConformanceID id); + /// Read the given normal conformance from the current module file, + /// returns the conformance or the first error. + llvm::Expected + readNormalConformanceChecked(serialization::NormalConformanceID id); /// Reads a foreign error conformance from \c DeclTypeCursor, if present. Optional maybeReadForeignErrorConvention(); diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 3918f79be0bc9..0b756b22eb545 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 548; // remove curried SILDeclRefs +const uint16_t SWIFTMODULE_VERSION_MINOR = 549; // differentiable_function, differentiable_function_extract /// A standard hash seed used for all string hashes in a serialized module. /// @@ -385,19 +385,18 @@ enum class SelfAccessKind : uint8_t { }; using SelfAccessKindField = BCFixed<2>; -/// Translates an operator DeclKind to a Serialization fixity, whose values are -/// guaranteed to be stable. -static inline OperatorKind getStableFixity(DeclKind kind) { - switch (kind) { - case DeclKind::PrefixOperator: +/// Translates an operator decl fixity to a Serialization fixity, whose values +/// are guaranteed to be stable. +static inline OperatorKind getStableFixity(OperatorFixity fixity) { + switch (fixity) { + case OperatorFixity::Prefix: return Prefix; - case DeclKind::PostfixOperator: + case OperatorFixity::Postfix: return Postfix; - case DeclKind::InfixOperator: + case OperatorFixity::Infix: return Infix; - default: - llvm_unreachable("unknown operator fixity"); } + llvm_unreachable("Unhandled case in switch"); } // These IDs must \em not be renumbered or reordered without incrementing @@ -1817,10 +1816,6 @@ namespace decls_block { Differentiable_DECL_ATTR, BCFixed<1>, // Implicit flag. BCFixed<1>, // Linear flag. - IdentifierIDField, // JVP name. - DeclIDField, // JVP function declaration. - IdentifierIDField, // VJP name. - DeclIDField, // VJP function declaration. GenericSignatureIDField, // Derivative generic signature. BCArray> // Differentiation parameter indices' bitvector. >; diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index 25d430e6e0f19..37826da34d8be 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -149,6 +149,8 @@ namespace sil_block { SIL_PROPERTY, SIL_ONE_OPERAND_EXTRA_ATTR, SIL_TWO_OPERANDS_EXTRA_ATTR, + SIL_INST_DIFFERENTIABLE_FUNCTION, + SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT, // We also share these layouts from the decls block. Their enumerators must // not overlap with ours. @@ -447,6 +449,22 @@ namespace sil_block { BCArray // SILDeclRef // may be trailed by an inline protocol conformance >; + + using SILInstDifferentiableFunctionLayout = BCRecordLayout< + SIL_INST_DIFFERENTIABLE_FUNCTION, + BCVBR<8>, // number of function parameters + BCFixed<1>, // has derivative functions? + BCArray // parameter indices and operands + >; + + using SILInstDifferentiableFunctionExtractLayout = BCRecordLayout< + SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT, + TypeIDField, + SILTypeCategoryField, + ValueIDField, + BCFixed<2>, // extractee + BCFixed<1> // has explicit extractee type? + >; } } // end namespace serialization diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index a5bc682619f95..2caa73c3888e3 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -1730,7 +1730,7 @@ void Serializer::writeCrossReference(const DeclContext *DC, uint32_t pathLen) { assert(op); abbrCode = DeclTypeAbbrCodes[XRefOperatorOrAccessorPathPieceLayout::Code]; auto emptyID = addDeclBaseNameRef(Identifier()); - auto fixity = getStableFixity(op->getKind()); + auto fixity = getStableFixity(op->getFixity()); XRefOperatorOrAccessorPathPieceLayout::emitRecord(Out, ScratchRecord, abbrCode, emptyID, fixity); @@ -1750,7 +1750,7 @@ void Serializer::writeCrossReference(const Decl *D) { abbrCode = DeclTypeAbbrCodes[XRefOperatorOrAccessorPathPieceLayout::Code]; auto nameID = addDeclBaseNameRef(op->getName()); - auto fixity = getStableFixity(op->getKind()); + auto fixity = getStableFixity(op->getFixity()); XRefOperatorOrAccessorPathPieceLayout::emitRecord(Out, ScratchRecord, abbrCode, nameID, fixity); @@ -2395,37 +2395,20 @@ class Serializer::DeclSerializer : public DeclVisitor { case DAK_Differentiable: { auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code]; auto *attr = cast(DA); - - IdentifierID jvpName = 0; - DeclID jvpRef = 0; - if (auto jvp = attr->getJVP()) - jvpName = S.addDeclBaseNameRef(jvp->Name.getBaseName()); - if (auto jvpFunction = attr->getJVPFunction()) - jvpRef = S.addDeclRef(jvpFunction); - - IdentifierID vjpName = 0; - DeclID vjpRef = 0; - if (auto vjp = attr->getVJP()) - vjpName = S.addDeclBaseNameRef(vjp->Name.getBaseName()); - if (auto vjpFunction = attr->getVJPFunction()) - vjpRef = S.addDeclRef(vjpFunction); - - auto paramIndices = attr->getParameterIndices(); - // NOTE(TF-836): `@differentiable` attribute serialization is blocked by - // `@differentiable` attribute type-checking (TF-828), which resolves - // parameter indices (`IndexSubset *`). - if (!paramIndices) - return; + assert(attr->getOriginalDeclaration() && + "`@differentiable` attribute should have original declaration set " + "during construction or parsing"); + auto *paramIndices = attr->getParameterIndices(); assert(paramIndices && "Parameter indices must be resolved"); - SmallVector indices; + SmallVector paramIndicesVector; for (unsigned i : range(paramIndices->getCapacity())) - indices.push_back(paramIndices->contains(i)); + paramIndicesVector.push_back(paramIndices->contains(i)); DifferentiableDeclAttrLayout::emitRecord( S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), - attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, + attr->isLinear(), S.addGenericSignatureRef(attr->getDerivativeGenericSignature()), - indices); + paramIndicesVector); return; } @@ -2442,12 +2425,12 @@ class Serializer::DeclSerializer : public DeclVisitor { getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind()); auto *parameterIndices = attr->getParameterIndices(); assert(parameterIndices && "Parameter indices must be resolved"); - SmallVector indices; + SmallVector paramIndicesVector; for (unsigned i : range(parameterIndices->getCapacity())) - indices.push_back(parameterIndices->contains(i)); + paramIndicesVector.push_back(parameterIndices->contains(i)); DerivativeDeclAttrLayout::emitRecord( S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId, - origDeclID, derivativeKind, indices); + origDeclID, derivativeKind, paramIndicesVector); return; } @@ -2467,12 +2450,12 @@ class Serializer::DeclSerializer : public DeclVisitor { DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction()); auto *parameterIndices = attr->getParameterIndices(); assert(parameterIndices && "Parameter indices must be resolved"); - SmallVector indices; + SmallVector paramIndicesVector; for (unsigned i : range(parameterIndices->getCapacity())) - indices.push_back(parameterIndices->contains(i)); + paramIndicesVector.push_back(parameterIndices->contains(i)); TransposeDeclAttrLayout::emitRecord( S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId, - origDeclID, indices); + origDeclID, paramIndicesVector); return; } } @@ -4924,7 +4907,7 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { .push_back({ extendedNominal, addDeclRef(D) }); } else if (auto OD = dyn_cast(D)) { operatorDecls[OD->getName()] - .push_back({ getStableFixity(OD->getKind()), addDeclRef(D) }); + .push_back({ getStableFixity(OD->getFixity()), addDeclRef(D) }); } else if (auto PGD = dyn_cast(D)) { precedenceGroupDecls[PGD->getName()] .push_back({ decls_block::PRECEDENCE_GROUP_DECL, addDeclRef(D) }); diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index bd6bcfb70576c..2c037edb137e1 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2152,6 +2152,38 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { break; } + case SILInstructionKind::DifferentiableFunctionInst: { + auto *dfi = cast(&SI); + SmallVector trailingInfo; + auto *paramIndices = dfi->getParameterIndices(); + for (unsigned idx : paramIndices->getIndices()) + trailingInfo.push_back(idx); + for (auto &op : dfi->getAllOperands()) { + auto val = op.get(); + trailingInfo.push_back(S.addTypeRef(val->getType().getASTType())); + trailingInfo.push_back((unsigned)val->getType().getCategory()); + trailingInfo.push_back(addValueRef(val)); + } + SILInstDifferentiableFunctionLayout::emitRecord( + Out, ScratchRecord, + SILAbbrCodes[SILInstDifferentiableFunctionLayout::Code], + paramIndices->getCapacity(), dfi->hasDerivativeFunctions(), + trailingInfo); + break; + } + case SILInstructionKind::DifferentiableFunctionExtractInst: { + auto *dfei = cast(&SI); + auto operandRef = addValueRef(dfei->getOperand()); + auto operandType = dfei->getOperand()->getType(); + auto operandTypeRef = S.addTypeRef(operandType.getASTType()); + auto rawExtractee = (unsigned)dfei->getExtractee(); + SILInstDifferentiableFunctionExtractLayout::emitRecord( + Out, ScratchRecord, + SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code], + operandTypeRef, (unsigned)operandType.getCategory(), operandRef, + rawExtractee, (unsigned)dfei->hasExplicitExtracteeType()); + break; + } case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { auto *dwfi = cast(&SI); auto *witness = dwfi->getWitness(); @@ -2541,6 +2573,10 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); + registerSILAbbr(); + registerSILAbbr(); + registerSILAbbr(); + registerSILAbbr(); registerSILAbbr(); registerSILAbbr(); @@ -2556,9 +2592,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { registerSILAbbr(); registerSILAbbr(); - registerSILAbbr(); - registerSILAbbr(); - // Register the abbreviation codes so these layouts can exist in both // decl blocks and sil blocks. registerSILAbbr(); diff --git a/lib/Serialization/SerializedModuleLoader.cpp b/lib/Serialization/SerializedModuleLoader.cpp index 8ee3900cabe02..64081f34d4edf 100644 --- a/lib/Serialization/SerializedModuleLoader.cpp +++ b/lib/Serialization/SerializedModuleLoader.cpp @@ -1084,14 +1084,17 @@ SerializedASTFile::lookupNestedType(Identifier name, return File.lookupNestedType(name, parent); } -OperatorDecl *SerializedASTFile::lookupOperator(Identifier name, - DeclKind fixity) const { - return File.lookupOperator(name, fixity); +void SerializedASTFile::lookupOperatorDirect( + Identifier name, OperatorFixity fixity, + TinyPtrVector &results) const { + if (auto *op = File.lookupOperator(name, fixity)) + results.push_back(op); } -PrecedenceGroupDecl * -SerializedASTFile::lookupPrecedenceGroup(Identifier name) const { - return File.lookupPrecedenceGroup(name); +void SerializedASTFile::lookupPrecedenceGroupDirect( + Identifier name, TinyPtrVector &results) const { + if (auto *group = File.lookupPrecedenceGroup(name)) + results.push_back(group); } void SerializedASTFile::lookupVisibleDecls(ModuleDecl::AccessPathTy accessPath, diff --git a/stdlib/cmake/modules/AddSwiftStdlib.cmake b/stdlib/cmake/modules/AddSwiftStdlib.cmake index c0286fc70aaea..0ddb9d609bf88 100644 --- a/stdlib/cmake/modules/AddSwiftStdlib.cmake +++ b/stdlib/cmake/modules/AddSwiftStdlib.cmake @@ -802,26 +802,18 @@ function(_add_swift_target_library_single target name) # Include LLVM Bitcode slices for iOS, Watch OS, and Apple TV OS device libraries. if(SWIFT_EMBED_BITCODE_SECTION AND NOT SWIFTLIB_SINGLE_DONT_EMBED_BITCODE) if(${SWIFTLIB_SINGLE_SDK} MATCHES "(I|TV|WATCH)OS") - # The two branches of this if statement accomplish the same end result - # We are simply accounting for the fact that on CMake < 3.16 - # using a generator expression to - # specify a LINKER: argument does not work, + # Please note that using a generator expression to fit + # this in a single target_link_options does not work + # (at least in CMake 3.15 and 3.16), # since that seems not to allow the LINKER: prefix to be # evaluated (i.e. it will be added as-is to the linker parameters) - if(CMAKE_VERSION VERSION_LESS 3.16) - target_link_options(${target} PRIVATE - "LINKER:-bitcode_bundle" - "LINKER:-lto_library,${LLVM_LIBRARY_DIR}/libLTO.dylib") + target_link_options(${target} PRIVATE + "LINKER:-bitcode_bundle" + "LINKER:-lto_library,${LLVM_LIBRARY_DIR}/libLTO.dylib") - if(SWIFT_EMBED_BITCODE_SECTION_HIDE_SYMBOLS) - target_link_options(${target} PRIVATE - "LINKER:-bitcode_hide_symbols") - endif() - else() + if(SWIFT_EMBED_BITCODE_SECTION_HIDE_SYMBOLS) target_link_options(${target} PRIVATE - "LINKER:-bitcode_bundle" - $<$:"LINKER:-bitcode_hide_symbols"> - "LINKER:-lto_library,${LLVM_LIBRARY_DIR}/libLTO.dylib") + "LINKER:-bitcode_hide_symbols") endif() endif() endif() diff --git a/stdlib/public/Darwin/Dispatch/CMakeLists.txt b/stdlib/public/Darwin/Dispatch/CMakeLists.txt index f6d6225349144..83ec1e89e9dd4 100644 --- a/stdlib/public/Darwin/Dispatch/CMakeLists.txt +++ b/stdlib/public/Darwin/Dispatch/CMakeLists.txt @@ -11,6 +11,7 @@ add_swift_target_library(swiftDispatch ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES} Private.swift Queue.swift Source.swift + Schedulers+DispatchQueue.swift Time.swift "${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c" @@ -21,6 +22,7 @@ add_swift_target_library(swiftDispatch ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES} SWIFT_MODULE_DEPENDS_IOS Darwin ObjectiveC # auto-updated SWIFT_MODULE_DEPENDS_TVOS Darwin ObjectiveC # auto-updated SWIFT_MODULE_DEPENDS_WATCHOS Darwin ObjectiveC # auto-updated + FRAMEWORK_DEPENDS_WEAK Combine DEPLOYMENT_VERSION_OSX ${SWIFTLIB_DEPLOYMENT_VERSION_DISPATCH_OSX} DEPLOYMENT_VERSION_IOS ${SWIFTLIB_DEPLOYMENT_VERSION_DISPATCH_IOS} diff --git a/stdlib/public/Darwin/Dispatch/Schedulers+DispatchQueue.swift b/stdlib/public/Darwin/Dispatch/Schedulers+DispatchQueue.swift new file mode 100644 index 0000000000000..11780b57d713b --- /dev/null +++ b/stdlib/public/Darwin/Dispatch/Schedulers+DispatchQueue.swift @@ -0,0 +1,283 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +private func clampedIntProduct(_ m1: Int, _ m2: UInt64) -> Int { + assert(m2 > 0, "multiplier must be positive") + guard m1 < Int.max, m2 < Int.max else { return Int.max } + let (result, overflow) = m1.multipliedReportingOverflow(by: Int(m2)) + if overflow { + return m1 > 0 ? Int.max : Int.min + } + return result +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension DispatchTimeInterval { + fileprivate var nanoseconds: Int { + switch self { + case .seconds(let s): return clampedIntProduct(s, NSEC_PER_SEC) + case .milliseconds(let ms): return clampedIntProduct(ms, NSEC_PER_MSEC) + case .microseconds(let us): return clampedIntProduct(us, NSEC_PER_USEC) + case .nanoseconds(let ns): return ns + case .never: return Int.max + } + } +} + +// This is Strideable except: +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension DispatchTime /* : Strideable */ { + typealias Stride = DispatchTimeInterval + + public func distance(to other: DispatchTime) -> DispatchTimeInterval { + let lhs = other.rawValue + let rhs = rawValue + if lhs >= rhs { + return DispatchTimeInterval.nanoseconds(Int(lhs - rhs)) + } else { + return DispatchTimeInterval.nanoseconds(0 - Int(rhs - lhs)) + } + } + + public func advanced(by n: DispatchTimeInterval) -> DispatchTime { + return self + n + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension DispatchQueue: Scheduler { + /// The scheduler time type used by the dispatch queue. + public struct SchedulerTimeType: Strideable, Codable, Hashable { + /// The dispatch time represented by this type. + public var dispatchTime: DispatchTime + + /// Creates a dispatch queue time type instance. + /// + /// - Parameter time: The dispatch time to represent. + public init(_ time: DispatchTime) { + dispatchTime = time + } + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let time = DispatchTime(uptimeNanoseconds: try container.decode(UInt64.self)) + self.init(time) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(dispatchTime.uptimeNanoseconds) + } + + /// Returns the distance to another dispatch queue time. + /// + /// - Parameter other: Another dispatch queue time. + /// - Returns: The time interval between this time and the provided time. + public func distance(to other: SchedulerTimeType) -> Stride { + return Stride(self.dispatchTime.distance(to: other.dispatchTime)) + } + + /// Returns a dispatch queue scheduler time calculated by advancing this instance’s time by the given interval. + /// + /// - Parameter n: A time interval to advance. + /// - Returns: A dispatch queue time advanced by the given interval from this instance’s time. + public func advanced(by n: Stride) -> SchedulerTimeType { + return SchedulerTimeType(self.dispatchTime.advanced(by: n.timeInterval)) + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(dispatchTime.rawValue) + } + + public struct Stride: SchedulerTimeIntervalConvertible, Comparable, SignedNumeric, ExpressibleByFloatLiteral, Hashable, Codable { + /// If created via floating point literal, the value is converted to nanoseconds via multiplication. + public typealias FloatLiteralType = Double + + /// Nanoseconds, same as DispatchTimeInterval. + public typealias IntegerLiteralType = Int + public typealias Magnitude = Int + + /// The value of this time interval in nanoseconds. + public var magnitude: Int + + /// A `DispatchTimeInterval` created with the value of this type in nanoseconds. + public var timeInterval: DispatchTimeInterval { + return .nanoseconds(magnitude) + } + + /// Creates a dispatch queue time interval from the given dispatch time interval. + /// + /// - Parameter timeInterval: A dispatch time interval. + public init(_ timeInterval: DispatchTimeInterval) { + magnitude = Int(timeInterval.nanoseconds) + } + + /// Creates a dispatch queue time interval from a floating-point seconds value. + /// + /// - Parameter value: The number of seconds, as a `Double`. + public init(floatLiteral value: Double) { + magnitude = Int(value * 1_000_000_000) + } + + /// Creates a dispatch queue time interval from an integer seconds value. + /// + /// - Parameter value: The number of seconds, as an `Int`. + public init(integerLiteral value: Int) { + magnitude = value * 1_000_000_000 + } + + /// Creates a dispatch queue time interval from a binary integer type. + /// + /// If `exactly` cannot convert to an `Int`, the resulting time interval is `nil`. + /// - Parameter exactly: A binary integer representing a time interval. + public init?(exactly source: T) where T: BinaryInteger { + if let v = Int(exactly: source) { + magnitude = v + } else { + return nil + } + } + + // --- + + public static func < (lhs: Stride, rhs: Stride) -> Bool { + return lhs.magnitude < rhs.magnitude + } + + // --- + + public static func * (lhs: Stride, rhs: Stride) -> Stride { + return Stride(.nanoseconds(lhs.magnitude * rhs.magnitude)) + } + + public static func + (lhs: Stride, rhs: Stride) -> Stride { + return Stride(.nanoseconds(lhs.magnitude + rhs.magnitude)) + } + + public static func - (lhs: Stride, rhs: Stride) -> Stride { + return Stride(.nanoseconds(lhs.magnitude - rhs.magnitude)) + } + + // --- + + public static func -= (lhs: inout Stride, rhs: Stride) { + let result = lhs - rhs + lhs = result + } + + public static func *= (lhs: inout Stride, rhs: Stride) { + let result = lhs * rhs + lhs = result + } + + public static func += (lhs: inout Stride, rhs: Stride) { + let result = lhs + rhs + lhs = result + } + + // --- + + public static func seconds(_ s: Double) -> Stride { + return Stride(.nanoseconds(Int(s * 1_000_000_000))) + } + + public static func seconds(_ s: Int) -> Stride { + return Stride(.seconds(s)) + } + + public static func milliseconds(_ ms: Int) -> Stride { + return Stride(.milliseconds(ms)) + } + + public static func microseconds(_ us: Int) -> Stride { + return Stride(.microseconds(us)) + } + + public static func nanoseconds(_ ns: Int) -> Stride { + return Stride(.nanoseconds(ns)) + } + } + } + + /// Options that affect the operation of the dispatch queue scheduler. + public struct SchedulerOptions { + /// The dispatch queue quality of service. + public var qos: DispatchQoS + + /// The dispatch queue work item flags. + public var flags: DispatchWorkItemFlags + + /// The dispatch group, if any, that should be used for performing actions. + public var group: DispatchGroup? + + public init(qos: DispatchQoS = .unspecified, flags: DispatchWorkItemFlags = [], group: DispatchGroup? = nil) { + self.qos = qos + self.flags = flags + self.group = group + } + } + + public var minimumTolerance: SchedulerTimeType.Stride { + return SchedulerTimeType.Stride(DispatchTimeInterval.seconds(0)) + } + + public var now: DispatchQueue.SchedulerTimeType { + return SchedulerTimeType(DispatchTime.now()) + } + + public func schedule(options: SchedulerOptions?, _ action: @escaping () -> Void) { + let qos = options?.qos ?? .unspecified + let flags = options?.flags ?? [] + + if let group = options?.group { + // Distinguish on the group because it appears to not be a call-through like the others. This may need to be adjusted. + self.async(group: group, qos: qos, flags: flags, execute: action) + } else { + self.async(qos: qos, flags: flags, execute: action) + } + } + + public func schedule(after date: SchedulerTimeType, + tolerance: SchedulerTimeType.Stride, + options: SchedulerOptions?, + _ action: @escaping () -> Void) { + // TODO: Tolerance ignored + let qos = options?.qos ?? .unspecified + let flags = options?.flags ?? [] + + self.asyncAfter(deadline: date.dispatchTime, qos: qos, flags: flags, execute: action) + } + + public func schedule(after date: SchedulerTimeType, + interval: SchedulerTimeType.Stride, + tolerance: SchedulerTimeType.Stride, + options: SchedulerOptions?, + _ action: @escaping () -> Void) -> Cancellable { + let source = DispatchSource.makeTimerSource(flags: DispatchSource.TimerFlags(), queue: self) + + source.schedule(deadline: date.dispatchTime, + repeating: interval.timeInterval, + leeway: tolerance.timeInterval) + source.setEventHandler(handler: action) + source.resume() + + return AnyCancellable(source.cancel) + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/CMakeLists.txt b/stdlib/public/Darwin/Foundation/CMakeLists.txt index dcee6a4c27dbb..8f19f213dc525 100644 --- a/stdlib/public/Darwin/Foundation/CMakeLists.txt +++ b/stdlib/public/Darwin/Foundation/CMakeLists.txt @@ -7,19 +7,19 @@ add_swift_target_library(swiftFoundation ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES BundleLookup.mm Calendar.swift CharacterSet.swift + CheckClass.mm Codable.swift Collections+DataProtocol.swift + CombineTypealiases.swift ContiguousBytes.swift Data.swift DataProtocol.swift - DispatchData+DataProtocol.swift - NSData+DataProtocol.swift - Pointers+DataProtocol.swift DataThunks.m Date.swift DateComponents.swift DateInterval.swift Decimal.swift + DispatchData+DataProtocol.swift FileManager.swift Foundation.swift IndexPath.swift @@ -30,6 +30,7 @@ add_swift_target_library(swiftFoundation ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES Notification.swift NSArray.swift NSCoder.swift + NSData+DataProtocol.swift NSDate.swift NSDictionary.swift NSError.swift @@ -53,15 +54,26 @@ add_swift_target_library(swiftFoundation ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES NSURL.swift PersonNameComponents.swift PlistEncoder.swift + Pointers+DataProtocol.swift Progress.swift + Publishers+KeyValueObserving.swift + Publishers+Locking.swift + Publishers+NotificationCenter.swift + Publishers+Timer.swift + Publishers+URLSession.swift ReferenceConvertible.swift + Scanner.swift + Schedulers+Date.swift + Schedulers+OperationQueue.swift + Schedulers+RunLoop.swift String.swift TimeZone.swift URL.swift + URLCache.swift URLComponents.swift URLRequest.swift + URLSession.swift UUID.swift - CheckClass.mm "${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c" @@ -80,6 +92,7 @@ add_swift_target_library(swiftFoundation ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES SWIFT_MODULE_DEPENDS_WATCHOS Darwin Dispatch CoreFoundation ObjectiveC # auto-updated CoreGraphics # imported in Swift FRAMEWORK_DEPENDS Foundation + FRAMEWORK_DEPENDS_WEAK Combine DEPLOYMENT_VERSION_OSX ${SWIFTLIB_DEPLOYMENT_VERSION_FOUNDATION_OSX} DEPLOYMENT_VERSION_IOS ${SWIFTLIB_DEPLOYMENT_VERSION_FOUNDATION_IOS} diff --git a/stdlib/public/Darwin/Foundation/Codable.swift b/stdlib/public/Darwin/Foundation/Codable.swift index 7254a4006ccef..59f47fab2688d 100644 --- a/stdlib/public/Darwin/Foundation/Codable.swift +++ b/stdlib/public/Darwin/Foundation/Codable.swift @@ -55,3 +55,26 @@ extension DecodingError { } } } + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +import Combine + +//===----------------------------------------------------------------------===// +// Generic Decoding +//===----------------------------------------------------------------------===// + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension JSONEncoder: TopLevelEncoder { } + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension PropertyListEncoder: TopLevelEncoder { } + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension JSONDecoder: TopLevelDecoder { } + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension PropertyListDecoder: TopLevelDecoder { } + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/CombineTypealiases.swift b/stdlib/public/Darwin/Foundation/CombineTypealiases.swift new file mode 100644 index 0000000000000..cb3cb62c9057e --- /dev/null +++ b/stdlib/public/Darwin/Foundation/CombineTypealiases.swift @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#if !(os(iOS) && (arch(i386) || arch(arm))) // Combine isn't on 32-bit iOS + +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +public typealias Published = Combine.Published + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +public typealias ObservableObject = Combine.ObservableObject + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/JSONEncoder.swift b/stdlib/public/Darwin/Foundation/JSONEncoder.swift index 464a565581137..e149c981ca7c0 100644 --- a/stdlib/public/Darwin/Foundation/JSONEncoder.swift +++ b/stdlib/public/Darwin/Foundation/JSONEncoder.swift @@ -75,6 +75,13 @@ open class JSONEncoder { /// Produce JSON with dictionary keys sorted in lexicographic order. @available(macOS 10.13, iOS 11.0, watchOS 4.0, tvOS 11.0, *) public static let sortedKeys = OutputFormatting(rawValue: 1 << 1) + + /// By default slashes get escaped ("/" → "\/", "http://apple.com/" → "http:\/\/apple.com\/") + /// for security reasons, allowing outputted JSON to be safely embedded within HTML/XML. + /// In contexts where this escaping is unnecessary, the JSON is known to not be embedded, + /// or is intended only for display, this option avoids this escaping. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public static let withoutEscapingSlashes = OutputFormatting(rawValue: 1 << 3) } /// The strategy to use for encoding `Date` values. @@ -257,18 +264,7 @@ open class JSONEncoder { EncodingError.Context(codingPath: [], debugDescription: "Top-level \(T.self) did not encode any values.")) } - if topLevel is NSNull { - throw EncodingError.invalidValue(value, - EncodingError.Context(codingPath: [], debugDescription: "Top-level \(T.self) encoded as null JSON fragment.")) - } else if topLevel is NSNumber { - throw EncodingError.invalidValue(value, - EncodingError.Context(codingPath: [], debugDescription: "Top-level \(T.self) encoded as number JSON fragment.")) - } else if topLevel is NSString { - throw EncodingError.invalidValue(value, - EncodingError.Context(codingPath: [], debugDescription: "Top-level \(T.self) encoded as string JSON fragment.")) - } - - let writingOptions = JSONSerialization.WritingOptions(rawValue: self.outputFormatting.rawValue) + let writingOptions = JSONSerialization.WritingOptions(rawValue: self.outputFormatting.rawValue).union(.fragmentsAllowed) do { return try JSONSerialization.data(withJSONObject: topLevel, options: writingOptions) } catch { @@ -1204,7 +1200,7 @@ open class JSONDecoder { open func decode(_ type: T.Type, from data: Data) throws -> T { let topLevel: Any do { - topLevel = try JSONSerialization.jsonObject(with: data) + topLevel = try JSONSerialization.jsonObject(with: data, options: .fragmentsAllowed) } catch { throw DecodingError.dataCorrupted(DecodingError.Context(codingPath: [], debugDescription: "The given data was not valid JSON.", underlyingError: error)) } diff --git a/stdlib/public/Darwin/Foundation/NSError.swift b/stdlib/public/Darwin/Foundation/NSError.swift index aa41c239ec8fa..b6d1d34fdc057 100644 --- a/stdlib/public/Darwin/Foundation/NSError.swift +++ b/stdlib/public/Darwin/Foundation/NSError.swift @@ -1977,6 +1977,26 @@ extension URLError.Code { } } +extension URLError { + /// Reasons used by URLError to indicate why a background URLSessionTask was cancelled. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public enum BackgroundTaskCancelledReason : Int { + case userForceQuitApplication + case backgroundUpdatesDisabled + case insufficientSystemResources + } +} + +extension URLError { + /// Reasons used by URLError to indicate that a URLSessionTask failed because of unsatisfiable network constraints. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public enum NetworkUnavailableReason : Int { + case cellular + case expensive + case constrained + } +} + extension URLError { private var _nsUserInfo: [AnyHashable : Any] { return (self as NSError).userInfo @@ -2000,6 +2020,24 @@ extension URLError { return nil } + + /// The reason why a background URLSessionTask was cancelled. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public var backgroundTaskCancelledReason: BackgroundTaskCancelledReason? { + return (_nsUserInfo[NSURLErrorBackgroundTaskCancelledReasonKey] as? Int).flatMap(BackgroundTaskCancelledReason.init(rawValue:)) + } + + /// The reason why the network is unavailable when the task failed due to unsatisfiable network constraints. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public var networkUnavailableReason: NetworkUnavailableReason? { + return (_nsUserInfo[NSURLErrorNetworkUnavailableReasonKey] as? Int).flatMap(NetworkUnavailableReason.init(rawValue:)) + } + + /// An opaque data blob to resume a failed download task. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public var downloadTaskResumeData: Data? { + return _nsUserInfo[NSURLSessionDownloadTaskResumeData] as? Data + } } extension URLError { diff --git a/stdlib/public/Darwin/Foundation/NSObject.swift b/stdlib/public/Darwin/Foundation/NSObject.swift index bb7bf0819cfc1..ea1f9a443a72f 100644 --- a/stdlib/public/Darwin/Foundation/NSObject.swift +++ b/stdlib/public/Darwin/Foundation/NSObject.swift @@ -12,6 +12,7 @@ @_exported import Foundation // Clang module import ObjectiveC +import _SwiftFoundationOverlayShims // This exists to allow for dynamic dispatch on KVO methods added to NSObject. // Extending NSObject with these methods would disallow overrides. @@ -173,6 +174,9 @@ public class NSKeyValueObservation : NSObject { // workaround for Erroneous (?) error when using bridging in the Foundation overlay // specifically, overriding observeValue(forKeyPath:of:change:context:) complains that it's not Obj-C-compatible @nonobjc static let swizzler: () = { + let cls = NSClassFromString("_NSKVOCompatibility") as? _NSKVOCompatibilityShim.Type + cls?._noteProcessHasUsedKVOSwiftOverlay() + let bridgeClass: AnyClass = Helper.self let observeSel = #selector(NSObject.observeValue(forKeyPath:of:change:context:)) let swapSel = #selector(Helper._swizzle_me_observeValue(forKeyPath:of:change:context:)) diff --git a/stdlib/public/Darwin/Foundation/Publishers+KeyValueObserving.swift b/stdlib/public/Darwin/Foundation/Publishers+KeyValueObserving.swift new file mode 100644 index 0000000000000..df9271fcf2350 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Publishers+KeyValueObserving.swift @@ -0,0 +1,210 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +// The following protocol is so that we can reference `Self` in the Publisher +// below. This is based on a trick used in the the standard library's +// implementation of `NSObject.observe(key path)` +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +public protocol _KeyValueCodingAndObservingPublishing {} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NSObject: _KeyValueCodingAndObservingPublishing {} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension _KeyValueCodingAndObservingPublishing where Self: NSObject { + /// Publish values when the value identified by a KVO-compliant keypath changes. + /// + /// - Parameters: + /// - keyPath: The keypath of the property to publish. + /// - options: Key-value observing options. + /// - Returns: A publisher that emits elements each time the property’s value changes. + public func publisher(for keyPath: KeyPath, + options: NSKeyValueObservingOptions = [.initial, .new]) + -> NSObject.KeyValueObservingPublisher { + return NSObject.KeyValueObservingPublisher(object: self, keyPath: keyPath, options: options) + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NSObject.KeyValueObservingPublisher { + /// Returns a publisher that emits values when a KVO-compliant property changes. + /// + /// - Returns: A key-value observing publisher. + public func didChange() + -> Publishers.Map, Void> { + return map { _ in () } + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NSObject { + /// A publisher that emits events when the value of a KVO-compliant property changes. + public struct KeyValueObservingPublisher : Equatable { + public let object: Subject + public let keyPath: KeyPath + public let options: NSKeyValueObservingOptions + + public init( + object: Subject, + keyPath: KeyPath, + options: NSKeyValueObservingOptions + ) { + self.object = object + self.keyPath = keyPath + self.options = options + } + + public static func == ( + lhs: KeyValueObservingPublisher, + rhs: KeyValueObservingPublisher + ) -> Bool { + return lhs.object === rhs.object + && lhs.keyPath == rhs.keyPath + && lhs.options == rhs.options + } + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NSObject.KeyValueObservingPublisher: Publisher { + public typealias Output = Value + public typealias Failure = Never + + public func receive(subscriber: S) where S.Input == Output, S.Failure == Failure { + let s = NSObject.KVOSubscription(object, keyPath, options, subscriber) + subscriber.receive(subscription: s) + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NSObject { + private final class KVOSubscription: Subscription, CustomStringConvertible, CustomReflectable, CustomPlaygroundDisplayConvertible { + private var observation: NSKeyValueObservation? // GuardedBy(lock) + private var demand: Subscribers.Demand // GuardedBy(lock) + + // for configurations that care about '.initial' we need to 'cache' the value to account for backpressure, along with whom to send it to + // + // TODO: in the future we might want to consider interjecting a temporary publisher that does this, so that all KVO subscriptions don't incur the cost. + private var receivedInitial: Bool // GuardedBy(lock) + private var last: Value? // GuardedBy(lock) + private var subscriber: AnySubscriber? // GuardedBy(lock) + + private let lock = Lock() + + // This lock can only be held for the duration of downstream callouts + private let downstreamLock = RecursiveLock() + + var description: String { return "KVOSubscription" } + var customMirror: Mirror { + lock.lock() + defer { lock.unlock() } + return Mirror(self, children: [ + "observation": observation as Any, + "demand": demand + ]) + } + var playgroundDescription: Any { return description } + + init( + _ object: Subject, + _ keyPath: KeyPath, + _ options: NSKeyValueObservingOptions, + _ subscriber: S) + where + S.Input == Value, + S.Failure == Never + { + demand = .max(0) + receivedInitial = false + self.subscriber = AnySubscriber(subscriber) + + observation = object.observe( + keyPath, + options: options + ) { [weak self] obj, _ in + guard let self = self else { + return + } + let value = obj[keyPath: keyPath] + self.lock.lock() + if self.demand > 0, let sub = self.subscriber { + self.demand -= 1 + self.lock.unlock() + + self.downstreamLock.lock() + let additional = sub.receive(value) + self.downstreamLock.unlock() + + self.lock.lock() + self.demand += additional + self.lock.unlock() + } else { + // Drop the value, unless we've asked for .initial, and this + // is the first value. + if self.receivedInitial == false && options.contains(.initial) { + self.last = value + self.receivedInitial = true + } + self.lock.unlock() + } + } + } + + deinit { + lock.cleanupLock() + downstreamLock.cleanupLock() + } + + func request(_ d: Subscribers.Demand) { + lock.lock() + demand += d + if demand > 0, let v = last, let sub = subscriber { + demand -= 1 + last = nil + lock.unlock() + + downstreamLock.lock() + let additional = sub.receive(v) + downstreamLock.unlock() + + lock.lock() + demand += additional + } else { + demand -= 1 + last = nil + } + lock.unlock() + } + + func cancel() { + lock.lock() + guard let o = observation else { + lock.unlock() + return + } + lock.unlock() + + observation = nil + subscriber = nil + last = nil + o.invalidate() + } + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Publishers+Locking.swift b/stdlib/public/Darwin/Foundation/Publishers+Locking.swift new file mode 100644 index 0000000000000..0279f9bdccd45 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Publishers+Locking.swift @@ -0,0 +1,117 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +import Darwin + +@available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) +extension UnsafeMutablePointer where Pointee == os_unfair_lock_s { + internal init() { + let l = UnsafeMutablePointer.allocate(capacity: 1) + l.initialize(to: os_unfair_lock()) + self = l + } + + internal func cleanupLock() { + deinitialize(count: 1) + deallocate() + } + + internal func lock() { + os_unfair_lock_lock(self) + } + + internal func tryLock() -> Bool { + let result = os_unfair_lock_trylock(self) + return result + } + + internal func unlock() { + os_unfair_lock_unlock(self) + } +} + +@available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) +typealias Lock = os_unfair_lock_t + +#if canImport(DarwinPrivate) + +@_implementationOnly import DarwinPrivate + +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) +extension UnsafeMutablePointer where Pointee == os_unfair_recursive_lock_s { + internal init() { + let l = UnsafeMutablePointer.allocate(capacity: 1) + l.initialize(to: os_unfair_recursive_lock_s()) + self = l + } + + internal func cleanupLock() { + deinitialize(count: 1) + deallocate() + } + + internal func lock() { + os_unfair_recursive_lock_lock(self) + } + + internal func tryLock() -> Bool { + let result = os_unfair_recursive_lock_trylock(self) + return result + } + + internal func unlock() { + os_unfair_recursive_lock_unlock(self) + } +} + +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) +typealias RecursiveLock = os_unfair_recursive_lock_t + +#else + +// Kept in overlay since some builds may not have `DarwinPrivate` but we should have the availability the same +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) +internal struct RecursiveLock { + private let lockPtr: UnsafeMutablePointer + + internal init() { + lockPtr = UnsafeMutablePointer.allocate(capacity: 1) + var attr = pthread_mutexattr_t() + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE) + pthread_mutex_init(lockPtr, &attr) + } + + internal func cleanupLock() { + pthread_mutex_destroy(lockPtr) + lockPtr.deinitialize(count: 1) + lockPtr.deallocate() + } + + internal func lock() { + pthread_mutex_lock(lockPtr) + } + + internal func tryLock() -> Bool { + return pthread_mutex_trylock(lockPtr) == 0 + } + + internal func unlock() { + pthread_mutex_unlock(lockPtr) + } +} + +#endif + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Publishers+NotificationCenter.swift b/stdlib/public/Darwin/Foundation/Publishers+NotificationCenter.swift new file mode 100644 index 0000000000000..834f4b7f5def1 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Publishers+NotificationCenter.swift @@ -0,0 +1,183 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NotificationCenter { + /// Returns a publisher that emits events when broadcasting notifications. + /// + /// - Parameters: + /// - name: The name of the notification to publish. + /// - object: The object posting the named notfication. If `nil`, the publisher emits elements for any object producing a notification with the given name. + /// - Returns: A publisher that emits events when broadcasting notifications. + public func publisher( + for name: Notification.Name, + object: AnyObject? = nil + ) -> NotificationCenter.Publisher { + return NotificationCenter.Publisher(center: self, name: name, object: object) + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NotificationCenter { + /// A publisher that emits elements when broadcasting notifications. + public struct Publisher: Combine.Publisher { + public typealias Output = Notification + public typealias Failure = Never + + /// The notification center this publisher uses as a source. + public let center: NotificationCenter + /// The name of notifications published by this publisher. + public let name: Notification.Name + /// The object posting the named notfication. + public let object: AnyObject? + + /// Creates a publisher that emits events when broadcasting notifications. + /// + /// - Parameters: + /// - center: The notification center to publish notifications for. + /// - name: The name of the notification to publish. + /// - object: The object posting the named notfication. If `nil`, the publisher emits elements for any object producing a notification with the given name. + public init(center: NotificationCenter, name: Notification.Name, object: AnyObject? = nil) { + self.center = center + self.name = name + self.object = object + } + + public func receive(subscriber: S) where S.Input == Output, S.Failure == Failure { + subscriber.receive(subscription: Notification.Subscription(center, name, object, subscriber)) + } + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension NotificationCenter.Publisher: Equatable { + public static func == ( + lhs: NotificationCenter.Publisher, + rhs: NotificationCenter.Publisher + ) -> Bool { + return lhs.center === rhs.center + && lhs.name == rhs.name + && lhs.object === rhs.object + } +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension Notification { + fileprivate final class Subscription: Combine.Subscription, CustomStringConvertible, CustomReflectable, CustomPlaygroundDisplayConvertible + where + S.Input == Notification + { + private let lock = Lock() + + // This lock can only be held for the duration of downstream callouts + private let downstreamLock = RecursiveLock() + + private var demand: Subscribers.Demand // GuardedBy(lock) + + private var center: NotificationCenter? // GuardedBy(lock) + private let name: Notification.Name // Stored only for debug info + private var object: AnyObject? // Stored only for debug info + private var observation: AnyObject? // GuardedBy(lock) + + var description: String { return "NotificationCenter Observer" } + var customMirror: Mirror { + lock.lock() + defer { lock.unlock() } + return Mirror(self, children: [ + "center": center as Any, + "name": name as Any, + "object": object as Any, + "demand": demand + ]) + } + var playgroundDescription: Any { return description } + + init(_ center: NotificationCenter, + _ name: Notification.Name, + _ object: AnyObject?, + _ next: S) + { + self.demand = .max(0) + self.center = center + self.name = name + self.object = object + + self.observation = center.addObserver( + forName: name, + object: object, + queue: nil + ) { [weak self] note in + guard let self = self else { return } + + self.lock.lock() + guard self.observation != nil else { + self.lock.unlock() + return + } + + let demand = self.demand + if demand > 0 { + self.demand -= 1 + } + self.lock.unlock() + + if demand > 0 { + self.downstreamLock.lock() + let additionalDemand = next.receive(note) + self.downstreamLock.unlock() + + if additionalDemand > 0 { + self.lock.lock() + self.demand += additionalDemand + self.lock.unlock() + } + } else { + // Drop it on the floor + } + } + } + + deinit { + lock.cleanupLock() + downstreamLock.cleanupLock() + } + + func request(_ d: Subscribers.Demand) { + lock.lock() + demand += d + lock.unlock() + } + + func cancel() { + lock.lock() + guard let center = self.center, + let observation = self.observation else { + lock.unlock() + return + } + self.center = nil + self.observation = nil + self.object = nil + lock.unlock() + + center.removeObserver(observation) + } + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Publishers+Timer.swift b/stdlib/public/Darwin/Foundation/Publishers+Timer.swift new file mode 100644 index 0000000000000..76011cd73400d --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Publishers+Timer.swift @@ -0,0 +1,328 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension Timer { + /// Returns a publisher that repeatedly emits the current date on the given interval. + /// + /// - Parameters: + /// - interval: The time interval on which to publish events. For example, a value of `0.5` publishes an event approximately every half-second. + /// - tolerance: The allowed timing variance when emitting events. Defaults to `nil`, which allows any variance. + /// - runLoop: The run loop on which the timer runs. + /// - mode: The run loop mode in which to run the timer. + /// - options: Scheduler options passed to the timer. Defaults to `nil`. + /// - Returns: A publisher that repeatedly emits the current date on the given interval. + public static func publish( + every interval: TimeInterval, + tolerance: TimeInterval? = nil, + on runLoop: RunLoop, + in mode: RunLoop.Mode, + options: RunLoop.SchedulerOptions? = nil) + -> TimerPublisher + { + return TimerPublisher(interval: interval, runLoop: runLoop, mode: mode, options: options) + } + + /// A publisher that repeatedly emits the current date on a given interval. + public final class TimerPublisher: ConnectablePublisher { + public typealias Output = Date + public typealias Failure = Never + + public let interval: TimeInterval + public let tolerance: TimeInterval? + public let runLoop: RunLoop + public let mode: RunLoop.Mode + public let options: RunLoop.SchedulerOptions? + + private lazy var routingSubscription: RoutingSubscription = { + return RoutingSubscription(parent: self) + }() + + // Stores if a `.connect()` happened before subscription, internally readable for tests + internal var isConnected: Bool { + return routingSubscription.isConnected + } + + /// Creates a publisher that repeatedly emits the current date on the given interval. + /// + /// - Parameters: + /// - interval: The interval on which to publish events. + /// - tolerance: The allowed timing variance when emitting events. Defaults to `nil`, which allows any variance. + /// - runLoop: The run loop on which the timer runs. + /// - mode: The run loop mode in which to run the timer. + /// - options: Scheduler options passed to the timer. Defaults to `nil`. + public init(interval: TimeInterval, tolerance: TimeInterval? = nil, runLoop: RunLoop, mode: RunLoop.Mode, options: RunLoop.SchedulerOptions? = nil) { + self.interval = interval + self.tolerance = tolerance + self.runLoop = runLoop + self.mode = mode + self.options = options + } + + /// Adapter subscription to allow `Timer` to multiplex to multiple subscribers + /// the values produced by a single `TimerPublisher.Inner` + private class RoutingSubscription: Subscription, Subscriber, CustomStringConvertible, CustomReflectable, CustomPlaygroundDisplayConvertible { + typealias Input = Date + typealias Failure = Never + + private typealias ErasedSubscriber = AnySubscriber + + private let lock: Lock + + // Inner is IUP due to init requirements + private var inner: Inner! + private var subscribers: [ErasedSubscriber] = [] + + private var _lockedIsConnected = false + var isConnected: Bool { + get { + lock.lock() + defer { lock.unlock() } + return _lockedIsConnected + } + + set { + lock.lock() + let oldValue = _lockedIsConnected + _lockedIsConnected = newValue + + // Inner will always be non-nil + let inner = self.inner! + lock.unlock() + + guard newValue, !oldValue else { + return + } + inner.enqueue() + } + } + + var description: String { return "Timer" } + var customMirror: Mirror { return inner.customMirror } + var playgroundDescription: Any { return description } + var combineIdentifier: CombineIdentifier { return inner.combineIdentifier } + + init(parent: TimerPublisher) { + self.lock = Lock() + self.inner = Inner(parent, self) + } + + deinit { + lock.cleanupLock() + } + + func addSubsriber(_ sub: S) + where + S.Failure == Failure, + S.Input == Output + { + lock.lock() + subscribers.append(AnySubscriber(sub)) + lock.unlock() + + sub.receive(subscription: self) + } + + func receive(subscription: Subscription) { + lock.lock() + let subscribers = self.subscribers + lock.unlock() + + for sub in subscribers { + sub.receive(subscription: subscription) + } + } + + func receive(_ value: Input) -> Subscribers.Demand { + var resultingDemand: Subscribers.Demand = .max(0) + lock.lock() + let subscribers = self.subscribers + let isConnected = _lockedIsConnected + lock.unlock() + + guard isConnected else { return .none } + + for sub in subscribers { + resultingDemand += sub.receive(value) + } + return resultingDemand + } + + func receive(completion: Subscribers.Completion) { + lock.lock() + let subscribers = self.subscribers + lock.unlock() + + for sub in subscribers { + sub.receive(completion: completion) + } + } + + func request(_ demand: Subscribers.Demand) { + lock.lock() + // Inner will always be non-nil + let inner = self.inner! + lock.unlock() + + inner.request(demand) + } + + func cancel() { + lock.lock() + // Inner will always be non-nil + let inner = self.inner! + _lockedIsConnected = false + self.subscribers = [] + lock.unlock() + + inner.cancel() + } + } + + public func receive(subscriber: S) where Failure == S.Failure, Output == S.Input { + routingSubscription.addSubsriber(subscriber) + } + + public func connect() -> Cancellable { + routingSubscription.isConnected = true + return routingSubscription + } + + private typealias Parent = TimerPublisher + private final class Inner: NSObject, Subscription, CustomReflectable, CustomPlaygroundDisplayConvertible + where + Downstream.Input == Date, + Downstream.Failure == Never + { + private lazy var timer: Timer? = { + let t = Timer( + timeInterval: parent?.interval ?? 0, + target: self, + selector: #selector(timerFired), + userInfo: nil, + repeats: true + ) + + t.tolerance = parent?.tolerance ?? 0 + + return t + }() + + private let lock: Lock + private var downstream: Downstream? + private var parent: Parent? + private var started: Bool + private var demand: Subscribers.Demand + + override var description: String { return "Timer" } + var customMirror: Mirror { + lock.lock() + defer { lock.unlock() } + return Mirror(self, children: [ + "downstream": downstream as Any, + "interval": parent?.interval as Any, + "tolerance": parent?.tolerance as Any + ]) + } + var playgroundDescription: Any { return description } + + init(_ parent: Parent, _ downstream: Downstream) { + self.lock = Lock() + self.parent = parent + self.downstream = downstream + self.started = false + self.demand = .max(0) + super.init() + } + + deinit { + lock.cleanupLock() + } + + func enqueue() { + lock.lock() + guard let t = timer, let parent = self.parent, !started else { + lock.unlock() + return + } + + started = true + lock.unlock() + + parent.runLoop.add(t, forMode: parent.mode) + } + + func cancel() { + lock.lock() + guard let t = timer else { + lock.unlock() + return + } + + // clear out all optionals + downstream = nil + parent = nil + started = false + demand = .max(0) + timer = nil + lock.unlock() + + // cancel the timer + t.invalidate() + } + + func request(_ n: Subscribers.Demand) { + lock.lock() + defer { lock.unlock() } + guard parent != nil else { + return + } + demand += n + } + + @objc + func timerFired(arg: Any) { + lock.lock() + guard let ds = downstream, parent != nil else { + lock.unlock() + return + } + + // This publisher drops events on the floor when there is no space in the subscriber + guard demand > 0 else { + lock.unlock() + return + } + + demand -= 1 + lock.unlock() + + let extra = ds.receive(Date()) + guard extra > 0 else { + return + } + + lock.lock() + demand += extra + lock.unlock() + } + } + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Publishers+URLSession.swift b/stdlib/public/Darwin/Foundation/Publishers+URLSession.swift new file mode 100644 index 0000000000000..565d0647662a0 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Publishers+URLSession.swift @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +// MARK: Data Tasks + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension URLSession { + /// Returns a publisher that wraps a URL session data task for a given URL. + /// + /// The publisher publishes data when the task completes, or terminates if the task fails with an error. + /// - Parameter url: The URL for which to create a data task. + /// - Returns: A publisher that wraps a data task for the URL. + public func dataTaskPublisher( + for url: URL) + -> DataTaskPublisher + { + let request = URLRequest(url: url) + return DataTaskPublisher(request: request, session: self) + } + + /// Returns a publisher that wraps a URL session data task for a given URL request. + /// + /// The publisher publishes data when the task completes, or terminates if the task fails with an error. + /// - Parameter request: The URL request for which to create a data task. + /// - Returns: A publisher that wraps a data task for the URL request. + public func dataTaskPublisher( + for request: URLRequest) + -> DataTaskPublisher + { + return DataTaskPublisher(request: request, session: self) + } + + public struct DataTaskPublisher: Publisher { + public typealias Output = (data: Data, response: URLResponse) + public typealias Failure = URLError + + public let request: URLRequest + public let session: URLSession + + public init(request: URLRequest, session: URLSession) { + self.request = request + self.session = session + } + + public func receive(subscriber: S) where Failure == S.Failure, Output == S.Input { + subscriber.receive(subscription: Inner(self, subscriber)) + } + + private typealias Parent = DataTaskPublisher + private final class Inner: Subscription, CustomStringConvertible, CustomReflectable, CustomPlaygroundDisplayConvertible + where + Downstream.Input == Parent.Output, + Downstream.Failure == Parent.Failure + { + typealias Input = Downstream.Input + typealias Failure = Downstream.Failure + + private let lock: Lock + private var parent: Parent? // GuardedBy(lock) + private var downstream: Downstream? // GuardedBy(lock) + private var demand: Subscribers.Demand // GuardedBy(lock) + private var task: URLSessionDataTask! // GuardedBy(lock) + + var description: String { return "DataTaskPublisher" } + var customMirror: Mirror { + lock.lock() + defer { lock.unlock() } + return Mirror(self, children: [ + "task": task as Any, + "downstream": downstream as Any, + "parent": parent as Any, + "demand": demand, + ]) + } + var playgroundDescription: Any { return description } + + init(_ parent: Parent, _ downstream: Downstream) { + self.lock = Lock() + self.parent = parent + self.downstream = downstream + self.demand = .max(0) + } + + deinit { + lock.cleanupLock() + } + + // MARK: - Upward Signals + func request(_ d: Subscribers.Demand) { + precondition(d > 0, "Invalid request of zero demand") + + lock.lock() + guard let p = parent else { + // We've already been cancelled so bail + lock.unlock() + return + } + + // Avoid issues around `self` before init by setting up only once here + if self.task == nil { + let task = p.session.dataTask( + with: p.request, + completionHandler: handleResponse(data:response:error:) + ) + self.task = task + } + + self.demand += d + let task = self.task! + lock.unlock() + + task.resume() + } + + private func handleResponse(data: Data?, response: URLResponse?, error: Error?) { + lock.lock() + guard demand > 0, + parent != nil, + let ds = downstream + else { + lock.unlock() + return + } + + parent = nil + downstream = nil + + // We clear demand since this is a single shot shape + demand = .max(0) + task = nil + lock.unlock() + + if let response = response, error == nil { + _ = ds.receive((data ?? Data(), response)) + ds.receive(completion: .finished) + } else { + let urlError = error as? URLError ?? URLError(.unknown) + ds.receive(completion: .failure(urlError)) + } + } + + func cancel() { + lock.lock() + guard parent != nil else { + lock.unlock() + return + } + parent = nil + downstream = nil + demand = .max(0) + let task = self.task + self.task = nil + lock.unlock() + task?.cancel() + } + } + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Scanner.swift b/stdlib/public/Darwin/Foundation/Scanner.swift new file mode 100644 index 0000000000000..ea77ebc6a360f --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Scanner.swift @@ -0,0 +1,213 @@ +// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_exported import Foundation // Clang module + +extension CharacterSet { + fileprivate func contains(_ character: Character) -> Bool { + return character.unicodeScalars.allSatisfy(self.contains(_:)) + } +} + +// ----- + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension Scanner { + public enum NumberRepresentation { + case decimal // See the %d, %f and %F format conversions. + case hexadecimal // See the %x, %X, %a and %A format conversions. For integers, a leading 0x or 0X is optional; for floating-point numbers, it is required. + } + + public var currentIndex: String.Index { + get { + let string = self.string + var index = string._toUTF16Index(scanLocation) + + var delta = 0 + while index != string.endIndex && index.samePosition(in: string) == nil { + delta += 1 + index = string._toUTF16Index(scanLocation + delta) + } + + return index + } + set { scanLocation = string._toUTF16Offset(newValue) } + } + + fileprivate func _scan + (representation: NumberRepresentation, + scanDecimal: (UnsafeMutablePointer?) -> Bool, + scanHexadecimal: (UnsafeMutablePointer?) -> Bool) -> Integer? { + var value: Integer = .max + + switch representation { + case .decimal: guard scanDecimal(&value) else { return nil } + case .hexadecimal: guard scanHexadecimal(&value) else { return nil } + } + + return value + } + + fileprivate func _scan + (representation: NumberRepresentation, + scanDecimal: (UnsafeMutablePointer?) -> Bool, + scanHexadecimal: (UnsafeMutablePointer?) -> Bool) -> FloatingPoint? { + var value: FloatingPoint = .greatestFiniteMagnitude + + switch representation { + case .decimal: guard scanDecimal(&value) else { return nil } + case .hexadecimal: guard scanHexadecimal(&value) else { return nil } + } + + return value + } + + fileprivate func _scan + (representation: NumberRepresentation, + scanDecimal: (UnsafeMutablePointer?) -> Bool, + overflowingScanHexadecimal: (UnsafeMutablePointer?) -> Bool) -> Integer? { + return _scan(representation: representation, scanDecimal: scanDecimal, scanHexadecimal: { (pointer) -> Bool in + var unsignedValue: OverflowingHexadecimalInteger = .max + guard overflowingScanHexadecimal(&unsignedValue) else { return false } + if unsignedValue <= Integer.max { + pointer?.pointee = Integer(unsignedValue) + } + return true + }) + } + + public func scanInt(representation: NumberRepresentation = .decimal) -> Int? { +#if arch(x86_64) || arch(arm64) || arch(s390x) || arch(powerpc64) || arch(powerpc64le) + if let value = scanInt64(representation: representation) { + return Int(value) + } +#elseif arch(i386) || arch(arm) + if let value = scanInt32(representation: representation) { + return Int(value) + } +#else + #error("This architecture isn't known. Add it to the 32-bit or 64-bit line; if the machine word isn't either of those, you need to implement appropriate scanning and handle the potential overflow here.") +#endif + return nil + } + + public func scanInt32(representation: NumberRepresentation = .decimal) -> Int32? { + return _scan(representation: representation, scanDecimal: self.scanInt32(_:), overflowingScanHexadecimal: self.scanHexInt32(_:)) + } + + public func scanInt64(representation: NumberRepresentation = .decimal) -> Int64? { + return _scan(representation: representation, scanDecimal: self.scanInt64(_:), overflowingScanHexadecimal: self.scanHexInt64(_:)) + } + + public func scanUInt64(representation: NumberRepresentation = .decimal) -> UInt64? { + return _scan(representation: representation, scanDecimal: self.scanUnsignedLongLong(_:), scanHexadecimal: self.scanHexInt64(_:)) + } + + public func scanFloat(representation: NumberRepresentation = .decimal) -> Float? { + return _scan(representation: representation, scanDecimal: self.scanFloat(_:), scanHexadecimal: self.scanHexFloat(_:)) + } + + public func scanDouble(representation: NumberRepresentation = .decimal) -> Double? { + return _scan(representation: representation, scanDecimal: self.scanDouble(_:), scanHexadecimal: self.scanHexDouble(_:)) + } + + public func scanDecimal() -> Decimal? { + var value: Decimal = 0 + guard scanDecimal(&value) else { return nil } + return value + } + + + fileprivate var _currentIndexAfterSkipping: String.Index { + guard let skips = charactersToBeSkipped else { return currentIndex } + + let index = string[currentIndex...].firstIndex(where: { !skips.contains($0) }) + return index ?? string.endIndex + } + + public func scanString(_ searchString: String) -> String? { + let currentIndex = _currentIndexAfterSkipping + + guard let substringEnd = string.index(currentIndex, offsetBy: searchString.count, limitedBy: string.endIndex) else { return nil } + + if string.compare(searchString, options: self.caseSensitive ? [] : .caseInsensitive, range: currentIndex ..< substringEnd, locale: self.locale as? Locale) == .orderedSame { + let it = string[currentIndex ..< substringEnd] + self.currentIndex = substringEnd + return String(it) + } else { + return nil + } + } + + public func scanCharacters(from set: CharacterSet) -> String? { + let currentIndex = _currentIndexAfterSkipping + + let substringEnd = string[currentIndex...].firstIndex(where: { !set.contains($0) }) ?? string.endIndex + guard currentIndex != substringEnd else { return nil } + + let substring = string[currentIndex ..< substringEnd] + self.currentIndex = substringEnd + return String(substring) + } + + public func scanUpToString(_ substring: String) -> String? { + guard !substring.isEmpty else { return nil } + let string = self.string + let startIndex = _currentIndexAfterSkipping + + var beginningOfNewString = string.endIndex + var currentSearchIndex = startIndex + + repeat { + guard let range = string.range(of: substring, options: self.caseSensitive ? [] : .caseInsensitive, range: currentSearchIndex ..< string.endIndex, locale: self.locale as? Locale) else { + // If the string isn't found at all, it means it's not in the string. Just take everything to the end. + beginningOfNewString = string.endIndex + break + } + + // range(of:…) can return partial grapheme ranges when dealing with emoji. + // Make sure we take a range only if it doesn't split a grapheme in the string. + if let maybeBeginning = range.lowerBound.samePosition(in: string), + range.upperBound.samePosition(in: string) != nil { + beginningOfNewString = maybeBeginning + break + } + + // If we got here, we need to search again starting from just after the location we found. + currentSearchIndex = range.upperBound + } while beginningOfNewString == string.endIndex && currentSearchIndex < string.endIndex + + guard startIndex != beginningOfNewString else { return nil } + + let foundSubstring = string[startIndex ..< beginningOfNewString] + self.currentIndex = beginningOfNewString + return String(foundSubstring) + } + + public func scanUpToCharacters(from set: CharacterSet) -> String? { + let currentIndex = _currentIndexAfterSkipping + let string = self.string + + let firstCharacterInSet = string[currentIndex...].firstIndex(where: { set.contains($0) }) ?? string.endIndex + guard currentIndex != firstCharacterInSet else { return nil } + self.currentIndex = firstCharacterInSet + return String(string[currentIndex ..< firstCharacterInSet]) + } + + public func scanCharacter() -> Character? { + let currentIndex = _currentIndexAfterSkipping + + let string = self.string + + guard currentIndex != string.endIndex else { return nil } + + let character = string[currentIndex] + self.currentIndex = string.index(after: currentIndex) + return character + } +} diff --git a/stdlib/public/Darwin/Foundation/Schedulers+Date.swift b/stdlib/public/Darwin/Foundation/Schedulers+Date.swift new file mode 100644 index 0000000000000..63b6eb0c05939 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Schedulers+Date.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +// Date cannot conform to Strideable per rdar://35158274 +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) + extension Date /* : Strideable */ { + public typealias Stride = TimeInterval + + public func distance(to other: Date) -> TimeInterval { + return other.timeIntervalSinceReferenceDate - self.timeIntervalSinceReferenceDate + } + + public func advanced(by n: TimeInterval) -> Date { + return self + n + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ \ No newline at end of file diff --git a/stdlib/public/Darwin/Foundation/Schedulers+OperationQueue.swift b/stdlib/public/Darwin/Foundation/Schedulers+OperationQueue.swift new file mode 100644 index 0000000000000..963c83fe2ee4a --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Schedulers+OperationQueue.swift @@ -0,0 +1,214 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension OperationQueue: Scheduler { + /// The scheduler time type used by the operation queue. + public struct SchedulerTimeType: Strideable, Codable, Hashable { + /// The date represented by this type. + public var date: Date + + /// Initializes a operation queue scheduler time with the given date. + /// + /// - Parameter date: The date to represent. + public init(_ date: Date) { + self.date = date + } + + /// Returns the distance to another operation queue scheduler time. + /// + /// - Parameter other: Another operation queue time. + /// - Returns: The time interval between this time and the provided time. + public func distance(to other: OperationQueue.SchedulerTimeType) -> OperationQueue.SchedulerTimeType.Stride { + return OperationQueue.SchedulerTimeType.Stride(floatLiteral: date.distance(to: other.date)) + } + + /// Returns a operation queue scheduler time calculated by advancing this instance’s time by the given interval. + /// + /// - Parameter n: A time interval to advance. + /// - Returns: A operation queue time advanced by the given interval from this instance’s time. + public func advanced(by n: OperationQueue.SchedulerTimeType.Stride) -> OperationQueue.SchedulerTimeType { + return OperationQueue.SchedulerTimeType(date.advanced(by: n.timeInterval)) + } + + /// The interval by which operation queue times advance. + public struct Stride: ExpressibleByFloatLiteral, Comparable, SignedNumeric, Codable, SchedulerTimeIntervalConvertible { + public typealias FloatLiteralType = TimeInterval + public typealias IntegerLiteralType = TimeInterval + public typealias Magnitude = TimeInterval + + /// The value of this time interval in seconds. + public var magnitude: TimeInterval + + /// The value of this time interval in seconds. + public var timeInterval: TimeInterval { + return magnitude + } + + public init(integerLiteral value: TimeInterval) { + magnitude = value + } + + public init(floatLiteral value: TimeInterval) { + magnitude = value + } + + public init(_ timeInterval: TimeInterval) { + magnitude = timeInterval + } + + public init?(exactly source: T) where T: BinaryInteger { + if let d = TimeInterval(exactly: source) { + magnitude = d + } else { + return nil + } + } + + // --- + + public static func < (lhs: Stride, rhs: Stride) -> Bool { + return lhs.magnitude < rhs.magnitude + } + + // --- + + public static func * (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.timeInterval * rhs.timeInterval) + } + + public static func + (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.magnitude + rhs.magnitude) + } + + public static func - (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.magnitude - rhs.magnitude) + } + + // --- + + public static func *= (lhs: inout Stride, rhs: Stride) { + let result = lhs * rhs + lhs = result + } + + public static func += (lhs: inout Stride, rhs: Stride) { + let result = lhs + rhs + lhs = result + } + + public static func -= (lhs: inout Stride, rhs: Stride) { + let result = lhs - rhs + lhs = result + } + + // --- + + public static func seconds(_ s: Int) -> Stride { + return Stride(Double(s)) + } + + public static func seconds(_ s: Double) -> Stride { + return Stride(s) + } + + public static func milliseconds(_ ms: Int) -> Stride { + return Stride(Double(ms) / 1_000.0) + } + + public static func microseconds(_ us: Int) -> Stride { + return Stride(Double(us) / 1_000_000.0) + } + + public static func nanoseconds(_ ns: Int) -> Stride { + return Stride(Double(ns) / 1_000_000_000.0) + } + } + } + + /// Options that affect the operation of the operation queue scheduler. + public struct SchedulerOptions { } + + private final class DelayReadyOperation: Operation, Cancellable { + static var readySchedulingQueue: DispatchQueue = { + return DispatchQueue(label: "DelayReadyOperation") + }() + + var action: (() -> Void)? + var readyFromAfter: Bool + + init(_ action: @escaping() -> Void, after: OperationQueue.SchedulerTimeType) { + self.action = action + readyFromAfter = false + super.init() + let deadline = DispatchTime.now() + after.date.timeIntervalSinceNow + DelayReadyOperation.readySchedulingQueue.asyncAfter(deadline: deadline) { [weak self] in + self?.becomeReady() + } + } + + override func main() { + action!() + action = nil + } + + func becomeReady() { + willChangeValue(for: \.isReady) + readyFromAfter = true + didChangeValue(for: \.isReady) + } + + override var isReady: Bool { + return super.isReady && readyFromAfter + } + } + + public func schedule(options: OperationQueue.SchedulerOptions?, + _ action: @escaping () -> Void) { + let op = BlockOperation(block: action) + addOperation(op) + } + + public func schedule(after date: OperationQueue.SchedulerTimeType, + tolerance: OperationQueue.SchedulerTimeType.Stride, + options: OperationQueue.SchedulerOptions?, + _ action: @escaping () -> Void) { + let op = DelayReadyOperation(action, after: date) + addOperation(op) + } + + public func schedule(after date: OperationQueue.SchedulerTimeType, + interval: OperationQueue.SchedulerTimeType.Stride, + tolerance: OperationQueue.SchedulerTimeType.Stride, + options: OperationQueue.SchedulerOptions?, + _ action: @escaping () -> Void) -> Cancellable { + let op = DelayReadyOperation(action, after: date.advanced(by: interval)) + addOperation(op) + return AnyCancellable(op) + } + + public var now: OperationQueue.SchedulerTimeType { + return OperationQueue.SchedulerTimeType(Date()) + } + + public var minimumTolerance: OperationQueue.SchedulerTimeType.Stride { + return OperationQueue.SchedulerTimeType.Stride(0.0) + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/Schedulers+RunLoop.swift b/stdlib/public/Darwin/Foundation/Schedulers+RunLoop.swift new file mode 100644 index 0000000000000..65b6faff19db1 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/Schedulers+RunLoop.swift @@ -0,0 +1,199 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +@_exported import Foundation // Clang module +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension RunLoop: Scheduler { + /// The scheduler time type used by the run loop. + public struct SchedulerTimeType: Strideable, Codable, Hashable { + /// The date represented by this type. + public var date: Date + + /// Initializes a run loop scheduler time with the given date. + /// + /// - Parameter date: The date to represent. + public init(_ date: Date) { + self.date = date + } + + /// Returns the distance to another run loop scheduler time. + /// + /// - Parameter other: Another dispatch queue time. + /// - Returns: The time interval between this time and the provided time. + public func distance(to other: RunLoop.SchedulerTimeType) -> SchedulerTimeType.Stride { + return Stride(floatLiteral: date.distance(to: other.date)) + } + + /// Returns a run loop scheduler time calculated by advancing this instance’s time by the given interval. + /// + /// - Parameter n: A time interval to advance. + /// - Returns: A dispatch queue time advanced by the given interval from this instance’s time. + public func advanced(by n: SchedulerTimeType.Stride) -> RunLoop.SchedulerTimeType { + return SchedulerTimeType(date.advanced(by: n.timeInterval)) + } + + /// The interval by which run loop times advance. + public struct Stride: ExpressibleByFloatLiteral, Comparable, SignedNumeric, Codable, SchedulerTimeIntervalConvertible { + public typealias FloatLiteralType = TimeInterval + public typealias IntegerLiteralType = TimeInterval + public typealias Magnitude = TimeInterval + + /// The value of this time interval in seconds. + public var magnitude: TimeInterval + + /// The value of this time interval in seconds. + public var timeInterval: TimeInterval { + return magnitude + } + + public init(integerLiteral value: TimeInterval) { + magnitude = value + } + + public init(floatLiteral value: TimeInterval) { + magnitude = value + } + + public init(_ timeInterval: TimeInterval) { + magnitude = timeInterval + } + + public init?(exactly source: T) where T: BinaryInteger { + if let d = TimeInterval(exactly: source) { + magnitude = d + } else { + return nil + } + } + + // --- + + public static func < (lhs: Stride, rhs: Stride) -> Bool { + return lhs.magnitude < rhs.magnitude + } + + // --- + + public static func * (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.timeInterval * rhs.timeInterval) + } + + public static func + (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.magnitude + rhs.magnitude) + } + + public static func - (lhs: Stride, rhs: Stride) -> Stride { + return Stride(lhs.magnitude - rhs.magnitude) + } + + // --- + + public static func *= (lhs: inout Stride, rhs: Stride) { + let result = lhs * rhs + lhs = result + } + + public static func += (lhs: inout Stride, rhs: Stride) { + let result = lhs + rhs + lhs = result + } + + public static func -= (lhs: inout Stride, rhs: Stride) { + let result = lhs - rhs + lhs = result + } + + // --- + + public static func seconds(_ s: Int) -> Stride { + return Stride(Double(s)) + } + + public static func seconds(_ s: Double) -> Stride { + return Stride(s) + } + + public static func milliseconds(_ ms: Int) -> Stride { + return Stride(Double(ms) / 1_000.0) + } + + public static func microseconds(_ us: Int) -> Stride { + return Stride(Double(us) / 1_000_000.0) + } + + public static func nanoseconds(_ ns: Int) -> Stride { + return Stride(Double(ns) / 1_000_000_000.0) + } + } + } + + /// Options that affect the operation of the run loop scheduler. + public struct SchedulerOptions { } + + public func schedule(options: SchedulerOptions?, + _ action: @escaping () -> Void) { + self.perform(action) + } + + public func schedule(after date: SchedulerTimeType, + tolerance: SchedulerTimeType.Stride, + options: SchedulerOptions?, + _ action: @escaping () -> Void) { + let ti = date.date.timeIntervalSince(Date()) + self.perform(#selector(self.runLoopScheduled), with: _CombineRunLoopAction(action), afterDelay: ti) + } + + public func schedule(after date: SchedulerTimeType, + interval: SchedulerTimeType.Stride, + tolerance: SchedulerTimeType.Stride, + options: SchedulerOptions?, + _ action: @escaping () -> Void) -> Cancellable { + let timer = Timer(fire: date.date, interval: interval.timeInterval, repeats: true) { _ in + action() + } + + timer.tolerance = tolerance.timeInterval + self.add(timer, forMode: .default) + + return AnyCancellable(timer.invalidate) + } + + public var now: SchedulerTimeType { + return SchedulerTimeType(Date()) + } + + public var minimumTolerance: SchedulerTimeType.Stride { + return 0.0 + } + + @objc + fileprivate func runLoopScheduled(action: _CombineRunLoopAction) { + action.action() + } +} + +@objc +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +private class _CombineRunLoopAction: NSObject { + let action: () -> Void + + init(_ action: @escaping () -> Void) { + self.action = action + } +} + +#endif /* !(os(iOS) && (arch(i386) || arch(arm))) */ diff --git a/stdlib/public/Darwin/Foundation/URLCache.swift b/stdlib/public/Darwin/Foundation/URLCache.swift new file mode 100644 index 0000000000000..2a2f7588e9791 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/URLCache.swift @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_exported import Foundation // Clang module + +extension URLCache { + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public convenience init(memoryCapacity: Int, diskCapacity: Int, directory: URL? = nil) { + self.init(__memoryCapacity: memoryCapacity, diskCapacity: diskCapacity, directoryURL: directory) + } +} diff --git a/stdlib/public/Darwin/Foundation/URLRequest.swift b/stdlib/public/Darwin/Foundation/URLRequest.swift index 82659e1093834..9695aba1fb9e8 100644 --- a/stdlib/public/Darwin/Foundation/URLRequest.swift +++ b/stdlib/public/Darwin/Foundation/URLRequest.swift @@ -122,6 +122,30 @@ public struct URLRequest : ReferenceConvertible, Equatable, Hashable { } } + /// `true` if the receiver is allowed to use an interface marked as expensive to + /// satify the request, `false` otherwise. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public var allowsExpensiveNetworkAccess: Bool { + get { + return _handle.map { $0.allowsExpensiveNetworkAccess } + } + set { + _applyMutation { $0.allowsExpensiveNetworkAccess = newValue } + } + } + + /// `true` if the receiver is allowed to use an interface marked as constrained to + /// satify the request, `false` otherwise. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public var allowsConstrainedNetworkAccess: Bool { + get { + return _handle.map { $0.allowsConstrainedNetworkAccess } + } + set { + _applyMutation { $0.allowsConstrainedNetworkAccess = newValue } + } + } + /// The HTTP request method of the receiver. public var httpMethod: String? { get { diff --git a/stdlib/public/Darwin/Foundation/URLSession.swift b/stdlib/public/Darwin/Foundation/URLSession.swift new file mode 100644 index 0000000000000..79cabfc93ded0 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/URLSession.swift @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_exported import Foundation // Clang module + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension URLSessionWebSocketTask { + public enum Message { + case data(Data) + case string(String) + } + + public func send(_ message: Message, completionHandler: @escaping (Error?) -> Void) { + switch message { + case .data(let data): + __send(__NSURLSessionWebSocketMessage(data: data), completionHandler: completionHandler) + case .string(let string): + __send(__NSURLSessionWebSocketMessage(string: string), completionHandler: completionHandler) + } + } + + public func receive(completionHandler: @escaping (Result) -> Void) { + __receiveMessage { message, error in + switch (message, error) { + case (.some(let message), nil): + switch message.type { + case .data: + completionHandler(.success(.data(message.data!))) + case .string: + completionHandler(.success(.string(message.string!))) + @unknown default: + break + } + case (nil, .some(let error)): + completionHandler(.failure(error)) + case (_, _): + fatalError("Only one of message or error should be nil") + } + } + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension URLSessionTaskTransactionMetrics { + public var localPort: Int? { + return __localPort as? Int + } + + public var remotePort: Int? { + return __remotePort as? Int + } + + public var negotiatedTLSProtocolVersion: tls_protocol_version_t? { + return (__negotiatedTLSProtocolVersion as? UInt16).flatMap(tls_protocol_version_t.init(rawValue:)) + } + + public var negotiatedTLSCipherSuite: tls_ciphersuite_t? { + return (__negotiatedTLSCipherSuite as? UInt16).flatMap(tls_ciphersuite_t.init(rawValue:)) + } +} diff --git a/stdlib/public/Darwin/Foundation/UUID.swift b/stdlib/public/Darwin/Foundation/UUID.swift index 203cd4a47cd14..413174ff1cb68 100644 --- a/stdlib/public/Darwin/Foundation/UUID.swift +++ b/stdlib/public/Darwin/Foundation/UUID.swift @@ -99,22 +99,15 @@ public struct UUID : ReferenceConvertible, Hashable, Equatable, CustomStringConv } public static func ==(lhs: UUID, rhs: UUID) -> Bool { - return lhs.uuid.0 == rhs.uuid.0 && - lhs.uuid.1 == rhs.uuid.1 && - lhs.uuid.2 == rhs.uuid.2 && - lhs.uuid.3 == rhs.uuid.3 && - lhs.uuid.4 == rhs.uuid.4 && - lhs.uuid.5 == rhs.uuid.5 && - lhs.uuid.6 == rhs.uuid.6 && - lhs.uuid.7 == rhs.uuid.7 && - lhs.uuid.8 == rhs.uuid.8 && - lhs.uuid.9 == rhs.uuid.9 && - lhs.uuid.10 == rhs.uuid.10 && - lhs.uuid.11 == rhs.uuid.11 && - lhs.uuid.12 == rhs.uuid.12 && - lhs.uuid.13 == rhs.uuid.13 && - lhs.uuid.14 == rhs.uuid.14 && - lhs.uuid.15 == rhs.uuid.15 + return withUnsafeBytes(of: rhs.uuid) { (rhsPtr) -> Bool in + return withUnsafeBytes(of: lhs.uuid) { (lhsPtr) -> Bool in + let lhsFirstChunk = lhsPtr.load(fromByteOffset: 0, as: UInt64.self) + let lhsSecondChunk = lhsPtr.load(fromByteOffset: MemoryLayout.size, as: UInt64.self) + let rhsFirstChunk = rhsPtr.load(fromByteOffset: 0, as: UInt64.self) + let rhsSecondChunk = rhsPtr.load(fromByteOffset: MemoryLayout.size, as: UInt64.self) + return ((lhsFirstChunk ^ rhsFirstChunk) | (lhsSecondChunk ^ rhsSecondChunk)) == 0 + } + } } } diff --git a/stdlib/public/Darwin/WatchKit/CMakeLists.txt b/stdlib/public/Darwin/WatchKit/CMakeLists.txt index 37431bd22ceff..a4d2684126efb 100644 --- a/stdlib/public/Darwin/WatchKit/CMakeLists.txt +++ b/stdlib/public/Darwin/WatchKit/CMakeLists.txt @@ -13,7 +13,6 @@ add_swift_target_library(swiftWatchKit ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES} FRAMEWORK_DEPENDS_WEAK WatchKit SWIFT_COMPILE_FLAGS_WATCHOS -Xfrontend -disable-autolink-framework -Xfrontend CoreText - DEPLOYMENT_VERSION_IOS ${SWIFTLIB_DEPLOYMENT_VERSION_WATCHKIT_IOS} DEPLOYMENT_VERSION_WATCHOS ${SWIFTLIB_DEPLOYMENT_VERSION_WATCHKIT_WATCHOS} INSTALL_IN_COMPONENT sdk-overlay ) diff --git a/stdlib/public/SwiftShims/FoundationOverlayShims.h b/stdlib/public/SwiftShims/FoundationOverlayShims.h index 4a26a2eefdf75..1c175352cae86 100644 --- a/stdlib/public/SwiftShims/FoundationOverlayShims.h +++ b/stdlib/public/SwiftShims/FoundationOverlayShims.h @@ -71,3 +71,7 @@ static inline _Bool _withStackOrHeapBuffer(size_t amount, void (__attribute__((n } return true; } + +@protocol _NSKVOCompatibilityShim ++ (void)_noteProcessHasUsedKVOSwiftOverlay; +@end diff --git a/stdlib/public/SwiftShims/NetworkOverlayShims.h b/stdlib/public/SwiftShims/NetworkOverlayShims.h index ef9151c21e8b7..4d0234b4cac61 100644 --- a/stdlib/public/SwiftShims/NetworkOverlayShims.h +++ b/stdlib/public/SwiftShims/NetworkOverlayShims.h @@ -24,6 +24,11 @@ #pragma clang assume_nonnull begin +static inline uint32_t +_swift_nw_data_transfer_report_all_paths(void) { + return (uint32_t)(-1); +} + typedef void (^__swift_nw_connection_send_completion_t)(_Nullable nw_error_t error); static inline SWIFT_NW_RETURNS_RETAINED nw_content_context_t @@ -51,6 +56,12 @@ _swift_nw_connection_send(nw_connection_t connection, _Nullable dispatch_data_t nw_connection_send(connection, content, context, is_complete, completion); } +API_AVAILABLE(macos(10.15)) API_UNAVAILABLE(ios, watchos, tvos) +static inline SWIFT_NW_RETURNS_RETAINED nw_parameters_t +_swift_nw_parameters_create_custom_ip(uint8_t custom_ip_protocol_number) { + nw_parameters_create_custom_ip(custom_ip_protocol_number, _nw_parameters_configure_protocol_default_configuration); +} + API_AVAILABLE(macos(10.14), ios(12.0), watchos(5.0), tvos(12.0)) _Nullable SWIFT_NW_RETURNS_RETAINED nw_endpoint_t nw_endpoint_create_unix(const char *path); diff --git a/stdlib/public/SwiftShims/UIKitOverlayShims.h b/stdlib/public/SwiftShims/UIKitOverlayShims.h index 2be287b9366af..bed05075f374c 100644 --- a/stdlib/public/SwiftShims/UIKitOverlayShims.h +++ b/stdlib/public/SwiftShims/UIKitOverlayShims.h @@ -21,10 +21,10 @@ #if TARGET_OS_TV || TARGET_OS_IOS static inline BOOL _swift_UIKit_UIFocusEnvironmentContainsEnvironment( - id environment, - id otherEnvironment -) API_AVAILABLE(ios(11.0), tvos(11.0)) { - return [UIFocusSystem environment:environment containsEnvironment:otherEnvironment]; + id environment, + id otherEnvironment + ) API_AVAILABLE(ios(11.0), tvos(11.0)) { + return [UIFocusSystem environment:environment containsEnvironment:otherEnvironment]; } #endif // TARGET_OS_TV || TARGET_OS_IOS @@ -33,5 +33,163 @@ static inline BOOL _swift_UIKit_UIFocusEnvironmentContainsEnvironment( #endif +//===--------------------===// +// diffable data source // +//===--------------------===// + +#if TARGET_OS_TV || TARGET_OS_IOS + +#if __has_feature(nullability) +#pragma clang assume_nonnull begin +#endif + +typedef UITableViewCell * _Nullable (^UITableViewDiffableDataSourceCellProvider)(UITableView * _Nonnull, NSIndexPath * _Nonnull, id _Nonnull); +typedef UICollectionViewCell * _Nullable (^UICollectionViewDiffableDataSourceCellProvider)(UICollectionView * _Nonnull,NSIndexPath * _Nonnull, id _Nonnull); + +typedef UICollectionViewCell * _Nullable(^UIDiffableDataSourceCollectionViewCellProvider)(UICollectionView*, NSIndexPath *indexPath, id identifier); +typedef UICollectionReusableView * _Nullable(^UIDiffableDataSourceSupplementaryViewProvider)(UICollectionView *, NSString *kind, NSIndexPath *indexPath); + +typedef NSString * _Nonnull(^UIDiffableDataSourceCellReuseIdentifierProvider)(id _Nonnull identifier); +typedef void(^UIDiffableDataSourceCollectionViewCellConfigurationHandler)(__kindof UICollectionViewCell * _Nonnull , id _Nonnull identifier); + +typedef NSString * _Nonnull(^UIDiffableDataSourceSupplementaryViewReuseIdentifierProvider)(NSString * _Nonnull kind, NSIndexPath * _Nonnull indexPath); +typedef void(^UIDiffableDataSourceSupplementaryViewConfigurationHandler)(__kindof UICollectionReusableView * _Nonnull, NSString * _Nonnull kind, NSIndexPath * _Nonnull indexPath); + +typedef UITableViewCell * _Nullable(^UIDiffableDataSourceTableViewCellProvider)(__kindof UITableView * _Nonnull, NSIndexPath * _Nonnull, id _Nonnull identifier); +typedef void(^UIDiffableDataSourceTableViewCellConfigurationHandler)(__kindof UITableViewCell * _Nonnull , id _Nonnull identifier); + +@class __UIDiffableDataSourceSnapshot; + +API_AVAILABLE(ios(13.0), tvos(13.0)) +@interface __UIDiffableDataSource : NSObject + +- (instancetype)initWithCollectionView:(UICollectionView*)collectionView + cellProvider:(UIDiffableDataSourceCollectionViewCellProvider)cellProvider; + +- (instancetype)initWithCollectionView:(UICollectionView *)collectionView + cellProvider:(UIDiffableDataSourceCollectionViewCellProvider)cellProvider + dataSource:(id)dataSource; + +- (instancetype)initWithCollectionView:(UICollectionView*)collectionView + reuseIdentifierProvider:(UIDiffableDataSourceCellReuseIdentifierProvider)cellReuseProvider + cellConfigurationHandler:(UIDiffableDataSourceCollectionViewCellConfigurationHandler)cellConfigurationHandler; + +- (NSString*)description; + +- (instancetype)initWithTableView:(UITableView*)tableView + cellProvider:(UIDiffableDataSourceTableViewCellProvider)cellProvider; + +- (instancetype)initWithTableView:(UITableView*)tableView + reuseIdentifierProvider:(UIDiffableDataSourceCellReuseIdentifierProvider)cellResueProvider + cellConfigurationHandler:(UIDiffableDataSourceTableViewCellConfigurationHandler)cellConfigurationHandler; + +@property(nonatomic) UITableViewRowAnimation tableViewDefaultRowAnimation; +@property(nonatomic,weak,readonly,nullable) UITableView *tableView; +@property(nonatomic,copy) UITableViewDiffableDataSourceCellProvider tableViewCellProvider; + +@property(nonatomic,weak,readonly,nullable) UICollectionView *collectionView; +@property(nonatomic,nullable,copy) UIDiffableDataSourceSupplementaryViewProvider supplementaryViewProvider; + + +- (instancetype)init NS_UNAVAILABLE; + +@property(nonatomic,readonly) NSInteger numberOfItems; +@property(nonatomic,readonly) NSInteger numberOfSections; +@property(nonatomic,readonly) NSArray *sectionIdentifiers; +@property(nonatomic,readonly) NSArray *itemIdentifiers; + +- (NSInteger)numberOfItemsInSection:(id)sectionIdentifier; +- (NSArray*)itemIdentifiersInSectionWithIdentifier:(id)sectionIdentifier; +- (nullable id)sectionIdentifierForSectionContainingItemIdentifier:(id)identifier; + +- (NSInteger)indexOfItemIdentifier:(id)itemIdentifier; +- (NSInteger)indexOfSectionIdentifier:(id)sectionIdentifier; + +- (void)appendItemsWithIdentifiers:(NSArray*)identifiers; +- (void)appendItemsWithIdentifiers:(NSArray*)identifiers intoSectionWithIdentifier:(id _Nullable)sectionIdentifier; + +- (void)insertItemsWithIdentifiers:(NSArray*)identifiers beforeItemWithIdentifier:(id)itemIdentifier; +- (void)insertItemsWithIdentifiers:(NSArray*)identifiers afterItemWithIdentifier:(id)itemIdentifier; + +- (void)deleteItemsWithIdentifiers:(NSArray*)identifiers; +- (void)deleteAllItems; + +- (void)moveItemWithIdentifier:(id)fromIdentifier beforeItemWithIdentifier:(id)toIdentifier; +- (void)moveItemWithIdentifier:(id)fromIdentifier afterItemWithIdentifier:(id)toIdentifier; + +- (void)reloadItemsWithIdentifiers:(NSArray*)identifiers; + +- (void)appendSectionsWithIdentifiers:(NSArray*)sectionIdentifiers; + +- (void)insertSectionsWithIdentifiers:(NSArray*)sectionIdentifiers beforeSectionWithIdentifier:(id)toSectionIdentifier; +- (void)insertSectionsWithIdentifiers:(NSArray*)sectionIdentifiers afterSectionWithIdentifier:(id)toSectionIdentifier; + +- (void)deleteSectionsWithIdentifiers:(NSArray*)sectionIdentifiers; + +- (void)moveSectionWithIdentifier:(id)fromSectionIdentifier beforeSectionWithIdentifier:(id)toSectionIdentifier; +- (void)moveSectionWithIdentifier:(id)fromSectionIdentifier afterSectionWithIdentifier:(id)toSectionIdentifier; + +- (void)reloadSectionsWithIdentifiers:(NSArray*)sectionIdentifiers; + +- (nullable id)itemIdentifierForIndexPath:(NSIndexPath*)indexPath; +- (nullable NSIndexPath*)indexPathForItemIdentifier:(id)identifier; + + +- (__UIDiffableDataSourceSnapshot*)snapshot; +- (__UIDiffableDataSourceSnapshot*)emptySnapshot; +- (void)applyDifferencesFromSnapshot:(__UIDiffableDataSourceSnapshot*)snapshot; +- (void)reloadFromSnapshot:(__UIDiffableDataSourceSnapshot*)snapshot; +- (void)applyDifferencesFromSnapshot:(__UIDiffableDataSourceSnapshot *)snapshot animatingDifferences:(BOOL)animatingDifferences; + + +// deprecated + +- (void)appendSectionWithIdentifier:(id)sectionIdentifier; +- (void)insertSectionWithIdentifier:(id)sectionIdentifier beforeSectionWithIdentifier:(id)toSectionIdentifier; +- (void)insertSectionWithIdentifier:(id)sectionIdentifier afterSectionWithIdentifier:(id)toSectionIdentifier; +- (void)applySnapshot:(__UIDiffableDataSourceSnapshot*)snapshot; + + +@property(nonatomic,nullable,copy) UIDiffableDataSourceSupplementaryViewReuseIdentifierProvider supplementaryReuseIdentifierProvider; +@property(nonatomic,nullable,copy) UIDiffableDataSourceSupplementaryViewConfigurationHandler supplementaryViewConfigurationHandler; + +@property(nonatomic,copy) UICollectionViewDiffableDataSourceCellProvider collectionViewCellProvider; + + +// helpers + +- (NSInteger)_numberOfSectionsForCollectionView:(UICollectionView*)collectionView NS_SWIFT_NAME(_numberOfSectionsForCollectionView(_:)); +- (NSInteger)_numberOfItemsInSection:(NSInteger)section collectionView:(UICollectionView*)collectionView NS_SWIFT_NAME(_numberOfItemsInSection(_:collectionView:)); +- (UICollectionViewCell*)_cellForItemAtIndexPath:(NSIndexPath*)indexPath collectionView:(UICollectionView*)collectionView NS_SWIFT_NAME(_cellForItemAtIndexPath(_:collectionView:)); +- (UICollectionReusableView*)_viewForSupplementaryElementOfKind:(NSString *)kind atIndexPath:(NSIndexPath *)indexPath collectionView:(UICollectionView *)collectionView NS_SWIFT_NAME(_viewForSupplementaryElementOfKind(_:atIndexPath:collectionView:)); + +- (NSInteger)_numberOfSectionsForTableView:(UITableView*)tableView NS_SWIFT_NAME(_numberOfSectionsForTableView(_:)); +- (NSInteger)_numberOfRowsInSection:(NSInteger)section tableView:(UITableView*)tableView NS_SWIFT_NAME(_numberOfRowsInSection(_:tableView:)); +- (UITableViewCell*)_cellForRowAtIndexPath:(NSIndexPath*)indexPath tableView:(UITableView*)tableView NS_SWIFT_NAME(_cellForRowAtIndexPath(_:tableView:)); + +- (NSInteger)_numberOfSectionsForCollectionViewDeprecatedSPI:(UICollectionView*)collectionView NS_SWIFT_NAME(numberOfSections(for:)); +- (NSInteger)_numberOfItemsInSectionDeprecatedSPI:(NSInteger)section collectionView:(UICollectionView*)collectionView NS_SWIFT_NAME(numberOfItems(inSection:collectionView:)); +- (UICollectionViewCell*)_cellForItemAtIndexPathDeprecatedSPI:(NSIndexPath*)indexPath collectionView:(UICollectionView*)collectionView NS_SWIFT_NAME(cellForItem(at:collectionView:)); +- (UICollectionReusableView*)_viewForSupplementaryElementOfKindDeprecatedSPI:(NSString *)kind atIndexPath:(NSIndexPath *)indexPath collectionView:(UICollectionView *)collectionView NS_SWIFT_NAME(viewForSupplementaryElement(ofKind:at:collectionView:)); + +- (NSInteger)_numberOfSectionsForTableViewDeprecatedSPI:(UITableView*)tableView NS_SWIFT_NAME(numberOfSections(for:)); +- (NSInteger)_numberOfRowsInSectionDeprecatedSPI:(NSInteger)section tableView:(UITableView*)tableView NS_SWIFT_NAME(numberOfRows(inSection:tableView:)); +- (UITableViewCell*)_cellForRowAtIndexPathDeprecatedSPI:(NSIndexPath*)indexPath tableView:(UITableView*)tableView NS_SWIFT_NAME(cellForRow(at:tableView:)); + +@end + + +API_AVAILABLE(ios(13.0), tvos(13.0)) +@interface __UIDiffableDataSourceSnapshot : __UIDiffableDataSource +- (instancetype)init; +@end + +#if __has_feature(nullability) +#pragma clang assume_nonnull end +#endif + + +#endif // TARGET_OS_TV || TARGET_OS_IOS + #endif // SWIFT_STDLIB_SHIMS_UIKIT_OVERLAY_H diff --git a/stdlib/public/core/BridgeObjectiveC.swift b/stdlib/public/core/BridgeObjectiveC.swift index f38362fc2084a..becebdabc82bd 100644 --- a/stdlib/public/core/BridgeObjectiveC.swift +++ b/stdlib/public/core/BridgeObjectiveC.swift @@ -85,11 +85,13 @@ public protocol _ObjectiveCBridgeable { #if _runtime(_ObjC) -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) -@available(*, deprecated) +// Note: This function is not intended to be called from Swift. The +// availability information here is perfunctory; this function isn't considered +// part of the Stdlib's Swift ABI. +@available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) @_cdecl("_SwiftCreateBridgedArray") @usableFromInline -internal func _SwiftCreateBridgedArray( +internal func _SwiftCreateBridgedArray_DoNotCall( values: UnsafePointer, numValues: Int ) -> Unmanaged { @@ -98,11 +100,13 @@ internal func _SwiftCreateBridgedArray( return Unmanaged.passRetained(bridged) } -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) -@available(*, deprecated) +// Note: This function is not intended to be called from Swift. The +// availability information here is perfunctory; this function isn't considered +// part of the Stdlib's Swift ABI. +@available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) @_cdecl("_SwiftCreateBridgedMutableArray") @usableFromInline -internal func _SwiftCreateBridgedMutableArray( +internal func _SwiftCreateBridgedMutableArray_DoNotCall( values: UnsafePointer, numValues: Int ) -> Unmanaged { diff --git a/stdlib/public/core/StringBridge.swift b/stdlib/public/core/StringBridge.swift index 6a15eb28a865a..9d48925bab045 100644 --- a/stdlib/public/core/StringBridge.swift +++ b/stdlib/public/core/StringBridge.swift @@ -483,14 +483,13 @@ extension String { } } -@available(macOS, introduced: 9999, deprecated) -@available(iOS, introduced: 9999, deprecated) -@available(watchOS, introduced: 9999, deprecated) -@available(tvOS, introduced: 9999, deprecated) -@available(*, deprecated) +// Note: This function is not intended to be called from Swift. The +// availability information here is perfunctory; this function isn't considered +// part of the Stdlib's Swift ABI. +@available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) @_cdecl("_SwiftCreateBridgedString") @usableFromInline -internal func _SwiftCreateBridgedString( +internal func _SwiftCreateBridgedString_DoNotCall( bytes: UnsafePointer, length: Int, encoding: _swift_shims_CFStringEncoding @@ -529,7 +528,7 @@ public func _getDescription(_ x: T) -> AnyObject { @_silgen_name("swift_stdlib_NSStringFromUTF8") @usableFromInline //this makes the symbol available to the runtime :( -@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) +@available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) internal func _NSStringFromUTF8(_ s: UnsafePointer, _ len: Int) -> AnyObject { return String( diff --git a/test/AutoDiff/Parse/differentiable_attr_parse.swift b/test/AutoDiff/Parse/differentiable_attr_parse.swift index eb94ff59728ab..251be24a5923c 100644 --- a/test/AutoDiff/Parse/differentiable_attr_parse.swift +++ b/test/AutoDiff/Parse/differentiable_attr_parse.swift @@ -9,32 +9,17 @@ struct Foo { var x: Float } -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: foo(_:_:)) // okay +@differentiable // okay func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: foo(_:_:)) // okay -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay +@differentiable(where T : FloatingPoint) // okay func bar(_ x: T, _: T) -> T { return 1 + x } -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: foo(_:_:)) // okay -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay -func bar(_ x: Float, _ y: Float) -> Float { - return 1 + x -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: foo(_:_:)) // okay -// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay +@differentiable(wrt: (self, x, y)) // okay func bar(_ x: Float, _ y: Float) -> Float { return 1 + x } @@ -67,8 +52,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float { } @_transparent -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay +@differentiable(wrt: (self)) // okay public func squareRoot() -> Self { var lhs = self lhs.formSquareRoot() @@ -112,44 +96,37 @@ func two(x: Float, y: Float) -> Float { /// Bad -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(3) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(foo(_:_:)) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} -@differentiable(vjp: foo(_:_:), 3) -func bar(_ x: Float, _: Float) -> Float { - return 1 + x -} - -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(wrt: (x), foo(_:_:)) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(wrt: x, y) func bar(_ x: Float, _ y: Float) -> Float { return 1 + x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(wrt: 0, 1) func two(x: Float, y: Float) -> Float { return x + y } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(wrt: 0, y) func two(x: Float, y: Float) -> Float { return x + y @@ -161,55 +138,51 @@ func two(x: Float, y: Float) -> Float { return x + y } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{expected ')' in 'differentiable' attribute}} -@differentiable(vjp: foo(_:_:) +@differentiable(wrt: (x) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{expected ':' or '==' to indicate a conformance or same-type requirement}} -@differentiable(vjp: foo(_:_:) where T) +@differentiable(wrt: (x) where T) func bar(_ x: T, _: T) -> T { return 1 + x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(,) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{unexpected ',' separator}} -@differentiable(vjp: foo(_:_:),) +@differentiable(wrt: (x),) func bar(_ x: Float, _: Float) -> Float { return 1 + x } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{unexpected ',' separator}} -@differentiable(vjp: foo(_:_:), where T) +@differentiable(wrt: (x), where T) func bar(_ x: T, _: T) -> T { return 1 + x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} +// expected-error @+1 {{expected 'wrt:' or 'where'}} @differentiable(wrt: x, linear) func slope4(_ x: Float) -> Float { return 4 * x } -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} -@differentiable(wrt: x, linear, vjp: const5) +// expected-error @+1 {{expected 'wrt:' or 'where'}} +@differentiable(wrt: x, linear) func slope5(_ x: Float) -> Float { return 5 * x } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}} -@differentiable(wrt: x, vjp: const6, linear) -func slope5(_ x: Float) -> Float { - return 6 * x +// Test removed `jvp:' and 'vjp:' arguments. +// expected-error @+1 {{expected 'wrt:' or 'where' in '@differentiable' attribute}} +@differentiable(jvp: foo, vjp: foo) +func bar(_ x: Float, _: Float) -> Float { + return 1 + x } diff --git a/test/AutoDiff/SIL/Parse/sildeclref_parse.sil b/test/AutoDiff/SIL/Parse/sildeclref.sil similarity index 98% rename from test/AutoDiff/SIL/Parse/sildeclref_parse.sil rename to test/AutoDiff/SIL/Parse/sildeclref.sil index 9c7452949d5ae..94d7df9325251 100644 --- a/test/AutoDiff/SIL/Parse/sildeclref_parse.sil +++ b/test/AutoDiff/SIL/Parse/sildeclref.sil @@ -1,5 +1,4 @@ // RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -module-name=sildeclref_parse | %target-sil-opt -enable-experimental-differentiable-programming -module-name=sildeclref_parse | %FileCheck %s -// REQUIRES: differentiable_programming // Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions. diff --git a/test/AutoDiff/SIL/Serialization/differentiation.swift b/test/AutoDiff/SIL/Serialization/differentiable_function.swift similarity index 100% rename from test/AutoDiff/SIL/Serialization/differentiation.swift rename to test/AutoDiff/SIL/Serialization/differentiable_function.swift diff --git a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil b/test/AutoDiff/SIL/differentiability_witness_function_inst.sil similarity index 98% rename from test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil rename to test/AutoDiff/SIL/differentiability_witness_function_inst.sil index 496f6654c2ccc..de2328be4f50b 100644 --- a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil +++ b/test/AutoDiff/SIL/differentiability_witness_function_inst.sil @@ -14,10 +14,8 @@ // IRGen test. // RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s --check-prefix=IRGEN --check-prefix %target-cpu -// NOTE: `%target-cpu`-specific FileCheck lines exist because lowered function -// types in LLVM IR differ between architectures. +// NOTE: `%target-cpu`-specific FileCheck lines exist because lowered function types in LLVM IR differ between architectures. -// REQUIRES: differentiable_programming // NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. // REQUIRES: shell diff --git a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil b/test/AutoDiff/SIL/differentiability_witness_function_inst_transpose.sil similarity index 99% rename from test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil rename to test/AutoDiff/SIL/differentiability_witness_function_inst_transpose.sil index dc8289d8827ae..04d4675b6befd 100644 --- a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil +++ b/test/AutoDiff/SIL/differentiability_witness_function_inst_transpose.sil @@ -15,7 +15,6 @@ // RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil // RUN: %target-sil-opt %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s -// REQUIRES: differentiable_programming // NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. // REQUIRES: shell diff --git a/test/AutoDiff/SIL/differentiable_function_inst.sil b/test/AutoDiff/SIL/differentiable_function_inst.sil new file mode 100644 index 0000000000000..f7b27ac5e6a92 --- /dev/null +++ b/test/AutoDiff/SIL/differentiable_function_inst.sil @@ -0,0 +1,83 @@ +// Round-trip parsing/printing test. + +// RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -emit-sorted-sil | %FileCheck %s --check-prefix=CHECK-SIL + +// Round-trip serialization-deserialization test. + +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -emit-sib -o %t/tmp.sib -module-name main +// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp.sib -o %t/tmp.sil -module-name main +// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files. +// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil +// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s --check-prefix=CHECK-SIL + +// IRGen test. + +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-ir %s | %FileCheck %s --check-prefix=CHECK-IRGEN + +// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. +// REQUIRES: shell + +// Would need to update for arm64e. +// UNSUPPORTED: CPU=arm64e + +sil_stage raw + +import Swift +import Builtin + +import _Differentiation + +sil @function : $@convention(thin) (Float) -> Float +sil @function_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @make_differentiable_function : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float { +bb0: + %orig_fn = function_ref @function : $@convention(thin) (Float) -> Float + %vjp_fn = function_ref @function_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + %diff_fn = differentiable_function [parameters 0] %orig_fn : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp_fn : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + %extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@differentiable @convention(thin) (Float) -> Float + %extracted_original = differentiable_function_extract [original] %diff_fn : $@differentiable @convention(thin) (Float) -> Float + return %diff_fn : $@differentiable @convention(thin) (Float) -> Float +} + +// CHECK-SIL-LABEL: @make_differentiable_function : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float { +// CHECK-SIL: [[ORIG_FN:%.*]] = function_ref @function : $@convention(thin) (Float) -> Float +// CHECK-SIL: [[VJP_FN:%.*]] = function_ref @function_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-SIL: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[VJP_FN]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} +// CHECK-SIL: [[EXTRACTED_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable @convention(thin) (Float) -> Float +// CHECK-SIL: [[EXTRACTED_ORIG_FN:%.*]] = differentiable_function_extract [original] [[DIFF_FN]] : $@differentiable @convention(thin) (Float) -> Float +// CHECK-SIL: return [[DIFF_FN]] : $@differentiable @convention(thin) (Float) -> Float + +// CHECK-IRGEN-LABEL: define {{.*}}swiftcc { i8*, i8*, i8* } @make_differentiable_function() +// CHECK-IRGEN-NEXT: entry: +// CHECK-IRGEN-NEXT: ret { i8*, i8*, i8* } { i8* bitcast (float (float)* @function to i8*), i8* undef, i8* bitcast ({ float, i8*, %swift.refcounted* } (float)* @function_vjp to i8*) } + +sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float +sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float + +// CHECK-SIL-LABEL: sil @test_roundtrip_parse +sil @test_roundtrip_parse : $@convention(thin) () -> () { +bb0: + %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float + %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float with_derivative {undef : $@convention(thin) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float, Float) -> Float), undef : $@convention(thin) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float, Float))} + + // CHECK-SIL: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float + %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float with_derivative {undef : $@convention(thin) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@convention(thin) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + + // CHECK-SIL: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @noDerivative Float, @noDerivative Float) -> Float + %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @noDerivative Float, @noDerivative Float) -> Float + %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float + %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float with_derivative {undef : $@convention(method) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float, Float) -> Float), undef : $@convention(method) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float, Float))} + + // CHECK-SIL: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float + %8 = differentiable_function [parameters 0] %5 : $@convention(method) (Float, Float, Float) -> Float with_derivative {undef : $@convention(method) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@convention(method) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + + // CHECK-SIL: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @noDerivative Float, @noDerivative Float) -> Float + %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @noDerivative Float, @noDerivative Float) -> Float + + %ret = tuple () + return %ret : $() +} diff --git a/test/AutoDiff/SIL/Serialization/sil_differentiability_witness.sil b/test/AutoDiff/SIL/sil_differentiability_witness.sil similarity index 99% rename from test/AutoDiff/SIL/Serialization/sil_differentiability_witness.sil rename to test/AutoDiff/SIL/sil_differentiability_witness.sil index 4db87196b6573..550ad0eaf4c5b 100644 --- a/test/AutoDiff/SIL/Serialization/sil_differentiability_witness.sil +++ b/test/AutoDiff/SIL/sil_differentiability_witness.sil @@ -15,7 +15,6 @@ // RUN: %target-swift-frontend -emit-ir %s | %FileCheck --check-prefix=IRGEN %s -// REQUIRES: differentiable_programming // NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. // REQUIRES: shell diff --git a/test/AutoDiff/SILGen/autodiff_builtins.swift b/test/AutoDiff/SILGen/autodiff_builtins.swift new file mode 100644 index 0000000000000..659244178a923 --- /dev/null +++ b/test/AutoDiff/SILGen/autodiff_builtins.swift @@ -0,0 +1,136 @@ +// RUN: %target-swift-frontend -parse-stdlib -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s + +import _Differentiation +import Swift + +@_silgen_name("f_direct_arity1") +func f_direct_arity1(_ x: Float) -> Float { + x +} + +@_silgen_name("f_direct_arity1_jvp") +func f_direct_arity1_jvp(_ x: Float) -> (Float, (Float) -> Float) { + (x, { $0 }) +} + +@_silgen_name("f_direct_arity1_vjp") +func f_direct_arity1_vjp(_ x: Float) -> (Float, (Float) -> Float) { + (x, { $0 }) +} + +@_silgen_name("f_direct_arity2") +func f_direct_arity2(_ x: Float, _ y: Float) -> Float { + x +} + +@_silgen_name("f_indirect_arity1") +func f_indirect_arity1(_ x: T) -> T { + x +} + +// MARK: - applyDerivative + +@_silgen_name("applyDerivative_f_direct_arity1_jvp") +func applyDerivative_f1_jvp(_ x: Float) -> (Float, (Float) -> Float) { + return Builtin.applyDerivative_jvp(f_direct_arity1, x) +} +// CHECK-LABEL: sil{{.*}}@applyDerivative_f_direct_arity1_jvp +// CHECK: bb0([[X:%.*]] : $Float): +// CHECK: [[D:%.*]] = differentiable_function_extract [jvp] +// CHECK: [[D_RESULT:%.*]] = apply [[D]]([[X]]) +// CHECK: ([[D_RESULT_0:%.*]], [[D_RESULT_1:%.*]]) = destructure_tuple [[D_RESULT]] +// CHECK: [[D_RESULT_RETUPLED:%.*]] = tuple ([[D_RESULT_0]] : {{.*}}, [[D_RESULT_1]] : {{.*}}) +// CHECK: return [[D_RESULT_RETUPLED]] + +@_silgen_name("applyDerivative_f_direct_arity1_vjp") +func applyDerivative_f1_vjp(_ x: Float) -> (Float, (Float) -> Float) { + return Builtin.applyDerivative_vjp(f_direct_arity1, x) +} +// CHECK-LABEL: sil{{.*}}@applyDerivative_f_direct_arity1_vjp +// CHECK: bb0([[X:%.*]] : $Float): +// CHECK: [[D:%.*]] = differentiable_function_extract [vjp] +// CHECK: [[D_RESULT:%.*]] = apply [[D]]([[X]]) +// CHECK: ([[D_RESULT_0:%.*]], [[D_RESULT_1:%.*]]) = destructure_tuple [[D_RESULT]] +// CHECK: [[D_RESULT_RETUPLED:%.*]] = tuple ([[D_RESULT_0]] : {{.*}}, [[D_RESULT_1]] : {{.*}}) +// CHECK: return [[D_RESULT_RETUPLED]] + +@_silgen_name("applyDerivative_f_direct_arity2_vjp") +func applyDerivative_f1_vjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)) { + return Builtin.applyDerivative_vjp_arity2(f_direct_arity2, x, y) +} +// CHECK-LABEL: sil{{.*}}@applyDerivative_f_direct_arity2_vjp +// CHECK: bb0([[X:%.*]] : $Float, [[Y:%.*]] : $Float): +// CHECK: [[D:%.*]] = differentiable_function_extract [vjp] +// CHECK: [[D_RESULT:%.*]] = apply [[D]]([[X]], [[Y]]) +// CHECK: ([[D_RESULT_0:%.*]], [[D_RESULT_1:%.*]]) = destructure_tuple [[D_RESULT]] +// CHECK: [[D_RESULT_RETUPLED:%.*]] = tuple ([[D_RESULT_0]] : {{.*}}, [[D_RESULT_1]] : {{.*}}) +// CHECK: return [[D_RESULT_RETUPLED]] + +@_silgen_name("applyDerivative_f_indirect_arity1_vjp") +func applyDerivative_f1_vjp(t0: T) -> (T, (T.TangentVector) -> T.TangentVector) { + return Builtin.applyDerivative_vjp(f_indirect_arity1, t0) +} +// CHECK-LABEL: sil{{.*}}@applyDerivative_f_indirect_arity1_vjp +// CHECK: bb0([[ORIG_RESULT_OUT_PARAM:%.*]] : $*T, [[X:%.]] : $*T): +// CHECK: [[D:%.*]] = differentiable_function_extract [vjp] +// CHECK: [[D_RESULT_BUFFER:%.*]] = alloc_stack $(T, @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for ) +// CHECK: [[D_RESULT_BUFFER_0_FOR_STORE:%.*]] = tuple_element_addr [[D_RESULT_BUFFER]] : ${{.*}}, 0 +// CHECK: [[D_RESULT:%.*]] = apply [[D]]([[D_RESULT_BUFFER_0_FOR_STORE]], [[X]]) +// CHECK: [[D_RESULT_BUFFER_1_FOR_STORE:%.*]] = tuple_element_addr [[D_RESULT_BUFFER]] : ${{.*}}, 1 +// CHECK: store [[D_RESULT]] to [init] [[D_RESULT_BUFFER_1_FOR_STORE]] +// CHECK: [[D_RESULT_BUFFER_0_FOR_LOAD:%.*]] = tuple_element_addr [[D_RESULT_BUFFER]] : ${{.*}}, 0 +// CHECK: [[D_RESULT_BUFFER_1_FOR_LOAD:%.*]] = tuple_element_addr [[D_RESULT_BUFFER]] : ${{.*}}, 1 +// CHECK: [[PULLBACK:%.*]] = load [take] [[D_RESULT_BUFFER_1_FOR_LOAD]] +// CHECK: copy_addr [take] [[D_RESULT_BUFFER_0_FOR_LOAD]] to [initialization] [[ORIG_RESULT_OUT_PARAM]] +// CHECK: return [[PULLBACK]] + +// MARK: - applyTranspose +// TODO(TF-1142): Add linear_function_extracts to this test when they exist. + +@_silgen_name("applyTranspose_f_direct_arity1") +func applyTranspose_f_direct_arity1(_ x: Float) -> Float { + return Builtin.applyTranspose_arity1(f_direct_arity1, x) +} +// CHECK-LABEL: sil{{.*}}@applyTranspose_f_direct_arity1 +// CHECK: bb0([[X:%.*]] : $Float): +// CHECK: [[RESULT:%.*]] = apply undef([[X]]) +// CHECK: return [[RESULT]] + +@_silgen_name("applyTranspose_f_direct_arity2") +func applyTranspose_f_direct_arity2(_ x: Float) -> (Float, Float) { + return Builtin.applyTranspose_arity2(f_direct_arity2, x) +} +// CHECK-LABEL: sil{{.*}}@applyTranspose_f_direct_arity2 +// CHECK: bb0([[X:%.*]] : $Float) +// CHECK: [[RESULT:%.*]] = apply undef([[X]]) +// CHECK: ([[RESULT_0:%.*]], [[RESULT_1:%.*]]) = destructure_tuple [[RESULT]] +// CHECK: [[RETUPLED_RESULT:%.*]] = tuple ([[RESULT_0]] : $Float, [[RESULT_1]] : $Float) +// CHECK: return [[RETUPLED_RESULT]] + +@_silgen_name("applyTranspose_f_indirect_arity1") +func applyTranspose_f_indirect_arity1(_ x: T) -> T { + return Builtin.applyTranspose_arity1(f_indirect_arity1, x) +} +// CHECK-LABEL: sil{{.*}}@applyTranspose_f_indirect_arity1 +// CHECK: bb0([[OUT_PARAM:%.*]] : $*T, [[X:%.*]] : $*T): +// CHECK: [[RESULT:%.*]] = apply [[TRANSPOSE:%.*]]([[OUT_PARAM]], [[X]]) + +// MARK: - differentiableFunction + +@_silgen_name("differentiableFunction_f_direct_arity1") +func differentiableFunction_f_direct_arity1() -> @differentiable (Float) -> Float { + return Builtin.differentiableFunction_arity1(f_direct_arity1, f_direct_arity1_jvp, f_direct_arity1_vjp) +} +// CHECK-LABEL: sil{{.*}}@differentiableFunction_f_direct_arity1 +// CHECK: [[DIFF_FN:%.*]] = differentiable_function +// CHECK: return [[DIFF_FN]] + +// MARK: - linearFunction +// TODO(TF-1142): Add linear_funcion to this test when it exists. + +@_silgen_name("linearFunction_f_direct_arity1") +func linearFunction_f_direct_arity1() -> @differentiable(linear) (Float) -> Float { + return Builtin.linearFunction_arity1(f_direct_arity1, f_direct_arity1) +} +// CHECK-LABEL: sil{{.*}}@linearFunction_f_direct_arity1 +// CHECK: return undef diff --git a/test/AutoDiff/SILGen/differentiable_function.swift b/test/AutoDiff/SILGen/differentiable_function.swift index bd1282fa59199..e298eb690a55e 100644 --- a/test/AutoDiff/SILGen/differentiable_function.swift +++ b/test/AutoDiff/SILGen/differentiable_function.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %FileCheck %s -// REQUIRES: differentiable_programming // Test SILGen for `@differentiable` function typed values. diff --git a/test/AutoDiff/SILGen/sil_differentiability_witness_silgen.swift b/test/AutoDiff/SILGen/sil_differentiability_witness.swift similarity index 70% rename from test/AutoDiff/SILGen/sil_differentiability_witness_silgen.swift rename to test/AutoDiff/SILGen/sil_differentiability_witness.swift index a0124eedc6e60..72657e3cbe7be 100644 --- a/test/AutoDiff/SILGen/sil_differentiability_witness_silgen.swift +++ b/test/AutoDiff/SILGen/sil_differentiability_witness.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend -emit-silgen -enable-experimental-differentiable-programming %s | %target-sil-opt -enable-experimental-differentiable-programming | %FileCheck %s -// REQUIRES: differentiable_programming // Test SIL differentiability witness SIL generation. @@ -31,9 +30,9 @@ public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { } // CHECK-LABEL: // differentiability witness for foo(_:) -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3fooyS2fF : $@convention(thin) (Float) -> Float { -// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3fooyS2fF : $@convention(thin) (Float) -> Float { +// CHECK-NEXT: jvp: @AD__$s29sil_differentiability_witness3fooyS2fF__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: vjp: @AD__$s29sil_differentiability_witness3fooyS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-NEXT: } // Test internal non-generic function. @@ -50,8 +49,8 @@ public func bar_jvp(_ x: Float, _ y: T) -> (value: Float, differential: (Floa } // CHECK-LABEL: // differentiability witness for bar(_:_:) -// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0> @$s36sil_differentiability_witness_silgen3baryS2f_xtlF : $@convention(thin) (Float, @in_guaranteed T) -> Float { -// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3baryS2f_xtlF__jvp_src_0_wrt_0_l : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0> @$s29sil_differentiability_witness3baryS2f_xtlF : $@convention(thin) (Float, @in_guaranteed T) -> Float { +// CHECK-NEXT: jvp: @AD__$s29sil_differentiability_witness3baryS2f_xtlF__jvp_src_0_wrt_0_l : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-NEXT: } // Test internal generic function. @@ -76,9 +75,9 @@ func generic_vjp(_ x: T, _ y: Float) -> ( } // CHECK-LABEL: // differentiability witness for generic(_:_:) -// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @$s36sil_differentiability_witness_silgen7genericyxx_SftlF : $@convention(thin) (@in_guaranteed T, Float) -> @out T { -// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__jvp_src_0_wrt_0_1_{{s|16_Differentiation}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>) -// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__vjp_src_0_wrt_0_1_{{s|16_Differentiation}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>) +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @$s29sil_differentiability_witness7genericyxx_SftlF : $@convention(thin) (@in_guaranteed T, Float) -> @out T { +// CHECK-NEXT: jvp: @AD__$s29sil_differentiability_witness7genericyxx_SftlF__jvp_src_0_wrt_0_1_{{s|16_Differentiation}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>) +// CHECK-NEXT: vjp: @AD__$s29sil_differentiability_witness7genericyxx_SftlF__vjp_src_0_wrt_0_1_{{s|16_Differentiation}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>) // CHECK-NEXT: } public struct Foo: Differentiable { @@ -89,7 +88,7 @@ public struct Foo: Differentiable { public var x: Float // CHECK-LABEL: // differentiability witness for Foo.x.getter -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV1xSfvg : $@convention(method) (Foo) -> Float { // CHECK-NEXT: } @differentiable @@ -98,7 +97,7 @@ public struct Foo: Differentiable { } // CHECK-LABEL: // differentiability witness for Foo.init(_:) -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVyACSfcfC : $@convention(method) (Float, @thin Foo.Type) -> Foo { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooVyACSfcfC : $@convention(method) (Float, @thin Foo.Type) -> Foo { // CHECK-NEXT: } @differentiable @@ -107,7 +106,7 @@ public struct Foo: Differentiable { } // CHECK-LABEL: // differentiability witness for Foo.method() -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV6methodSfyF : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV6methodSfyF : $@convention(method) (Foo) -> Float { // CHECK-NEXT: } @differentiable @@ -116,7 +115,7 @@ public struct Foo: Differentiable { } // CHECK-LABEL: // differentiability witness for Foo.computedProperty.getter -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV16computedPropertySfvg : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV16computedPropertySfvg : $@convention(method) (Foo) -> Float { // CHECK-NEXT: } @differentiable @@ -125,7 +124,7 @@ public struct Foo: Differentiable { } // CHECK-LABEL: // differentiability witness for Foo.subscript.getter -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVSfycig : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooVSfycig : $@convention(method) (Foo) -> Float { // CHECK-NEXT: } } @@ -161,17 +160,17 @@ public func wrt_subset_vjp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> } // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:) -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { // CHECK-NEXT: } // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:) -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 3] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { // CHECK-NEXT: jvp: // CHECK-NEXT: vjp: // CHECK-NEXT: } // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:) -// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 2 3] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float { // CHECK-NEXT: jvp: // CHECK-NEXT: vjp: // CHECK-NEXT: } @@ -191,8 +190,8 @@ extension P1 { } // CHECK-LABEL: // differentiability witness for P1.foo() -// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF : $@convention(method) (@in_guaranteed Self) -> Float { -// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF__vjp_src_0_wrt_0_36sil_differentiability_witness_silgen2P1Rzl : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s29sil_differentiability_witness2P1PAAE3fooSfyF : $@convention(method) (@in_guaranteed Self) -> Float { +// CHECK-NEXT: vjp: @AD__$s29sil_differentiability_witness2P1PAAE3fooSfyF__vjp_src_0_wrt_0_29sil_differentiability_witness2P1Rzl : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // CHECK-NEXT: } // Test custom derivatives of functions with generic signatures and `@differentiable` attributes. diff --git a/test/AutoDiff/SILGen/vtable.swift b/test/AutoDiff/SILGen/vtable.swift index 6763012887d3d..1f0d86412b2c5 100644 --- a/test/AutoDiff/SILGen/vtable.swift +++ b/test/AutoDiff/SILGen/vtable.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-silgen %s | %FileCheck %s -// REQUIRES: differentiable_programming // Test derivative function vtable entries for `@differentiable` class members: // - Methods. @@ -96,6 +95,19 @@ class Sub: Super { class SubSub: Sub {} +// Check vtable entry thunks. + +// CHECK-LABEL: sil hidden [transparent] [thunk] [ossa] @AD__${{.*}}5SuperC6methody{{.*}}jvp_src_0_wrt_0_vtable_entry_thunk : $@convention(method) (Float, Float, @guaranteed Super) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : @guaranteed $Super): +// CHECK: %3 = function_ref @$s6vtable5SuperC6methodyS2f_SftF : $@convention(method) (Float, Float, @guaranteed Super) -> Float +// CHECK: %4 = differentiable_function [parameters 0] %3 : $@convention(method) (Float, Float, @guaranteed Super) -> Float +// CHECK: %5 = differentiable_function_extract [jvp] %4 : $@differentiable @convention(method) (Float, @noDerivative Float, @noDerivative @guaranteed Super) -> Float +// CHECK: %6 = apply %5(%0, %1, %2) : $@convention(method) (Float, Float, @guaranteed Super) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// CHECK: return %6 : $(Float, @callee_guaranteed (Float) -> Float) +// CHECK: } + +// Check vtable entries: new vs `[override]` vs `[inherited]` entries. + // CHECK-LABEL: sil_vtable Super { // CHECK: #Super.method: (Super) -> (Float, Float) -> Float : @$s6vtable5SuperC6methodyS2f_SftF // CHECK: #Super.method!jvp.SUU: (Super) -> (Float, Float) -> Float : @AD__$s6vtable5SuperC6methodyS2f_SftF__jvp_src_0_wrt_0_vtable_entry_thunk diff --git a/test/AutoDiff/SILGen/witness_table.swift b/test/AutoDiff/SILGen/witness_table.swift index 65ebf5bf2c99c..a7c0da0d23aa1 100644 --- a/test/AutoDiff/SILGen/witness_table.swift +++ b/test/AutoDiff/SILGen/witness_table.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-silgen %s | %FileCheck %s -// REQUIRES: differentiable_programming // Test derivative function witness table entries for `@differentiable` protocol requirements. @@ -36,12 +35,32 @@ struct Struct: Protocol { } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] + // CHECK: apply [[JVP_FN]] + // CHECK: } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SUU : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] + // CHECK: apply [[VJP_FN]] + // CHECK: } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_jvp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, Double, @in_guaranteed τ_0_0) -> Float for ) { + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 1 2] [[ORIG_FN]] + // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] + // CHECK: apply [[JVP_FN]] + // CHECK: } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}method{{.*}}_vjp_SSS : $@convention(witness_method: Protocol) (Float, Double, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, Double, @out τ_0_0) for ) { + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}method{{.*}} : $@convention(method) (Float, Double, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 1 2] [[ORIG_FN]] + // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] + // CHECK: apply [[VJP_FN]] + // CHECK: } @differentiable var property: Float { @@ -50,9 +69,18 @@ struct Struct: Protocol { } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_jvp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for ) { + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}property{{.*}} : $@convention(method) (Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] + // CHECK: apply [[JVP_FN]] + // CHECK: } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__${{.*}}property{{.*}}_vjp_S : $@convention(witness_method: Protocol) (@in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for ) { - + // CHECK: [[ORIG_FN:%.*]] = function_ref {{.*}}property{{.*}} : $@convention(method) (Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] + // CHECK: apply [[VJP_FN]] + // CHECK: } @differentiable(wrt: x) subscript(_ x: Float, _ y: Float) -> Float { @@ -61,8 +89,18 @@ struct Struct: Protocol { } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] + // CHECK: apply [[JVP_FN]] + // CHECK: } // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [[ORIG_FN]] + // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] + // CHECK: apply [[VJP_FN]] + // CHECK: } } // CHECK-LABEL: sil_witness_table hidden Struct: Protocol module witness_table { diff --git a/test/AutoDiff/Sema/DerivedConformances/Inputs/struct_additive_arithmetic_other_module.swift b/test/AutoDiff/Sema/DerivedConformances/Inputs/struct_additive_arithmetic_other_module.swift new file mode 100644 index 0000000000000..92f45c786c4e5 --- /dev/null +++ b/test/AutoDiff/Sema/DerivedConformances/Inputs/struct_additive_arithmetic_other_module.swift @@ -0,0 +1,12 @@ +// expected-note @+1 3 {{type declared here}} +struct OtherFileNonconforming : Equatable { + var int: Int + var float: Float +} + +// expected-note @+1 3 {{type declared here}} +struct GenericOtherFileNonconforming : Equatable { + var x: T + var int: Int + var float: Float +} diff --git a/test/AutoDiff/Sema/DerivedConformances/struct_additive_arithmetic.swift b/test/AutoDiff/Sema/DerivedConformances/struct_additive_arithmetic.swift new file mode 100644 index 0000000000000..4c0048b2de8ad --- /dev/null +++ b/test/AutoDiff/Sema/DerivedConformances/struct_additive_arithmetic.swift @@ -0,0 +1,117 @@ +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -typecheck -verify -primary-file %s %S/Inputs/struct_additive_arithmetic_other_module.swift + +import _Differentiation + +func testAdditiveArithmetic( + _ x: inout T +) { + // Test `AdditiveArithmetic` requirements: `zero`, `+`, `-`. + let zero = T.zero + x += x + zero + x -= x - zero +} + +struct Empty : AdditiveArithmetic {} +func testEmpty() { + var empty = Empty() + testAdditiveArithmetic(&empty) +} + +struct Int2: AdditiveArithmetic { + var a: Int + var b: Int +} +func testInt2() { + var int2 = Int2(a: 1, b: 1) + testAdditiveArithmetic(&int2) +} + +// Test generic type. +struct Vector2: AdditiveArithmetic { + var x: T + var y: T +} +func testVector2() { + var vec2 = Vector2(x: 1, y: 1) + testAdditiveArithmetic(&vec2) +} + +// Test nested type. +struct Nested: AdditiveArithmetic { + var int2: Int2 + var int: Int +} +func testNested(int2: Int2) { + var nested = Nested(int2: int2, int: 1) + testAdditiveArithmetic(&nested) +} + +// Test mixed type. +// Note: `Numeric` refines `AdditiveArithmetic`. +struct Mixed: AdditiveArithmetic { + var nested: Nested + var float: Float + var uint8: UInt8 +} +func testMixed(nested: Nested) { + var mixed = Mixed(nested: nested, float: 1, uint8: 1) + testAdditiveArithmetic(&mixed) +} + +// Test type in generic context. +struct A { + struct B { + struct GenericContextNested : AdditiveArithmetic { + var nested: Nested + var float: Float + var uint8: UInt8 + } + } +} +func testGenericContext(nested: Nested) -> A.B.GenericContextNested { + var genericNested = + A.B.GenericContextNested(nested: nested, float: 1, uint8: 1) + testAdditiveArithmetic(&genericNested) + return genericNested +} + +// Test extension. +struct Extended { + var x: Int +} +extension Extended : Equatable, AdditiveArithmetic {} + +// Test extension of generic type. +struct GenericExtended { + var x: T +} +extension GenericExtended : Equatable, AdditiveArithmetic where T : AdditiveArithmetic {} + +// Test memberwise initializer synthesis. +struct NoMemberwiseInitializer : AdditiveArithmetic { + var value: T + init(randomLabel value: T) { self.value = value } +} +struct NoMemberwiseInitializerCustomZero: AdditiveArithmetic { + var x: Float + static var zero: Self { return NoMemberwiseInitializerCustomZero(0) } + init(_ x: Float) { + self.x = x + } +} +struct NoMemberwiseInitializerExtended { + var value: T + init(_ value: T) { + self.value = value + } +} +extension NoMemberwiseInitializerExtended: Equatable, AdditiveArithmetic + where T : AdditiveArithmetic {} + +// Test derived conformances in disallowed contexts. + +// expected-error @+1 3 {{implementation of 'AdditiveArithmetic' cannot be automatically synthesized in an extension in a different file to the type}} +extension OtherFileNonconforming : AdditiveArithmetic {} + +// expected-error @+1 3 {{implementation of 'AdditiveArithmetic' cannot be automatically synthesized in an extension in a different file to the type}} +extension GenericOtherFileNonconforming : AdditiveArithmetic {} diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 82060690e93c9..f846bad486f47 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend-typecheck -enable-experimental-differentiable-programming -verify %s -// REQUIRES: differentiable_programming import _Differentiation diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 74e05e5f32a2b..91c512f9ed196 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -1,14 +1,13 @@ // RUN: %target-swift-frontend-typecheck -enable-experimental-differentiable-programming -verify %s -// REQUIRES: differentiable_programming import _Differentiation // Dummy `Differentiable`-conforming type. -struct DummyTangentVector: Differentiable & AdditiveArithmetic { - static var zero: Self { Self() } - static func + (_: Self, _: Self) -> Self { Self() } - static func - (_: Self, _: Self) -> Self { Self() } - typealias TangentVector = Self +public struct DummyTangentVector: Differentiable & AdditiveArithmetic { + public static var zero: Self { Self() } + public static func + (_: Self, _: Self) -> Self { Self() } + public static func - (_: Self, _: Self) -> Self { Self() } + public typealias TangentVector = Self } @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} @@ -32,8 +31,7 @@ func testLocalVariables() { } } -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: dfoo) // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} +@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} protocol P {} @differentiable() // ok! @@ -154,7 +152,10 @@ struct DifferentiableInstanceMethod: Differentiable { } // Test subscript methods. -struct SubscriptMethod { +struct SubscriptMethod: Differentiable { + typealias TangentVector = DummyTangentVector + mutating func move(along _: TangentVector) {} + @differentiable // ok subscript(implicitGetter x: Float) -> Float { return x @@ -169,489 +170,19 @@ struct SubscriptMethod { subscript(explicit x: Float) -> Float { @differentiable // ok get { return x } - @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} + // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} + @differentiable set {} } subscript(x: Float, y: Float) -> Float { @differentiable // ok get { return x + y } - @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} - set {} - } -} - -// JVP - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(jvp: jvpSimpleJVP) -func jvpSimple(x: Float) -> Float { - return x -} - -func jvpSimpleJVP(x: Float) -> (Float, ((Float) -> Float)) { - return (x, { v in v }) -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: y, jvp: jvpWrtSubsetJVP) -func jvpWrtSubset1(x: Float, y: Float) -> Float { - return x + y -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: (y), jvp: jvpWrtSubsetJVP) -func jvpWrtSubset2(x: Float, y: Float) -> Float { - return x + y -} - -func jvpWrtSubsetJVP(x: Float, y: Float) -> (Float, (Float) -> Float) { - return (x + y, { v in v }) -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(jvp: jvp2ParamsJVP) -func jvp2Params(x: Float, y: Float) -> Float { - return x + y -} - -func jvp2ParamsJVP(x: Float, y: Float) -> (Float, (Float, Float) -> Float) { - return (x + y, { (a, b) in a + b }) -} - -// expected-error @+1 {{unknown parameter name 'y'}} -@differentiable(wrt: (y)) -func jvpUnknownParam(x: Float) -> Float { - return x -} - -// expected-error @+1 {{parameters must be specified in original order}} -@differentiable(wrt: (y, x)) -func jvpParamOrderNotIncreasing(x: Float, y: Float) -> Float { - return x * y -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{'jvpWrongTypeJVP' does not have expected type '(Float) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}} -@differentiable(jvp: jvpWrongTypeJVP) -func jvpWrongType(x: Float) -> Float { - return x -} - -func jvpWrongTypeJVP(x: Float) -> (Float, (Float) -> Int) { - return (x, { v in Int(v) }) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}} -@differentiable(jvp: jvpSimpleJVP) -func jvpNonDiffParam(x: Int) -> Float { - return Float(x) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} -@differentiable(jvp: jvpSimpleJVP) -func jvpNonDiffResult(x: Float) -> Int { - return Int(x) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but '(Float, Int)' does not conform to 'Differentiable'}} -@differentiable(jvp: jvpSimpleJVP) -func jvpNonDiffResult2(x: Float) -> (Float, Int) { - return (x, Int(x)) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{ambiguous reference to 'jvpAmbiguousVJP' in '@differentiable' attribute}} -@differentiable(jvp: jvpAmbiguousVJP) -func jvpAmbiguous(x: Float) -> Float { - return x -} -func jvpAmbiguousVJP(_ x: Float) -> (Float, (Float) -> Float) { - return (x, { $0 }) -} -func jvpAmbiguousVJP(x: Float) -> (Float, (Float) -> Float) { - return (x, { $0 }) -} - -class DifferentiableClassMethod { - // Direct differentiation case. - @differentiable - func foo(_ x: Float) -> Float { - return x - } -} - -struct JVPStruct { - @differentiable - let p: Float - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(wrt: (self), jvp: funcJVP) - func funcWrongType() -> Double { - fatalError("unimplemented") - } -} - -extension JVPStruct { - func funcJVP() -> (Float, (JVPStruct) -> Float) { - fatalError("unimplemented") - } -} - -extension JVPStruct: AdditiveArithmetic { - static var zero: JVPStruct { fatalError("unimplemented") } - static func + (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct { - fatalError("unimplemented") - } - static func - (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct { - fatalError("unimplemented") - } - typealias Scalar = Float - static func * (lhs: Float, rhs: JVPStruct) -> JVPStruct { - fatalError("unimplemented") - } -} - -extension JVPStruct: Differentiable { - typealias TangentVector = JVPStruct -} - -extension JVPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: x, jvp: wrtAllNonSelfJVP) - func wrtAllNonSelf(x: Float) -> Float { - return x + p - } - - func wrtAllNonSelfJVP(x: Float) -> (Float, (Float) -> Float) { - return (x + p, { v in v }) - } -} - -extension JVPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: (self, x), jvp: wrtAllJVP) - func wrtAll(x: Float) -> Float { - return x + p - } - - func wrtAllJVP(x: Float) -> (Float, (JVPStruct, Float) -> Float) { - return (x + p, { (a, b) in a.p + b }) - } -} - -extension JVPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(jvp: computedPropJVP) - var computedPropOk1: Float { - return 0 - } - - var computedPropOk2: Float { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(jvp: computedPropJVP) - get { - return 0 - } - } - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'computedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(jvp: computedPropJVP) - var computedPropWrongType: Double { - return 0 - } - - var computedPropWrongAccessor: Float { - get { - return 0 - } - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} - @differentiable(jvp: computedPropJVP) - set { - fatalError("unimplemented") - } - } - - func computedPropJVP() -> (Float, (JVPStruct) -> Float) { - fatalError("unimplemented") - } -} - -// VJP - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: vjpSimpleVJP) -func vjpSimple(x: Float) -> Float { - return x -} - -func vjpSimpleVJP(x: Float) -> (Float, ((Float) -> Float)) { - return (x, { v in v }) -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(wrt: (y), vjp: vjpWrtSubsetVJP) -func vjpWrtSubset(x: Float, y: Float) -> Float { - return x + y -} - -func vjpWrtSubsetVJP(x: Float, y: Float) -> (Float, (Float) -> Float) { - return (x + y, { v in v }) -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(vjp: vjp2ParamsVJP) -func vjp2Params(x: Float, y: Float) -> Float { - return x + y -} - -func vjp2ParamsVJP(x: Float, y: Float) -> (Float, (Float) -> (Float, Float)) { - return (x + y, { v in (v, v) }) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{'vjpWrongTypeVJP' does not have expected type '(Float) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}} -@differentiable(vjp: vjpWrongTypeVJP) -func vjpWrongType(x: Float) -> Float { - return x -} - -func vjpWrongTypeVJP(x: Float) -> (Float, (Float) -> Int) { - return (x, { v in Int(v) }) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}} -@differentiable(vjp: vjpSimpleVJP) -func vjpNonDiffParam(x: Int) -> Float { - return Float(x) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}} -@differentiable(vjp: vjpSimpleVJP) -func vjpNonDiffResult(x: Float) -> Int { - return Int(x) -} - -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but '(Float, Int)' does not conform to 'Differentiable'}} -@differentiable(vjp: vjpSimpleVJP) -func vjpNonDiffResult2(x: Float) -> (Float, Int) { - return (x, Int(x)) -} - -struct VJPStruct { - let p: Float - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: funcVJP) - func funcWrongType() -> Double { - fatalError("unimplemented") - } -} - -extension VJPStruct { - func funcVJP() -> (Float, (Float) -> VJPStruct) { - fatalError("unimplemented") - } -} - -extension VJPStruct: AdditiveArithmetic { - static var zero: VJPStruct { fatalError("unimplemented") } - static func + (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct { - fatalError("unimplemented") - } - static func - (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct { - fatalError("unimplemented") - } - typealias Scalar = Float - static func * (lhs: Float, rhs: VJPStruct) -> VJPStruct { - fatalError("unimplemented") - } -} - -extension VJPStruct: Differentiable { - typealias TangentVector = VJPStruct -} - -extension VJPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: x, vjp: wrtAllNonSelfVJP) - func wrtAllNonSelf(x: Float) -> Float { - return x + p - } - - func wrtAllNonSelfVJP(x: Float) -> (Float, (Float) -> Float) { - return (x + p, { v in v }) - } -} - -extension VJPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: (self, x), vjp: wrtAllVJP) - func wrtAll(x: Float) -> Float { - return x + p - } - - func wrtAllVJP(x: Float) -> (Float, (Float) -> (VJPStruct, Float)) { - fatalError("unimplemented") - } -} - -extension VJPStruct { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(vjp: computedPropVJP) - var computedPropOk1: Float { - return 0 - } - - var computedPropOk2: Float { - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(vjp: computedPropVJP) - get { - return 0 - } - } - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'computedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: computedPropVJP) - var computedPropWrongType: Double { - return 0 - } - - var computedPropWrongAccessor: Float { - get { - return 0 - } - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} - @differentiable(vjp: computedPropVJP) - set { - fatalError("unimplemented") - } - } - - func computedPropVJP() -> (Float, (Float) -> VJPStruct) { - fatalError("unimplemented") - } -} - -// expected-error @+2 {{empty 'where' clause in '@differentiable' attribute}} -// expected-error @+1 {{expected type}} -@differentiable(where) -func emptyWhereClause(x: T) -> T { - return x -} - -// expected-error @+1 {{'where' clause is valid only when original function is generic 'nongenericWhereClause(x:)'}} -@differentiable(where T: Differentiable) -func nongenericWhereClause(x: Float) -> Float { - return x -} - -// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(jvp: jvpWhere1, vjp: vjpWhere1 where T: Differentiable) -func where1(x: T) -> T { - return x -} -func jvpWhere1(x: T) -> (T, (T.TangentVector) -> T.TangentVector) { - return (x, { v in v }) -} -func vjpWhere1(x: T) -> (T, (T.TangentVector) -> T.TangentVector) { - return (x, { v in v }) -} - -// Test derivative functions with result tuple type labels. -// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels) -func derivativeResultLabels(_ x: Float) -> Float { - return x -} -func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) { - return (x, { $0 }) -} -func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { - return (x, { $0 }) -} -struct ResultLabelTest { - // expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels) - static func derivativeResultLabels(_ x: Float) -> Float { - return x - } - static func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) { - return (x, { $0 }) - } - static func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { - return (x, { $0 }) - } - - // expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels) - func derivativeResultLabels(_ x: Float) -> Float { - return x - } - func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) { - return (x, { $0 }) - } - func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { - return (x, { $0 }) - } -} - -struct Tensor: AdditiveArithmetic { - static var zero: Self { Self() } - static func + (_: Self, _: Self) -> Self { Self() } - static func - (_: Self, _: Self) -> Self { Self() } -} -extension Tensor: Differentiable where Scalar: Differentiable { - typealias TangentVector = Self -} -// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(jvp: jvpWhere2, vjp: vjpWhere2 where Scalar: Differentiable) -func where2(x: Tensor) -> Tensor { - return x -} -func jvpWhere2(x: Tensor) -> (Tensor, (Tensor) -> Tensor) { - return (x, { v in v }) -} -func vjpWhere2(x: Tensor) -> (Tensor, (Tensor) -> Tensor) { - return (x, { v in v }) -} - -struct A { - struct B { - @differentiable(wrt: x where T: Differentiable, V: Differentiable, V.TangentVector == V) - func whereInGenericContext(x: T) -> T { - return x - } - } -} - -extension FloatingPoint { - @differentiable(wrt: (self) where Self: Differentiable) - func whereClauseExtension() -> Self { - return self + @differentiable + set {} } } -// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}} -@differentiable(wrt: x, vjp: vjpNonvariadic) -func variadic(_ x: Float, indices: Int32...) -> Float { - return x -} -func vjpNonvariadic(_ x: Float, indices: [Int32]) -> (Float, (Float) -> Float) { - return (x, { $0 }) -} // expected-error @+3 {{type 'Scalar' constrained to non-protocol, non-class type 'Float'}} // expected-error @+2 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}} @@ -706,58 +237,98 @@ protocol ProtocolRequirements: Differentiable { } protocol ProtocolRequirementsRefined: ProtocolRequirements { - // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{3-3=@differentiable }} + // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} func f1(_ x: Float) -> Float } -// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}} -struct DiffAttrConformanceErrors: ProtocolRequirements { +// Test missing `@differentiable` attribute for internal protocol witnesses. +// No errors expected; internal `@differentiable` attributes are created. + +struct InternalDiffAttrConformance: ProtocolRequirements { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} var x: Float var y: Float - // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}} init(x: Float, y: Float) { self.x = x self.y = y } - // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}} init(x: Float, y: Int) { self.x = x self.y = Float(y) } - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} - // expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}} func amb(x: Float, y: Float) -> Float { return x } - // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{3-3=@differentiable(wrt: x) }} - // expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}} func amb(x: Float, y: Int) -> Float { return x } - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { return x } - // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} @differentiable(wrt: (self, x)) func f2(_ x: Float, _ y: Float) -> Float { return x + y } } +// Test missing `@differentiable` attribute for public protocol witnesses. Errors expected. + +// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}} +public struct PublicDiffAttrConformance: ProtocolRequirements { + public typealias TangentVector = DummyTangentVector + public mutating func move(along _: TangentVector) {} + + var x: Float + var y: Float + + // FIXME(TF-284): Fix unexpected diagnostic. + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}} + public init(x: Float, y: Float) { + self.x = x + self.y = y + } + + // FIXME(TF-284): Fix unexpected diagnostic. + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}} + public init(x: Float, y: Int) { + self.x = x + self.y = Float(y) + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }} + // expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}} + public func amb(x: Float, y: Float) -> Float { + return x + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{10-10=@differentiable(wrt: x) }} + // expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}} + public func amb(x: Float, y: Int) -> Float { + return x + } + + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + public func f1(_ x: Float) -> Float { + return x + } + + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} + @differentiable(wrt: (self, x)) + public func f2(_ x: Float, _ y: Float) -> Float { + return x + y + } +} + protocol ProtocolRequirementsWithDefault_NoConformingTypes { @differentiable func f1(_ x: Float) -> Float @@ -769,51 +340,38 @@ extension ProtocolRequirementsWithDefault_NoConformingTypes { } protocol ProtocolRequirementsWithDefault { - // expected-note @+2 {{protocol requires function 'f1'}} @differentiable func f1(_ x: Float) -> Float } extension ProtocolRequirementsWithDefault { - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { x } } -// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}} struct DiffAttrConformanceErrors2: ProtocolRequirementsWithDefault { - typealias TangentVector = DummyTangentVector - mutating func move(along _: TangentVector) {} - - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func f1(_ x: Float) -> Float { x } } protocol NotRefiningDiffable { @differentiable(wrt: x) - // expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}} func a(_ x: Float) -> Float } -// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}} struct CertainlyNotDiffableWrtSelf: NotRefiningDiffable { - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }} func a(_ x: Float) -> Float { return x * 5.0 } } - protocol TF285: Differentiable { @differentiable(wrt: (x, y)) @differentiable(wrt: x) - // expected-note @+1 {{protocol requires function 'foo(x:y:)' with type '(Float, Float) -> Float'; do you want to add a stub?}} func foo(x: Float, y: Float) -> Float } -// expected-error @+1 {{type 'TF285MissingOneDiffAttr' does not conform to protocol 'TF285'}} struct TF285MissingOneDiffAttr: TF285 { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} - // Requirement is missing an attribute. + // Requirement is missing the required `@differentiable(wrt: (x, y))` attribute. + // Since `TF285MissingOneDiffAttr.foo` is internal, the attribute is implicitly created. @differentiable(wrt: x) - // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}} {{3-3=@differentiable(wrt: (x, y)) }} func foo(x: Float, y: Float) -> Float { return x } @@ -826,9 +384,8 @@ struct TF_521 { var real: T var imaginary: T - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'TF_521' does not conform to 'Differentiable'}} - @differentiable(vjp: _vjpInit where T: Differentiable, T == T.TangentVector) + @differentiable(where T: Differentiable, T == T.TangentVector) init(real: T = 0, imaginary: T = 0) { self.real = real self.imaginary = imaginary @@ -838,14 +395,9 @@ struct TF_521 { extension TF_521: Differentiable where T: Differentiable { // expected-note @+1 {{possibly intended match 'TF_521.TangentVector' does not conform to 'AdditiveArithmetic'}} typealias TangentVector = TF_521 - typealias AllDifferentiableVariables = TF_521 -} -extension TF_521 where T: Differentiable, T == T.TangentVector { - static func _vjpInit(real: T, imaginary: T) -> (TF_521, (TF_521) -> (T, T)) { - return (TF_521(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) }) - } } -let _: @differentiable (Float, Float) -> TF_521 = { r, i in + +let _: @differentiable(Float, Float) -> TF_521 = { r, i in TF_521(real: r, imaginary: i) } @@ -882,65 +434,6 @@ struct NonDiffableStruct { } } -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(linear, wrt: x, vjp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' attribute for transpose registration instead}} -func slope1(_ x: Float) -> Float { - return 3 * x -} - -// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(linear, wrt: x, jvp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' attribute for transpose registration instead}} -func slope2(_ x: Float) -> Float { - return 3 * x -} - -// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} -@differentiable(linear, jvp: const3, vjp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' attribute for transpose registration instead}} -func slope3(_ x: Float) -> Float { - return 3 * x -} - -// Check that `@differentiable` attribute rejects stored properties. -struct StoredProperty: Differentiable { - typealias TangentVector = DummyTangentVector - mutating func move(along _: TangentVector) {} - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'@differentiable' attribute on stored property cannot specify 'jvp:' or 'vjp:'}} - @differentiable(vjp: vjpStored) - var stored: Float - - func vjpStored() -> (Float, (Float) -> TangentVector) { - (stored, { _ in .zero }) - } -} - -// Check that `@differentiable` attribute rejects non-`func` derivatives. -struct Struct: Differentiable { - typealias TangentVector = DummyTangentVector - mutating func move(along _: TangentVector) {} - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{registered derivative 'computedPropertyVJP' must be a 'func' declaration}} - @differentiable(vjp: computedPropertyVJP) - func testComputedProperty() -> Float { 1 } - var computedPropertyVJP: (Float, (Float) -> TangentVector) { - (1, { _ in .zero }) - } - - // expected-error @+1 {{expected a vjp function name}} - @differentiable(vjp: init) - func testInitializer() -> Struct { self } - init(_ x: Struct) {} - - // expected-error @+1 {{expected a vjp function name}} - @differentiable(vjp: subscript) - func testSubscript() -> Float { 1 } - subscript() -> (Float, (Float) -> TangentVector) { - (1, { _ in .zero }) - } -} - // Index based 'wrt:' struct NumberWrtStruct: Differentiable { @@ -1019,6 +512,18 @@ func two9(x: Float, y: Float) -> Float { return x + y } +// Inout 'wrt:' arguments. + +@differentiable(wrt: y) +func inout1(x: Float, y: inout Float) -> Void { + let _ = x + y +} +// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} +@differentiable(wrt: y) +func inout2(x: Float, y: inout Float) -> Float { + let _ = x + y +} + // Test refining protocol requirements with `@differentiable` attribute. public protocol Distribution { @@ -1047,11 +552,6 @@ protocol ProtocolRequirementUnsupported: Differentiable { // expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'where' clause}} @differentiable(where Scalar: Differentiable) func unsupportedWhereClause(value: Scalar) -> Float - - // expected-warning @+2 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'jvp:' or 'vjp:'}} - @differentiable(wrt: x, jvp: dfoo, vjp: dfoo) - func unsupportedDerivatives(_ x: Float) -> Float } extension ProtocolRequirementUnsupported { func dfoo(_ x: Float) -> (Float, (Float) -> Float) { @@ -1089,19 +589,13 @@ class Super: Differentiable { static func testStaticMethod(_ x: Float) -> Float { x } @differentiable(wrt: (self, x)) - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: x, vjp: vjp) + @differentiable(wrt: x) // expected-note @+1 2 {{overridden declaration is here}} func testMissingAttributes(_ x: Float) -> Float { x } - // expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - @differentiable(wrt: x, vjp: vjp) + @differentiable(wrt: x) func testSuperclassDerivatives(_ x: Float) -> Float { x } - final func vjp(_ x: Float) -> (Float, (Float) -> Float) { - fatalError() - } - // Test duplicate attributes with different derivative generic signatures. // expected-error @+1 {{duplicate '@differentiable' attribute with same parameters}} @differentiable(wrt: x where T: Differentiable) @@ -1109,11 +603,14 @@ class Super: Differentiable { @differentiable(wrt: x) func instanceMethod(_ x: Float, y: T) -> Float { x } - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} - @differentiable(vjp: vjpDynamicSelfResult) + @differentiable func dynamicSelfResult() -> Self { self } + // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} + @differentiable + var testDynamicSelfProperty: Self { self } + // TODO(TF-632): Fix "'TangentVector' is not a member type of 'Self'" diagnostic. // The underlying error should appear instead: // "covariant 'Self' can only appear at the top level of method result type". @@ -1124,14 +621,9 @@ class Super: Differentiable { } class Sub: Super { - // expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} {{12-12=@differentiable(wrt: x) }} - // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{12-12=@differentiable }} + // expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} + // expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} override func testMissingAttributes(_ x: Float) -> Float { x } - - // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'vjp' is not defined in the current type context}} - @differentiable(wrt: x, vjp: vjp) - override func testSuperclassDerivatives(_ x: Float) -> Float { x } } final class FinalClass: Differentiable { @@ -1182,15 +674,14 @@ extension InoutParameters { mutating func mutatingMethod(_ other: Self) -> Self {} } -// Test unsupported accessors: `set`, `_read`, `_modify`. +// Test accessors: `set`, `_read`, `_modify`. -struct UnsupportedAccessors: Differentiable { +struct Accessors: Differentiable { typealias TangentVector = DummyTangentVector mutating func move(along _: TangentVector) {} var stored: Float var computed: Float { - // `set` has an `inout` parameter: `(inout Self) -> (Float) -> ()`. // expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}} @differentiable set { stored = newValue } diff --git a/test/AutoDiff/Serialization/derivative_attr.swift b/test/AutoDiff/Serialization/derivative_attr.swift index 8b306740424c1..1e1b6339704a6 100644 --- a/test/AutoDiff/Serialization/derivative_attr.swift +++ b/test/AutoDiff/Serialization/derivative_attr.swift @@ -5,8 +5,6 @@ // BCANALYZER-NOT: UnknownCode -// REQUIRES: differentiable_programming - import _Differentiation // Dummy `Differentiable`-conforming type. diff --git a/test/AutoDiff/Serialization/differentiable_attr.swift b/test/AutoDiff/Serialization/differentiable_attr.swift index c1dcdfacfc5b2..96da6a266446f 100644 --- a/test/AutoDiff/Serialization/differentiable_attr.swift +++ b/test/AutoDiff/Serialization/differentiable_attr.swift @@ -1,20 +1,15 @@ // RUN: %empty-directory(%t) -// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t // RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER -// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s -// REQUIRES: differentiable_programming - -// TODO(TF-836): Enable this test. -// Blocked by TF-828: `@differentiable` attribute type-checking. -// XFAIL: * +// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s // BCANALYZER-NOT: UnknownCode import _Differentiation -// CHECK: @differentiable(wrt: x, jvp: jvpSimple, vjp: vjpSimple) +// CHECK: @differentiable(wrt: x) // CHECK-NEXT: func simple(x: Float) -> Float -@differentiable(jvp: jvpSimple, vjp: vjpSimple) +@differentiable func simple(x: Float) -> Float { return x } @@ -73,23 +68,18 @@ func testOnlyWhereClause(x: T) -> T { return x } -// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClause where T : Differentiable) +// CHECK: @differentiable(wrt: x where T : Differentiable) // CHECK-NEXT: func testWhereClause(x: T) -> T where T : Numeric -@differentiable(vjp: vjpTestWhereClause where T : Differentiable) +@differentiable(where T : Differentiable) func testWhereClause(x: T) -> T { return x } -func vjpTestWhereClause(x: T) -> (T, (T.TangentVector) -> T.TangentVector) - where T : Numeric, T : Differentiable -{ - return (x, { v in v }) -} protocol P {} extension P { - // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable) + // CHECK: @differentiable(wrt: self where Self : Differentiable) // CHECK-NEXT: func testWhereClauseMethod() -> Self - @differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable) + @differentiable(wrt: self where Self : Differentiable) func testWhereClauseMethod() -> Self { return self } @@ -100,9 +90,9 @@ extension P where Self : Differentiable { } } -// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) +// CHECK: @differentiable(wrt: x where T : Differentiable, T == T.TangentVector) // CHECK-NEXT: func testWhereClauseMethodTypeConstraint(x: T) -> T where T : Numeric -@differentiable(vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) +@differentiable(where T : Differentiable, T == T.TangentVector) func testWhereClauseMethodTypeConstraint(x: T) -> T { return x } @@ -113,9 +103,9 @@ func vjpTestWhereClauseMethodTypeConstraint(x: T) -> (T, (T) -> T) } extension P { - // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self : Differentiable, Self == Self.TangentVector) + // CHECK: @differentiable(wrt: self where Self : Differentiable, Self == Self.TangentVector) // CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self - @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self.TangentVector == Self, Self : Differentiable) + @differentiable(wrt: self where Self.TangentVector == Self, Self : Differentiable) func testWhereClauseMethodTypeConstraint() -> Self { return self } diff --git a/test/AutoDiff/Serialization/differentiation.swift b/test/AutoDiff/Serialization/differentiable_function.swift similarity index 74% rename from test/AutoDiff/Serialization/differentiation.swift rename to test/AutoDiff/Serialization/differentiable_function.swift index 3a17b5440b8a9..7e1f8af5e6a67 100644 --- a/test/AutoDiff/Serialization/differentiation.swift +++ b/test/AutoDiff/Serialization/differentiable_function.swift @@ -1,7 +1,7 @@ // RUN: %empty-directory(%t) // RUN: %target-swift-frontend %s -emit-module -parse-as-library -enable-experimental-differentiable-programming -o %t -// RUN: llvm-bcanalyzer %t/differentiation.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER -// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiation.swiftmodule -enable-experimental-differentiable-programming -o - | %FileCheck %s +// RUN: llvm-bcanalyzer %t/differentiable_function.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER +// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_function.swiftmodule -enable-experimental-differentiable-programming -o - | %FileCheck %s // BCANALYZER-NOT: UnknownCode diff --git a/test/AutoDiff/Serialization/transpose_attr.swift b/test/AutoDiff/Serialization/transpose_attr.swift index 80bc7a220586b..a5c5c9255a87e 100644 --- a/test/AutoDiff/Serialization/transpose_attr.swift +++ b/test/AutoDiff/Serialization/transpose_attr.swift @@ -4,7 +4,6 @@ // RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/transpose_attr.swiftmodule -o - | %FileCheck %s // BCANALYZER-NOT: UnknownCode -// REQUIRES: differentiable_programming // TODO(TF-838): Enable this test. // Blocked by TF-830: `@transpose` attribute type-checking. diff --git a/test/AutoDiff/Syntax/Outputs/round_trip_parse_gen.swift.withkinds b/test/AutoDiff/Syntax/Outputs/round_trip_parse_gen.swift.withkinds index c49bacacb7633..2503bdef381c1 100644 --- a/test/AutoDiff/Syntax/Outputs/round_trip_parse_gen.swift.withkinds +++ b/test/AutoDiff/Syntax/Outputs/round_trip_parse_gen.swift.withkinds @@ -11,19 +11,19 @@ // Note: RUN lines copied from test/Syntax/round_trip_parse_gen.swift. -@differentiable(jvp: foo(_:_:)) +@differentiable func bar(_ x: Float, _: Float) -> Float { return 1 } -@differentiable(jvp: foo(_:_:) where T : FloatingPoint) +@differentiable(where T : FloatingPoint) func bar<T : Numeric>(_ x: T, _: T) -> T { return 1 } -@differentiable(wrt: x, jvp: foo(_:_:)) +@differentiable(wrt: x) func bar(_ x: Float, _: Float) -> Float { return 1 } -@differentiable(wrt: (self, x, y), jvp: foo(_:_:)) +@differentiable(wrt: (self, x, y)) func bar(_ x: Float, y: Float) -> Float { return 1 } -@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:) where T : FloatingPoint) +@differentiable(wrt: (self, x, y) where T : FloatingPoint) func bar<T : Numeric>(_ x: T, y: T) -> T { return 1 } @derivative(of: -) diff --git a/test/AutoDiff/Syntax/round_trip_parse_gen.swift b/test/AutoDiff/Syntax/round_trip_parse_gen.swift index bfde3d5e04cfd..c71cd77b378dd 100644 --- a/test/AutoDiff/Syntax/round_trip_parse_gen.swift +++ b/test/AutoDiff/Syntax/round_trip_parse_gen.swift @@ -11,19 +11,19 @@ // Note: RUN lines copied from test/Syntax/round_trip_parse_gen.swift. -@differentiable(jvp: foo(_:_:)) +@differentiable func bar(_ x: Float, _: Float) -> Float { return 1 } -@differentiable(jvp: foo(_:_:) where T : FloatingPoint) +@differentiable(where T : FloatingPoint) func bar(_ x: T, _: T) -> T { return 1 } -@differentiable(wrt: x, jvp: foo(_:_:)) +@differentiable(wrt: x) func bar(_ x: Float, _: Float) -> Float { return 1 } -@differentiable(wrt: (self, x, y), jvp: foo(_:_:)) +@differentiable(wrt: (self, x, y)) func bar(_ x: Float, y: Float) -> Float { return 1 } -@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:) where T : FloatingPoint) +@differentiable(wrt: (self, x, y) where T : FloatingPoint) func bar(_ x: T, y: T) -> T { return 1 } @derivative(of: -) diff --git a/test/AutoDiff/lit.local.cfg b/test/AutoDiff/lit.local.cfg new file mode 100644 index 0000000000000..f0f2cc478d4fd --- /dev/null +++ b/test/AutoDiff/lit.local.cfg @@ -0,0 +1,2 @@ +if 'differentiable_programming' not in config.available_features: + config.unsupported = True diff --git a/test/AutoDiff/stdlib/differentiable_protocol.swift b/test/AutoDiff/stdlib/differentiable_protocol.swift index 4ff49e8708301..bf5462963616e 100644 --- a/test/AutoDiff/stdlib/differentiable_protocol.swift +++ b/test/AutoDiff/stdlib/differentiable_protocol.swift @@ -1,5 +1,4 @@ // RUN: %target-typecheck-verify-swift -// REQUIRES: differentiable_programming import _Differentiation diff --git a/test/AutoDiff/stdlib/differentiable_stdlib_conformances.swift b/test/AutoDiff/stdlib/differentiable_stdlib_conformances.swift index 977072bac6318..21fc2137c28d0 100644 --- a/test/AutoDiff/stdlib/differentiable_stdlib_conformances.swift +++ b/test/AutoDiff/stdlib/differentiable_stdlib_conformances.swift @@ -1,6 +1,5 @@ // RUN: %target-run-simple-swift // REQUIRES: executable_test -// REQUIRES: differentiable_programming import _Differentiation diff --git a/test/ClangImporter/access-specifiers-module-interface.swift b/test/ClangImporter/access-specifiers-module-interface.swift index 2b4b536e2f2ae..c65b4b84138da 100644 --- a/test/ClangImporter/access-specifiers-module-interface.swift +++ b/test/ClangImporter/access-specifiers-module-interface.swift @@ -9,31 +9,31 @@ // CHECK-NEXT: init() // CHECK-NEXT: } // CHECK-NEXT: struct PublicEnum : Equatable, RawRepresentable { -// CHECK-NEXT: init(_ rawValue: [[ENUM_INT:Int32|UInt32]]) -// CHECK-NEXT: init(rawValue: [[ENUM_INT]]) -// CHECK-NEXT: var rawValue: [[ENUM_INT]] -// CHECK-NEXT: typealias RawValue = [[ENUM_INT]] +// CHECK-NEXT: init(_ rawValue: [[ENUM_UNDERLYING_TYPE:Int32|UInt32]]) +// CHECK-NEXT: init(rawValue: [[ENUM_UNDERLYING_TYPE]]) +// CHECK-NEXT: var rawValue: [[ENUM_UNDERLYING_TYPE]] +// CHECK-NEXT: typealias RawValue = [[ENUM_UNDERLYING_TYPE]] // CHECK-NEXT: } -// CHECK-NEXT: @frozen enum PublicClosedEnum : [[ENUM_INT]] { -// CHECK-NEXT: init?(rawValue: [[ENUM_INT]]) -// CHECK-NEXT: var rawValue: [[ENUM_INT]] { get } -// CHECK-NEXT: typealias RawValue = [[ENUM_INT]] +// CHECK-NEXT: @frozen enum PublicClosedEnum : [[ENUM_UNDERLYING_TYPE]] { +// CHECK-NEXT: init?(rawValue: [[ENUM_UNDERLYING_TYPE]]) +// CHECK-NEXT: var rawValue: [[ENUM_UNDERLYING_TYPE]] { get } +// CHECK-NEXT: typealias RawValue = [[ENUM_UNDERLYING_TYPE]] // CHECK-NEXT: case value1 // CHECK-NEXT: @available(swift, obsoleted: 3, renamed: "value1") // CHECK-NEXT: static var Value1: PublicPrivate.PublicClosedEnum { get } // CHECK-NEXT: } -// CHECK-NEXT: enum PublicOpenEnum : [[ENUM_INT]] { -// CHECK-NEXT: init?(rawValue: [[ENUM_INT]]) -// CHECK-NEXT: var rawValue: [[ENUM_INT]] { get } -// CHECK-NEXT: typealias RawValue = [[ENUM_INT]] +// CHECK-NEXT: enum PublicOpenEnum : [[ENUM_UNDERLYING_TYPE]] { +// CHECK-NEXT: init?(rawValue: [[ENUM_UNDERLYING_TYPE]]) +// CHECK-NEXT: var rawValue: [[ENUM_UNDERLYING_TYPE]] { get } +// CHECK-NEXT: typealias RawValue = [[ENUM_UNDERLYING_TYPE]] // CHECK-NEXT: case value1 // CHECK-NEXT: @available(swift, obsoleted: 3, renamed: "value1") // CHECK-NEXT: static var Value1: PublicPrivate.PublicOpenEnum { get } // CHECK-NEXT: } // CHECK-NEXT: struct PublicFlagEnum : OptionSet { -// CHECK-NEXT: init(rawValue: [[ENUM_INT]]) -// CHECK-NEXT: let rawValue: [[ENUM_INT]] -// CHECK-NEXT: typealias RawValue = [[ENUM_INT]] +// CHECK-NEXT: init(rawValue: [[ENUM_UNDERLYING_TYPE]]) +// CHECK-NEXT: let rawValue: [[ENUM_UNDERLYING_TYPE]] +// CHECK-NEXT: typealias RawValue = [[ENUM_UNDERLYING_TYPE]] // CHECK-NEXT: typealias Element = PublicPrivate.PublicFlagEnum // CHECK-NEXT: typealias ArrayLiteralElement = PublicPrivate.PublicFlagEnum // CHECK-NEXT: } diff --git a/test/ClangImporter/availability_returns_twice.swift b/test/ClangImporter/availability_returns_twice.swift index c91e922fb0e89..84deef467aa15 100644 --- a/test/ClangImporter/availability_returns_twice.swift +++ b/test/ClangImporter/availability_returns_twice.swift @@ -8,7 +8,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin typealias JumpBuffer = Int32 -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc typealias JumpBuffer = jmp_buf #else diff --git a/test/ClangImporter/clang_builtins.swift b/test/ClangImporter/clang_builtins.swift index 720bb637cb909..ea3fb40684748 100644 --- a/test/ClangImporter/clang_builtins.swift +++ b/test/ClangImporter/clang_builtins.swift @@ -2,7 +2,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Constraints/casts.swift b/test/Constraints/casts.swift index 3747e66dec5c5..7737bca125fc5 100644 --- a/test/Constraints/casts.swift +++ b/test/Constraints/casts.swift @@ -221,7 +221,7 @@ func process(p: Any?) { func compare(_: T, _: T) {} // expected-note {{'compare' declared here}} func compare(_: T?, _: T?) {} -_ = nil? as? Int?? // expected-error {{nil literal cannot be the source of a conditional cast}} +_ = nil? as? Int?? // expected-error {{'nil' requires a contextual type}} func test_tuple_casts_no_warn() { struct Foo {} diff --git a/test/Constraints/one_way_solve.swift b/test/Constraints/one_way_solve.swift index ed854fa58c0ff..6f03fac0bc65d 100644 --- a/test/Constraints/one_way_solve.swift +++ b/test/Constraints/one_way_solve.swift @@ -38,9 +38,9 @@ func testTernaryOneWayOverload(b: Bool) { // CHECK: solving component #1 // CHECK: Initial bindings: $T11 := Int8 - // CHECK: found solution 0 0 0 0 0 0 2 0 0 0 0 0 + // CHECK: found solution 0 0 0 0 0 0 0 2 0 0 0 0 0 - // CHECK: composed solution 0 0 0 0 0 0 2 0 0 0 0 0 - // CHECK-NOT: composed solution 0 0 0 0 0 0 2 0 0 0 0 0 + // CHECK: composed solution 0 0 0 0 0 0 0 2 0 0 0 0 0 + // CHECK-NOT: composed solution 0 0 0 0 0 0 0 2 0 0 0 0 0 let _: Int8 = b ? Builtin.one_way(int8Or16(17)) : Builtin.one_way(int8Or16(42)) } diff --git a/test/Constraints/optional.swift b/test/Constraints/optional.swift index 9c7af5d4f525d..d3bf1c65a1121 100644 --- a/test/Constraints/optional.swift +++ b/test/Constraints/optional.swift @@ -439,9 +439,10 @@ func sr_12309() { _ = (nil!) // expected-error {{'nil' literal cannot be force unwrapped}} _ = (nil)! // expected-error {{'nil' literal cannot be force unwrapped}} _ = ((nil))! // expected-error {{'nil' literal cannot be force unwrapped}} - _ = nil? // expected-error {{value of optional type 'Optional<_>' must be unwrapped to a value of type '_'}} - // expected-note@-1 {{coalesce using '??' to provide a default when the optional value contains 'nil'}} - // expected-note@-2 {{force-unwrap using '!' to abort execution if the optional value contains 'nil'}} + _ = nil? // expected-error {{'nil' requires a contextual type}} + _ = ((nil?)) // expected-error {{'nil' requires a contextual type}} + _ = ((nil))? // expected-error {{'nil' requires a contextual type}} + _ = ((nil)?) // expected-error {{'nil' requires a contextual type}} _ = nil // expected-error {{'nil' requires a contextual type}} _ = (nil) // expected-error {{'nil' requires a contextual type}} _ = ((nil)) // expected-error {{'nil' requires a contextual type}} diff --git a/test/Constraints/rdar44770297.swift b/test/Constraints/rdar44770297.swift index d3dbbf4904655..82d96f7459112 100644 --- a/test/Constraints/rdar44770297.swift +++ b/test/Constraints/rdar44770297.swift @@ -4,11 +4,11 @@ protocol P { associatedtype A } -func foo(_: () throws -> T) -> T.A? { +func foo(_: () throws -> T) -> T.A? { // expected-note {{where 'T' = 'Never'}} fatalError() } -// TODO(diagnostics): This expression is truly ambiguous because there is no conformance between `Never` and `P` -// which means no associated type `A` and `nil` can't be an argument to any overload of `&` so we end -// up generating at least 3 fixes per overload of `&`. But we could at least point to where the problems are. -let _ = foo() {fatalError()} & nil // expected-error {{type of expression is ambiguous without more context}} +let _ = foo() {fatalError()} & nil // expected-error {{global function 'foo' requires that 'Never' conform to 'P'}} +// expected-error@-1 {{value of optional type 'Never.A?' must be unwrapped to a value of type 'Never.A'}} +// expected-note@-2 {{force-unwrap}} +// expected-note@-3 {{coalesce using '??'}} diff --git a/test/Constraints/valid_pointer_conversions.swift b/test/Constraints/valid_pointer_conversions.swift index 2f7e60098aa27..3a3bd9ea9cbc4 100644 --- a/test/Constraints/valid_pointer_conversions.swift +++ b/test/Constraints/valid_pointer_conversions.swift @@ -31,6 +31,13 @@ func givesPtr(_ str: String) { takesDoubleOptionalPtr(i) // expected-error {{cannot convert value of type 'Int' to expected argument type 'UnsafeRawPointer??'}} takesMutableDoubleOptionalPtr(arr) // expected-error {{cannot convert value of type '[Int]' to expected argument type 'UnsafeMutableRawPointer??'}} - // FIXME(SR-12382): Poor diagnostic. - takesMutableDoubleOptionalTypedPtr(&i) // expected-error {{type of expression is ambiguous without more context}} + takesMutableDoubleOptionalTypedPtr(&i) // expected-error {{cannot convert value of type 'UnsafeMutablePointer' to expected argument type 'UnsafeMutablePointer'}} + // expected-note@-1 {{arguments to generic parameter 'Pointee' ('Int' and 'Double') are expected to be equal}} } + +// SR12382 +func SR12382(_ x: UnsafeMutablePointer??) {} + +var i = 0 +SR12382(&i) // expected-error {{cannot convert value of type 'UnsafeMutablePointer' to expected argument type 'UnsafeMutablePointer'}} +// expected-note@-1 {{arguments to generic parameter 'Pointee' ('Int' and 'Double') are expected to be equal}} diff --git a/test/FixCode/verify-fixits.swift b/test/FixCode/verify-fixits.swift new file mode 100644 index 0000000000000..81c4943e39a41 --- /dev/null +++ b/test/FixCode/verify-fixits.swift @@ -0,0 +1,9 @@ +// RUN: cp %s %t +// RUN: not %swift -typecheck -target %target-triple -verify-apply-fixes %t +// RUN: diff %t %s.result + +func f1() { + guard true { return } // expected-error {{...}} + + guard true { return } // expected-error {{expected 'else' after 'guard' condition}} {{none}} +} \ No newline at end of file diff --git a/test/FixCode/verify-fixits.swift.result b/test/FixCode/verify-fixits.swift.result new file mode 100644 index 0000000000000..d6c3324067f79 --- /dev/null +++ b/test/FixCode/verify-fixits.swift.result @@ -0,0 +1,9 @@ +// RUN: cp %s %t +// RUN: not %swift -typecheck -target %target-triple -verify-apply-fixes %t +// RUN: diff %t %s.result + +func f1() { + guard true { return } // expected-error {{expected 'else' after 'guard' condition}} + + guard true { return } // expected-error {{expected 'else' after 'guard' condition}} {{14-14=else }} +} \ No newline at end of file diff --git a/test/Fuzzing/fuzzer_test.swift b/test/Fuzzing/fuzzer_test.swift index 6991a423cbe48..fbcfb7194edb9 100644 --- a/test/Fuzzing/fuzzer_test.swift +++ b/test/Fuzzing/fuzzer_test.swift @@ -10,7 +10,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Fuzzing/fuzzer_test_simpler.swift b/test/Fuzzing/fuzzer_test_simpler.swift index 559539cfb3226..e599130585fe5 100644 --- a/test/Fuzzing/fuzzer_test_simpler.swift +++ b/test/Fuzzing/fuzzer_test_simpler.swift @@ -10,7 +10,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/IDE/complete_value_expr.swift b/test/IDE/complete_value_expr.swift index f16a8ae567156..9599cc8329a32 100644 --- a/test/IDE/complete_value_expr.swift +++ b/test/IDE/complete_value_expr.swift @@ -200,6 +200,7 @@ // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=IN_DICTIONARY_LITERAL_1 | %FileCheck %s -check-prefix=SIMPLE_OBJECT_DOT // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=IN_DICTIONARY_LITERAL_2 | %FileCheck %s -check-prefix=SIMPLE_OBJECT_DOT // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=COMPLETE_CALL_RESULT | %FileCheck %s -check-prefix=COMPLETE_CALL_RESULT +// RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-keywords=false -code-completion-token=BROKEN_CONFORMANCE | %FileCheck %s -check-prefix=BROKEN_CONFORMANCE // Test code completion of expressions that produce a value. @@ -2185,3 +2186,19 @@ func testWrapSuccess(promise: Int, seal: Resolver) { // COMPLETE_CALL_RESULT: Pattern/CurrModule: ({#Void#}, {#Bool#})[#Void#]; name=(Void, Bool) // COMPLETE_CALL_RESULT: End completions } + +protocol BrokenConformanceP { + static func staticFunc() + func instanceFunc() +} +extension BrokenConformanceP { + static func staticFuncExtension() {} +} +struct BrokenConformanceS: BrokenConformanceP { +} +func testBrokenConformance(arg: BrokenConformanceS) { + arg.#^BROKEN_CONFORMANCE^# + // BROKEN_CONFORMANCE: Begin completions, 1 items + // BROKEN_CONFORMANCE: Decl[InstanceMethod]/Super: instanceFunc()[#Void#]; + // BROKEN_CONFORMANCE: End completions +} diff --git a/test/IRGen/autolink-runtime-compatibility-arm64e.swift b/test/IRGen/autolink-runtime-compatibility-arm64e.swift new file mode 100644 index 0000000000000..57be53f6f3c3f --- /dev/null +++ b/test/IRGen/autolink-runtime-compatibility-arm64e.swift @@ -0,0 +1,10 @@ +// REQUIRES: CPU=arm64e,OS=ios + +// Doesn't autolink compatibility library because target OS doesn't need it +// RUN: %target-swift-frontend -target arm64e-apple-ios11.0 -emit-ir -parse-stdlib %s | %FileCheck -check-prefix=NO-FORCE-LOAD %s + +public func foo() {} + +// NO-FORCE-LOAD-NOT: FORCE_LOAD +// NO-FORCE-LOAD-NOT: !{!"-lswiftCompatibility50"} +// NO-FORCE-LOAD-NOT: !{!"-lswiftCompatibilityDynamicReplacements"} diff --git a/test/IRGen/builtin_math.swift b/test/IRGen/builtin_math.swift index d20ebe4a21cc3..5c85dac5e5c20 100644 --- a/test/IRGen/builtin_math.swift +++ b/test/IRGen/builtin_math.swift @@ -2,7 +2,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/IRGen/objc_protocol_instance_methods.swift b/test/IRGen/objc_protocol_instance_methods.swift index 2ad684984d393..d67484867b6c9 100644 --- a/test/IRGen/objc_protocol_instance_methods.swift +++ b/test/IRGen/objc_protocol_instance_methods.swift @@ -7,6 +7,9 @@ // CHECK-NOT: _PROTOCOL_INSTANCE_METHODS_NSObject{{.*}}"\01L_selector_data(conformsToProtocol:)"{{.*}}"\01L_selector_data(conformsToProtocol:)" +// Make sure that extended method types are in sync with entries in method list. +// CHECK: @_PROTOCOL_INSTANCE_METHODS_NSObject = private constant { i32, i32, [5 x +// CHECK: @_PROTOCOL_METHOD_TYPES_NSObject = private constant [5 import Foundation @objc protocol P: NSObjectProtocol {} diff --git a/test/IRGen/opaque_result_type.swift b/test/IRGen/opaque_result_type.swift index 09c100030145d..0dda32dbf18e8 100644 --- a/test/IRGen/opaque_result_type.swift +++ b/test/IRGen/opaque_result_type.swift @@ -37,7 +37,7 @@ extension String: P { // -- mangled underlying type // CHECK-SAME: @"symbolic Si" // -- conformance to O - // CHECK-SAME: @"get_witness_table S2i18opaque_result_type1OHpyHC + // CHECK-SAME: @"get_witness_table Si18opaque_result_type1OHpyHC // CHECK-SAME: }> func poo() -> some O { return 0 @@ -68,7 +68,7 @@ public class C: P, Q { // -- mangled underlying type // CHECK-SAME: @"symbolic Si" // -- conformance to O - // CHECK-SAME: @"get_witness_table S2i18opaque_result_type1OHpyHC + // CHECK-SAME: @"get_witness_table Si18opaque_result_type1OHpyHC // CHECK-SAME: }> func poo() -> some O { return 0 @@ -82,9 +82,9 @@ public class C: P, Q { // -- mangled underlying type // CHECK-SAME: @"symbolic Si" // -- conformance to O - // CHECK-SAME: @"get_witness_table S2i18opaque_result_type1OHpyHC + // CHECK-SAME: @"get_witness_table Si18opaque_result_type1OHpyHC // -- conformance to O2 - // CHECK-SAME: @"get_witness_table S2i18opaque_result_type2O2HpyHC + // CHECK-SAME: @"get_witness_table Si18opaque_result_type2O2HpyHC // CHECK-SAME: }> func qoo() -> some O & O2 { return 0 @@ -99,7 +99,7 @@ public class C: P, Q { // -- mangled underlying type // CHECK-SAME: @"symbolic SS" // -- conformance to P -// CHECK-SAME: @"get_witness_table S2S18opaque_result_type1PHpyHC +// CHECK-SAME: @"get_witness_table SS18opaque_result_type1PHpyHC // CHECK-SAME: }> func foo(x: String) -> some P { return x @@ -113,7 +113,7 @@ func foo(x: String) -> some P { // -- mangled underlying type // CHECK-SAME: @"symbolic _____ 18opaque_result_type1CC" // -- conformance to Q -// CHECK-SAME: @"get_witness_table 18opaque_result_type1CCAcA1QHPyHC +// CHECK-SAME: @"get_witness_table 18opaque_result_type1CCAA1QHPyHC // CHECK-SAME: }> func bar(y: C) -> some Q { return y diff --git a/test/IRGen/opaque_result_type_metadata_peephole.swift b/test/IRGen/opaque_result_type_metadata_peephole.swift index 17cd995728a38..efbd46928d56b 100644 --- a/test/IRGen/opaque_result_type_metadata_peephole.swift +++ b/test/IRGen/opaque_result_type_metadata_peephole.swift @@ -12,9 +12,9 @@ func foo() -> some P { // The mangled underlying type in foo2 ought to look through foo()'s opaque type // CHECK-LABEL: @"$s36opaque_result_type_metadata_peephole4foo2QryFQOMQ" = {{.*}} constant // DEFAULT-SAME: @"symbolic Si" -// DEFAULT-SAME: @"get_witness_table S2i36opaque_result_type_metadata_external1PHpyHC +// DEFAULT-SAME: @"get_witness_table Si36opaque_result_type_metadata_external1PHpyHC // IMPLICIT-DYNAMIC-SAME: @"symbolic _____yQo_ 36opaque_result_type_metadata_peephole3fooQryFQO" -// IMPLICIT-DYNAMIC-SAME: @"get_witness_table 36opaque_result_type_metadata_peephole3fooQryFQOyQo_0a1_b1_c1_D9_external1P +// IMPLICIT-DYNAMIC-SAME: @"get_witness_table x36opaque_result_type_metadata_external1PHD1_0a1_b1_c1_D9_peephole3fooQryFQOyQo_HO func foo2() -> some P { return foo() } diff --git a/test/Incremental/single-file/AnyObject.swift b/test/Incremental/single-file/AnyObject.swift index fe8855ed32c61..adb47bd8556cd 100644 --- a/test/Incremental/single-file/AnyObject.swift +++ b/test/Incremental/single-file/AnyObject.swift @@ -1,3 +1,4 @@ +// REQUIRES: rdar60050653 // REQUIRES: objc_interop // RUN: %empty-directory(%t) @@ -15,6 +16,7 @@ import Foundation // expected-private-conformance {{Swift.CustomDebugStringConvertible}} // expected-private-conformance {{Swift.CVarArg}} // expected-private-conformance {{Swift.CustomStringConvertible}} +// expected-cascading-superclass {{main.LookupFactory}} @objc private class LookupFactory: NSObject { // expected-provides {{AssignmentPrecedence}} // expected-provides {{IntegerLiteralType}} @@ -29,7 +31,7 @@ import Foundation // expected-cascading-member {{__C.NSObject.init}} // expected-cascading-member {{main.LookupFactory.init}} - // expected-cascading-member {{main.LookupFactory.deinit}} + // expected-private-member {{main.LookupFactory.deinit}} // expected-cascading-member {{main.LookupFactory.someMember}} // expected-cascading-member {{main.LookupFactory.someMethod}} } diff --git a/test/Incremental/single-file/Conformances.swift b/test/Incremental/single-file/Conformances.swift index c60e3a35b05f3..a3363fbebb431 100644 --- a/test/Incremental/single-file/Conformances.swift +++ b/test/Incremental/single-file/Conformances.swift @@ -10,7 +10,7 @@ private protocol PrivateProtocol { } // expected-provides {{PrivateProtocol}} public struct PublicConformance { } // expected-provides {{PublicConformance}} // expected-cascading-member {{main.PublicConformance.init}} -// expected-cascading-member {{main.PublicConformance.deinit}} +// expected-cascading-conformance {{main.PublicConformance}} extension PublicConformance: PublicProtocol { } extension PublicConformance: InternalProtocol { } extension PublicConformance: FilePrivateProtocol { } @@ -20,8 +20,7 @@ extension PublicConformance: PrivateProtocol { } private struct PrivateConformance { } // expected-provides {{PrivateConformance}} // expected-cascading-member {{main.PrivateConformance.init}} -// FIXME: This could be a private dependency... -// expected-cascading-member {{main.PrivateConformance.deinit}} +// expected-cascading-conformance {{main.PrivateConformance}} extension PrivateConformance: PublicProtocol { } // expected-cascading-conformance {{main.PublicProtocol}} extension PrivateConformance: InternalProtocol { } // expected-cascading-conformance {{main.InternalProtocol}} extension PrivateConformance: FilePrivateProtocol { } // expected-cascading-conformance {{main.FilePrivateProtocol}} diff --git a/test/Interpreter/SDK/libc.swift b/test/Interpreter/SDK/libc.swift index 6bfb5351cbca7..0515cde35f671 100644 --- a/test/Interpreter/SDK/libc.swift +++ b/test/Interpreter/SDK/libc.swift @@ -11,7 +11,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Interpreter/dynamicReplacement_property_observer.swift b/test/Interpreter/dynamicReplacement_property_observer.swift index 1b8d27252fd80..29777250d22cc 100644 --- a/test/Interpreter/dynamicReplacement_property_observer.swift +++ b/test/Interpreter/dynamicReplacement_property_observer.swift @@ -14,7 +14,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Interpreter/dynamic_replacement.swift b/test/Interpreter/dynamic_replacement.swift index f68f240339322..5f3dc11fe534b 100644 --- a/test/Interpreter/dynamic_replacement.swift +++ b/test/Interpreter/dynamic_replacement.swift @@ -56,7 +56,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Interpreter/dynamic_replacement_chaining.swift b/test/Interpreter/dynamic_replacement_chaining.swift index d95779f0e4e37..8fc4c09b5d3df 100644 --- a/test/Interpreter/dynamic_replacement_chaining.swift +++ b/test/Interpreter/dynamic_replacement_chaining.swift @@ -28,7 +28,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Interpreter/dynamic_replacement_without_previous_calls.swift b/test/Interpreter/dynamic_replacement_without_previous_calls.swift index 760847a951fec..13ea1f3ab2fec 100644 --- a/test/Interpreter/dynamic_replacement_without_previous_calls.swift +++ b/test/Interpreter/dynamic_replacement_without_previous_calls.swift @@ -14,7 +14,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/Interpreter/generic_casts.swift b/test/Interpreter/generic_casts.swift index ffdd3a027f74c..99941a08524d6 100644 --- a/test/Interpreter/generic_casts.swift +++ b/test/Interpreter/generic_casts.swift @@ -159,7 +159,7 @@ func nongenericAnyIsPAndPCSubConforming(type: Any.Type) -> Bool { func genericAnyIs(type: Any.Type, to: T.Type, expected: Bool) -> Bool { // If we're testing against a runtime that doesn't have the fix this tests, // just pretend we got it right. - if #available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { return type is T.Type } else { return expected diff --git a/test/Interpreter/generic_casts_objc.swift b/test/Interpreter/generic_casts_objc.swift index 61b934fd57870..f7e9ccd80c01e 100644 --- a/test/Interpreter/generic_casts_objc.swift +++ b/test/Interpreter/generic_casts_objc.swift @@ -22,7 +22,7 @@ func nongenericAnyIsPObjC(type: Any.Type) -> Bool { func genericAnyIs(type: Any.Type, to: T.Type, expected: Bool) -> Bool { // If we're testing against a runtime that doesn't have the fix this tests, // just pretend we got it right. - if #available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { return type is T.Type } else { return expected diff --git a/test/NameBinding/name_lookup.swift b/test/NameBinding/name_lookup.swift index fb5f62cf9943a..82881dfbaad03 100644 --- a/test/NameBinding/name_lookup.swift +++ b/test/NameBinding/name_lookup.swift @@ -344,6 +344,19 @@ class ThisDerived1 : ThisBase1 { } } +protocol Crawlable {} +extension Crawlable { + static func crawl() {} +} +struct GenericChameleon: Crawlable { + static func chameleon() {} + + func testStaticOnInstance(arg: GenericChameleon) { + arg.chameleon() // expected-error {{static member 'chameleon' cannot be used on instance of type 'GenericChameleon'}} {{5-8=GenericChameleon}} + arg.crawl() // expected-error {{static member 'crawl' cannot be used on instance of type 'GenericChameleon'}} {{5-8=GenericChameleon}} + } +} + extension ThisBase1 { var baseExtProp : Int { get { diff --git a/test/NameBinding/reference-dependencies-fine.swift b/test/NameBinding/reference-dependencies-fine.swift index e15fbcf24e500..987a4a848bcae 100644 --- a/test/NameBinding/reference-dependencies-fine.swift +++ b/test/NameBinding/reference-dependencies-fine.swift @@ -464,7 +464,6 @@ struct Sentinel2 {} // CHECK-MEMBER-DAG: member interface 4main10IntWrapperV Int false -// CHECK-MEMBER-DAG: member interface 4main10IntWrapperV deinit false // CHECK-POTENTIALMEMBER-DAG: potentialMember interface SL '' false // CHECK-POTENTIALMEMBER-DAG: potentialMember interface 4main18ClassFromOtherFileC '' false // CHECK-MEMBER-DAG: member interface Si max false @@ -472,13 +471,13 @@ struct Sentinel2 {} // CHECK-POTENTIALMEMBER-DAG: potentialMember interface s33ExpressibleByUnicodeScalarLiteralP '' false // CHECK-MEMBER-DAG: member interface Sx Stride false // CHECK-MEMBER-DAG: member interface Sa reduce false -// CHECK-MEMBER-DAG: member interface 4main17OtherFileIntArrayV deinit false +// CHECK-POTENTIALMEMBER-DAG: potentialMember interface 4main17OtherFileIntArrayV '' false // CHECK-MEMBER-DAG: member interface 4main18OtherFileOuterTypeV InnerType false // CHECK-MEMBER-DAG: member interface 4main18OtherFileOuterTypeV05InnerE0V init false // CHECK-MEMBER-DAG: member interface 4main18OtherFileOuterTypeV05InnerE0V sharedConstant false // CHECK-MEMBER-DAG: member interface 4main26OtherFileSecretTypeWrapperV0dE0V constant false -// CHECK-MEMBER-DAG: member interface 4main25OtherFileProtoImplementorV deinit false -// CHECK-MEMBER-DAG: member interface 4main26OtherFileProtoImplementor2V deinit false +// CHECK-POTENTIALMEMBER-DAG: potentialMember interface 4main25OtherFileProtoImplementorV '' false +// CHECK-POTENTIALMEMBER-DAG: potentialMember interface 4main26OtherFileProtoImplementor2V '' false // CHECK-MEMBER-DAG: member interface s15EmptyCollectionV8IteratorV init false // CHECK-MEMBER-DAG: member interface 4main13OtherFileEnumO Value false // CHECK-MEMBER-DAG: member interface 4main20OtherFileEnumWrapperV Enum false diff --git a/test/NameBinding/reference-dependencies-members-fine.swift b/test/NameBinding/reference-dependencies-members-fine.swift index bbb99fbfb3e98..8ca770c26b725 100644 --- a/test/NameBinding/reference-dependencies-members-fine.swift +++ b/test/NameBinding/reference-dependencies-members-fine.swift @@ -52,7 +52,7 @@ protocol SomeProto {} // PROVIDES-NOMINAL-DAG: nominal implementation 4main10OtherClassC '' true // PROVIDES-NOMINAL-2-DAG: nominal interface 4main10OtherClassC '' true // PROVIDES-MEMBER-DAG: potentialMember interface 4main10OtherClassC '' true -// DEPENDS-MEMBER-DAG: member interface 4main10OtherClassC deinit false +// DEPENDS-MEMBER-DAG: potentialMember interface 4main10OtherClassC '' true extension OtherClass : SomeProto {} // PROVIDES-NOMINAL-DAG: nominal implementation 4main11OtherStructV '' true diff --git a/test/NameBinding/reference-dependencies-members.swift b/test/NameBinding/reference-dependencies-members.swift index 397d4d465de12..118fbe9c28c11 100644 --- a/test/NameBinding/reference-dependencies-members.swift +++ b/test/NameBinding/reference-dependencies-members.swift @@ -52,7 +52,7 @@ protocol SomeProto {} // DEPENDS-NOMINAL-DAG: 10OtherClassC" // DEPENDS-NOMINAL-DAG: 9SomeProtoP" // DEPENDS-MEMBER-DAG: - ["{{.+}}9SomeProtoP", ""] -// DEPENDS-MEMBER-DAG: - ["{{.+}}10OtherClassC", "deinit"] +// DEPENDS-MEMBER-DAG: - ["{{.+}}10OtherClassC", ""] extension OtherClass : SomeProto {} // PROVIDES-NOMINAL-NEGATIVE-NOT: 11OtherStructV"{{$}} diff --git a/test/Parse/enum.swift b/test/Parse/enum.swift index 3a51bfaa37901..5caa221e06396 100644 --- a/test/Parse/enum.swift +++ b/test/Parse/enum.swift @@ -473,7 +473,7 @@ enum SE0036 { func staticReferenceInInstanceMethod() { _ = A // expected-error {{enum case 'A' cannot be used as an instance member}} {{9-9=SE0036.}} - _ = self.A // expected-error {{enum case 'A' cannot be used as an instance member}} {{9-9=SE0036.}} + _ = self.A // expected-error {{enum case 'A' cannot be used as an instance member}} {{9-13=SE0036}} _ = SE0036.A } @@ -488,7 +488,7 @@ enum SE0036 { func staticReferenceInSwitchInInstanceMethod() { switch self { case A: break // expected-error {{enum case 'A' cannot be used as an instance member}} {{10-10=.}} - case B(_): break // expected-error {{enum case 'B' cannot be used as an instance member}} {{10-10=.}} + case B(_): break // expected-error {{'_' can only appear in a pattern or on the left side of an assignment}} case C(let x): _ = x; break // expected-error {{enum case 'C' cannot be used as an instance member}} {{10-10=.}} } } @@ -532,7 +532,7 @@ enum SE0036_Generic { func foo() { switch self { - case A(_): break // expected-error {{enum case 'A' cannot be used as an instance member}} {{10-10=.}} expected-error {{missing argument label 'x:' in call}} + case A(_): break // expected-error {{'_' can only appear in a pattern or on the left side of an assignment}} } switch self { diff --git a/test/Parse/matching_patterns.swift b/test/Parse/matching_patterns.swift index c5cbfd1fd07ca..1d32472a264bc 100644 --- a/test/Parse/matching_patterns.swift +++ b/test/Parse/matching_patterns.swift @@ -28,20 +28,17 @@ case square(9): // 'var' and 'let' patterns. case var a: a = 1 -case let a: // expected-warning {{case is already handled by previous patterns; consider removing it}} +case let a: a = 1 // expected-error {{cannot assign}} case var var a: // expected-error {{'var' cannot appear nested inside another 'var' or 'let' pattern}} - // expected-warning@-1 {{case is already handled by previous patterns; consider removing it}} a += 1 case var let a: // expected-error {{'let' cannot appear nested inside another 'var' or 'let' pattern}} - // expected-warning@-1 {{case is already handled by previous patterns; consider removing it}} print(a, terminator: "") case var (var b): // expected-error {{'var' cannot appear nested inside another 'var'}} - // expected-warning@-1 {{case is already handled by previous patterns; consider removing it}} b += 1 // 'Any' pattern. -case _: // expected-warning {{case is already handled by previous patterns; consider removing it}} +case _: () // patterns are resolved in expression-only positions are errors. @@ -283,14 +280,12 @@ case (1, 2, 3): // patterns in expression-only positions are errors. case +++(_, var d, 3): // expected-error@-1{{'_' can only appear in a pattern or on the left side of an assignment}} -// expected-error@-2{{'var' binding pattern cannot appear in an expression}} () case (_, var e, 3) +++ (1, 2, 3): -// expected-error@-1{{'_' can only appear in a pattern}} -// expected-error@-2{{'var' binding pattern cannot appear in an expression}} +// expected-error@-1{{'_' can only appear in a pattern or on the left side of an assignment}} () case (let (_, _, _)) + 1: -// expected-error@-1 {{expression pattern of type 'Int' cannot match values of type '(Int, Int, Int)'}} +// expected-error@-1 {{'_' can only appear in a pattern or on the left side of an assignment}} () } diff --git a/test/Prototypes/BigInt.swift b/test/Prototypes/BigInt.swift index b5f9f3924e0df..fe4a8ab53ccb6 100644 --- a/test/Prototypes/BigInt.swift +++ b/test/Prototypes/BigInt.swift @@ -19,7 +19,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/SILGen/dynamically_replaceable.swift b/test/SILGen/dynamically_replaceable.swift index e808ca1b9b3f5..8debacd2036fa 100644 --- a/test/SILGen/dynamically_replaceable.swift +++ b/test/SILGen/dynamically_replaceable.swift @@ -414,7 +414,7 @@ struct WrapperWithInitialValue { } } -// CHECK-LABEL: sil hidden [ossa] @$s23dynamically_replaceable10SomeStructV1tSbvpfP +// CHECK-NOT: sil hidden [ossa] @$s23dynamically_replaceable10SomeStructV1tSbvpfP public struct SomeStruct { @WrapperWithInitialValue var t = false } diff --git a/test/SILGen/property_wrappers.swift b/test/SILGen/property_wrappers.swift index 5dbf31ce6903e..bef64012673e0 100644 --- a/test/SILGen/property_wrappers.swift +++ b/test/SILGen/property_wrappers.swift @@ -510,6 +510,8 @@ public protocol TestProtocol {} public class TestClass { @WrapperWithInitialValue var value: T + // CHECK-LABEL: sil [ossa] @$s17property_wrappers9TestClassC5valuexvpfP : $@convention(thin) (@in T) -> @out WrapperWithInitialValue + // CHECK-LABEL: sil hidden [ossa] @$s17property_wrappers9TestClassC5value8protocolACyxGx_qd__tcAA0C8ProtocolRd__lufc // CHECK: [[BACKING_INIT:%.*]] = function_ref @$s17property_wrappers9TestClassC5valuexvpfP : $@convention(thin) <τ_0_0> (@in τ_0_0) -> @out WrapperWithInitialValue<τ_0_0> // CHECK-NEXT: partial_apply [callee_guaranteed] [[BACKING_INIT]]() diff --git a/test/SILGen/property_wrappers_library_evolution.swift b/test/SILGen/property_wrappers_library_evolution.swift index 09f750c213aba..270e01170848a 100644 --- a/test/SILGen/property_wrappers_library_evolution.swift +++ b/test/SILGen/property_wrappers_library_evolution.swift @@ -1,6 +1,6 @@ // RUN: %empty-directory(%t) // RUN: %target-swift-frontend -emit-module -o %t -enable-library-evolution %S/Inputs/property_wrapper_defs.swift -// RUN: %target-swift-emit-silgen -primary-file %s -I %t -enable-library-evolution +// RUN: %target-swift-emit-silgen -primary-file %s -I %t -enable-library-evolution | %FileCheck %s import property_wrapper_defs // rdar://problem/55995892 @@ -8,3 +8,12 @@ import property_wrapper_defs public enum E { case a } struct M { @MyPublished private var e = E.a } + +// Ensure that the backing initializer is serialized. +@frozen +public struct StructUsesPublishedAsPrivate { + public var integer: Int = 17 + + // CHECK: sil non_abi [serialized] [ossa] @$s35property_wrappers_library_evolution28StructUsesPublishedAsPrivateV6stringSSvpfP : $@convention(thin) (@owned String) -> @out MyPublished + @MyPublished var string: String = "Hello" +} diff --git a/test/SILOptimizer/cse.sil b/test/SILOptimizer/cse.sil index b6fa53fad8774..9f42954cc3c2f 100644 --- a/test/SILOptimizer/cse.sil +++ b/test/SILOptimizer/cse.sil @@ -1267,3 +1267,31 @@ bb0(%0 : $Proto & Ping): return %12 : $Ping } +sil [global_init] @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + +// CHECK-LABEL: sil @cse_global_init +// CHECK: [[P:%[0-9]+]] = apply +// CHECK: [[A:%[0-9]+]] = pointer_to_address [[P]] +// CHECK: begin_access [modify] [dynamic] [no_nested_conflict] [[A]] +// CHECK: begin_access [read] [dynamic] [no_nested_conflict] [[A]] +// CHECK: // end sil function 'cse_global_init' +sil @cse_global_init : $@convention(thin) () -> Int64 { +bb0: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + %4 = pointer_to_address %3 : $Builtin.RawPointer to [strict] $*Int64 + %5 = integer_literal $Builtin.Int64, 42 + %6 = struct $Int64 (%5 : $Builtin.Int64) + %7 = begin_access [modify] [dynamic] [no_nested_conflict] %4 : $*Int64 + store %6 to %7 : $*Int64 + end_access %7 : $*Int64 + %10 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %11 = apply %10() : $@convention(thin) () -> Builtin.RawPointer + %12 = pointer_to_address %11 : $Builtin.RawPointer to [strict] $*Int64 + %33 = begin_access [read] [dynamic] [no_nested_conflict] %12 : $*Int64 + %35 = load %33 : $*Int64 + end_access %33 : $*Int64 + return %35 : $Int64 +} + + diff --git a/test/SILOptimizer/destroy_hoisting_crash.swift b/test/SILOptimizer/destroy_hoisting_crash.swift new file mode 100644 index 0000000000000..ce96950dfb714 --- /dev/null +++ b/test/SILOptimizer/destroy_hoisting_crash.swift @@ -0,0 +1,21 @@ +// RUN: %target-swift-frontend -O %s -emit-sil -o /dev/null + +public struct S { + let args: [Substring] + let arg: Substring + + enum Error: Swift.Error { + case Case + } + + public init(arg: String) throws { + args = arg.split(separator: "\n") + guard args.count > 0 else { throw Error.Case } + + let parts = args[0].split(separator: " ") + guard parts.count > 2 else { throw Error.Case } + + self.arg = parts[1] + } +} + diff --git a/test/SILOptimizer/global_hoisting_crash.swift b/test/SILOptimizer/global_hoisting_crash.swift new file mode 100644 index 0000000000000..b640f656e9bbb --- /dev/null +++ b/test/SILOptimizer/global_hoisting_crash.swift @@ -0,0 +1,24 @@ +// RUN: %empty-directory(%t) +// RUN: %target-build-swift -O %s -o %t/a.out +// RUN: %target-run %t/a.out | %FileCheck %s + +// REQUIRES: executable_test + +struct Teststruct { + static let s = Teststruct() + + @inline(never) + init() { + let set = Set() + for _ in set { + // Check that the global initializer is not hoisted out of this loop, + // resulting in a dispatch_once re-retrance crash. + _ = Teststruct.s + } + } +} + +// CHECK: Teststruct +print(Teststruct.s) + + diff --git a/test/SILOptimizer/global_init_opt.swift b/test/SILOptimizer/global_init_opt.swift new file mode 100644 index 0000000000000..d6e127b84bf75 --- /dev/null +++ b/test/SILOptimizer/global_init_opt.swift @@ -0,0 +1,33 @@ +// RUN: %target-swift-frontend -parse-as-library -O -module-name=test %s -emit-sil | %FileCheck %s + +var gg: Int = { + print("gg init") + return 27 +}() + +// CHECK-LABEL: sil @$s4test3cseSiyF +// CHECK: builtin "once" +// CHECK-NOT: builtin "once" +// CHECK: [[G:%[0-9]+]] = load +// CHECK-NOT: builtin "once" +// CHECK: builtin "sadd_{{.*}}"([[G]] : $Builtin.Int{{[0-9]+}}, [[G]] : $Builtin.Int{{[0-9]+}}, %{{[0-9]+}} : $Builtin.Int1) +// CHECK-NOT: builtin "once" +// CHECK: } // end sil function '$s4test3cseSiyF' +public func cse() -> Int { + return gg + gg +} + +// CHECK-LABEL: sil @$s4test4licmSiyF +// CHECK: bb0: +// CHECK: builtin "once" +// CHECK: bb1: +// CHECK-NOT: builtin "once" +// CHECK: } // end sil function '$s4test4licmSiyF' +public func licm() -> Int { + var s = 0 + for _ in 0..<100 { + s += gg + } + return s +} + diff --git a/test/SILOptimizer/globalopt.sil b/test/SILOptimizer/globalopt.sil index e14318a4fc7d2..0892873906699 100644 --- a/test/SILOptimizer/globalopt.sil +++ b/test/SILOptimizer/globalopt.sil @@ -1,461 +1,55 @@ // RUN: %target-sil-opt -enable-sil-verify-all %s -global-opt | %FileCheck %s -// -// ginit.cold has a hammock with an initializer call on the slow path. -// ginit.loop has a loop containing an initializer call. sil_stage canonical - import Builtin import Swift -// globalinit_token0 -sil_global private @globalinit_token0 : $Builtin.Word -sil_global @MyConst : $Int32 - -// globalinit_func0 -sil @globalinit_func0 : $@convention(c) () -> () - -// ginit.MyConst.mutableAddressor : Swift.Int32 -sil [global_init] @_TF5ginita7MyConstSi : $@convention(thin) () -> Builtin.RawPointer - -// Don't hoist this initializer call. -// ginit.cold (Swift.Int32) -> Swift.Int32 -// CHECK-LABEL: sil @_TF5ginit4coldFSiSi -// CHECK-NOT: 5ginita7MyConst -// CHECK: bb1: -// CHECK: 5ginita7MyConst -// CHECK: {{^bb2}} -sil @_TF5ginit4coldFSiSi : $@convention(thin) (Int32) -> Int32 { -bb0(%0 : $Int32): - %1 = integer_literal $Builtin.Int32, 0 // users: %4, %5, %13 - %3 = struct_extract %0 : $Int32, #Int32._value // user: %4 - %4 = builtin "cmp_sgt_Int32"(%3 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1 // user: %5 - cond_br %4, bb1, bb2(%1 : $Builtin.Int32) // id: %5 - -bb1: // Preds: bb0 - // function_ref ginit.MyConst.mutableAddressor : Swift.Int32 - %6 = function_ref @_TF5ginita7MyConstSi : $@convention(thin) () -> Builtin.RawPointer // user: %7 - %7 = apply %6() : $@convention(thin) () -> Builtin.RawPointer // user: %8 - %8 = pointer_to_address %7 : $Builtin.RawPointer to [strict] $*Int32 // user: %9 - %9 = struct_element_addr %8 : $*Int32, #Int32._value // user: %10 - %10 = load %9 : $*Builtin.Int32 // user: %13 - %12 = integer_literal $Builtin.Int1, -1 // user: %13 - %13 = builtin "sadd_with_overflow_Int32"(%1 : $Builtin.Int32, %10 : $Builtin.Int32, %12 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %14 - %14 = tuple_extract %13 : $(Builtin.Int32, Builtin.Int1), 0 // user: %15 - br bb2(%14 : $Builtin.Int32) // id: %15 - -bb2(%16 : $Builtin.Int32): // Preds: bb0 bb1 - %17 = struct $Int32 (%16 : $Builtin.Int32) // user: %18 - return %17 : $Int32 // id: %18 -} - -// Do hoist this initializer call. -// ginit.loop (Swift.Int32) -> Swift.Int32 -// CHECK-LABEL: sil @_TF5ginit4loopFSiSi -// CHECK: {{^bb0}} -// CHECK: 5ginita7MyConst -// CHECK: {{^bb1}} -// CHECK-NOT: 5ginita7MyConst -sil @_TF5ginit4loopFSiSi : $@convention(thin) (Int32) -> Int32 { -bb0(%0 : $Int32): - %1 = integer_literal $Builtin.Int32, 0 // user: %8 - %2 = integer_literal $Builtin.Int32, 1 // users: %6, %8, %22 - %3 = struct_extract %0 : $Int32, #Int32._value // user: %6 - %5 = integer_literal $Builtin.Int1, -1 // users: %6, %22, %37 - %6 = builtin "sadd_with_overflow_Int32"(%3 : $Builtin.Int32, %2 : $Builtin.Int32, %5 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %7 - %7 = tuple_extract %6 : $(Builtin.Int32, Builtin.Int1), 0 // user: %13 - br bb1(%1 : $Builtin.Int32, %2 : $Builtin.Int32) // id: %8 - -bb1(%9 : $Builtin.Int32, %10 : $Builtin.Int32): // Preds: bb0 bb5 - %11 = struct $Int32 (%10 : $Builtin.Int32) // user: %24 - %13 = builtin "cmp_eq_Int32"(%10 : $Builtin.Int32, %7 : $Builtin.Int32) : $Builtin.Int1 // user: %14 - cond_br %13, bb2, bb4 // id: %14 - -bb2: // Preds: bb1 - %15 = enum $Optional, #Optional.none!enumelt // user: %16 - br bb3(%10 : $Builtin.Int32, %15 : $Optional) // id: %16 - -bb3(%17 : $Builtin.Int32, %18 : $Optional): // Preds: bb2 bb4 - %19 = alloc_stack $Optional // users: %20, %26, %27, %40 - store %18 to %19 : $*Optional // id: %20 - switch_enum %18 : $Optional, case #Optional.some!enumelt: bb5, case #Optional.none!enumelt: bb6 // id: %21 - -bb4: // Preds: bb1 - %22 = builtin "sadd_with_overflow_Int32"(%10 : $Builtin.Int32, %2 : $Builtin.Int32, %5 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %23 - %23 = tuple_extract %22 : $(Builtin.Int32, Builtin.Int1), 0 // user: %25 - %24 = enum $Optional, #Optional.some!enumelt, %11 : $Int32 // user: %25 - br bb3(%23 : $Builtin.Int32, %24 : $Optional) // id: %25 - -bb5: // Preds: bb3 - %26 = unchecked_take_enum_data_addr %19 : $*Optional, #Optional.some!enumelt - dealloc_stack %19 : $*Optional // id: %27 - %28 = alloc_stack $Optional // users: %29, %30, %31 - store %18 to %28 : $*Optional // id: %29 - %30 = unchecked_take_enum_data_addr %28 : $*Optional, #Optional.some!enumelt - dealloc_stack %28 : $*Optional // id: %31 - // function_ref ginit.MyConst.mutableAddressor : Swift.Int32 - %32 = function_ref @_TF5ginita7MyConstSi : $@convention(thin) () -> Builtin.RawPointer // user: %33 - %33 = apply %32() : $@convention(thin) () -> Builtin.RawPointer // user: %34 - %34 = pointer_to_address %33 : $Builtin.RawPointer to [strict] $*Int32 // user: %35 - %35 = struct_element_addr %34 : $*Int32, #Int32._value // user: %36 - %36 = load %35 : $*Builtin.Int32 // user: %37 - %37 = builtin "sadd_with_overflow_Int32"(%9 : $Builtin.Int32, %36 : $Builtin.Int32, %5 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %38 - %38 = tuple_extract %37 : $(Builtin.Int32, Builtin.Int1), 0 // user: %39 - br bb1(%38 : $Builtin.Int32, %17 : $Builtin.Int32) // id: %39 - -bb6: // Preds: bb3 - dealloc_stack %19 : $*Optional // id: %40 - %41 = struct $Int32 (%9 : $Builtin.Int32) // user: %42 - return %41 : $Int32 // id: %42 -} - -// libg.MyGlobal.mutableAddressor : Swift.Int32 -sil [global_init] @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - -// Hoist this initializer call out of a loop, but not into the function entry. -// ginit.loop (Swift.Int32) -> Swift.Int32 -// CHECK-LABEL: sil @_TF10ginitloops3runFSiSi -// CHECK: {{^bb2}} -// CHECK: function_ref @_TF4libga8MyGlobalSi -// CHECK-NEXT: apply -// CHECK: {{^bb6}} -// CHECK-NOT: addressor -// CHECK-NOT: mutableAddressor -// CHECK: pointer_to_address -// CHECK: {{br bb6}} -// -// ginitloops.run (Swift.Int32) -> Swift.Int32 -sil @_TF10ginitloops3runFSiSi : $@convention(thin) (Int32) -> Int32 { -bb0(%0 : $Int32): - %1 = integer_literal $Builtin.Int32, 1000 // user: %4 - %3 = struct_extract %0 : $Int32, #Int32._value // users: %4, %16, %29 - %4 = builtin "cmp_sgt_Int32"(%3 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1 // user: %5 - cond_br %4, bb1, bb2 // id: %5 - -bb1: // Preds: bb0 - %6 = integer_literal $Builtin.Int32, -1 // user: %7 - %7 = struct $Int32 (%6 : $Builtin.Int32) // user: %8 - br bb9(%7 : $Int32) // id: %8 - -bb2: // Preds: bb0 - %9 = integer_literal $Builtin.Int32, 0 // user: %11 - %10 = integer_literal $Builtin.Int32, 1 // users: %11, %22, %25, %34 - br bb3(%9 : $Builtin.Int32, %10 : $Builtin.Int32) // id: %11 - -bb3(%12 : $Builtin.Int32, %13 : $Builtin.Int32): // Preds: bb2 bb7 - %14 = struct $Int32 (%13 : $Builtin.Int32) // user: %24 - %16 = builtin "cmp_eq_Int32"(%13 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1 // user: %17 - cond_br %16, bb4, bb5 // id: %17 - -bb4: // Preds: bb3 - %18 = struct $Int32 (%12 : $Builtin.Int32) // user: %19 - br bb9(%18 : $Int32) // id: %19 - -bb5: // Preds: bb3 - %21 = integer_literal $Builtin.Int1, -1 // user: %22 - %22 = builtin "sadd_with_overflow_Int32"(%13 : $Builtin.Int32, %10 : $Builtin.Int32, %21 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %23 - %23 = tuple_extract %22 : $(Builtin.Int32, Builtin.Int1), 0 // user: %31 - %24 = enum $Optional, #Optional.some!enumelt, %14 : $Int32 - br bb6(%12 : $Builtin.Int32, %10 : $Builtin.Int32) // id: %25 - -bb6(%26 : $Builtin.Int32, %27 : $Builtin.Int32): // Preds: bb5 bb8 - %28 = struct $Int32 (%27 : $Builtin.Int32) // user: %36 - %29 = builtin "cmp_eq_Int32"(%27 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1 // user: %30 - cond_br %29, bb7, bb8 // id: %30 - -bb7: // Preds: bb6 - br bb3(%26 : $Builtin.Int32, %23 : $Builtin.Int32) // id: %31 - -bb8: // Preds: bb6 - %33 = integer_literal $Builtin.Int1, -1 // user: %34 - %34 = builtin "sadd_with_overflow_Int32"(%27 : $Builtin.Int32, %10 : $Builtin.Int32, %33 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %35 - %35 = tuple_extract %34 : $(Builtin.Int32, Builtin.Int1), 0 // user: %46 - %36 = enum $Optional, #Optional.some!enumelt, %28 : $Int32 - // function_ref libg.MyGlobal.mutableAddressor : Swift.Int32 - %37 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer // user: %38 - %38 = apply %37() : $@convention(thin) () -> Builtin.RawPointer // user: %39 - %39 = pointer_to_address %38 : $Builtin.RawPointer to [strict] $*Int32 // user: %40 - %40 = struct_element_addr %39 : $*Int32, #Int32._value // user: %41 - %41 = load %40 : $*Builtin.Int32 // user: %44 - %43 = integer_literal $Builtin.Int1, -1 // user: %44 - %44 = builtin "sadd_with_overflow_Int32"(%26 : $Builtin.Int32, %41 : $Builtin.Int32, %43 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %45 - %45 = tuple_extract %44 : $(Builtin.Int32, Builtin.Int1), 0 // user: %46 - br bb6(%45 : $Builtin.Int32, %35 : $Builtin.Int32) // id: %46 - -bb9(%47 : $Int32): // Preds: bb1 bb4 - return %47 : $Int32 // id: %48 -} - - -// Do NOT hoist this initializer out of a cold block. -// CHECK-LABEL: sil @_TF9ginitcold3runFSiSi -// CHECK-NOT: addressor -// CHECK-NOT: mutableAddressor -// CHECK: {{^bb3}} -// CHECK: cond_br -// CHECK: {{^bb4}} -// CHECK: function_ref @_TF4libga8MyGlobalSi -// CHECK-NEXT: apply -// CHECK: pointer_to_address -// CHECK: br -sil @_TF9ginitcold3runFSiSi : $@convention(thin) (Int32) -> Int32 { -bb0(%0 : $Int32): - %1 = integer_literal $Builtin.Int32, 0 // users: %4, %19 - %2 = integer_literal $Builtin.Int32, 1 // users: %4, %14 - %3 = struct_extract %0 : $Int32, #Int32._value // user: %8 - br bb1(%1 : $Builtin.Int32, %2 : $Builtin.Int32) // id: %4 - -bb1(%5 : $Builtin.Int32, %6 : $Builtin.Int32): // Preds: bb0 bb5 - %8 = builtin "cmp_eq_Int32"(%6 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1 // user: %9 - cond_br %8, bb2, bb3 // id: %9 - -bb2: // Preds: bb1 - %10 = struct $Int32 (%5 : $Builtin.Int32) // user: %11 - return %10 : $Int32 // id: %11 - -bb3: // Preds: bb1 - %13 = integer_literal $Builtin.Int1, -1 // users: %14, %29 - %14 = builtin "sadd_with_overflow_Int32"(%6 : $Builtin.Int32, %2 : $Builtin.Int32, %13 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %15 - %15 = tuple_extract %14 : $(Builtin.Int32, Builtin.Int1), 0 // user: %33 - %16 = integer_literal $Builtin.Int32, 10 // user: %18 - %18 = builtin "srem_Int32"(%5 : $Builtin.Int32, %16 : $Builtin.Int32) : $Builtin.Int32 // user: %19 - %19 = builtin "cmp_eq_Int32"(%18 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1 // user: %22 - %20 = integer_literal $Builtin.Int1, 0 // user: %22 - %22 = builtin "int_expect_Int1"(%19 : $Builtin.Int1, %20 : $Builtin.Int1) : $Builtin.Int1 // user: %23 - cond_br %22, bb4, bb5(%5 : $Builtin.Int32) // id: %23 - -bb4: // Preds: bb3 - // function_ref libg.MyGlobal.mutableAddressor : Swift.Int32 - %24 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer // user: %25 - %25 = apply %24() : $@convention(thin) () -> Builtin.RawPointer // user: %26 - %26 = pointer_to_address %25 : $Builtin.RawPointer to [strict] $*Int32 // user: %27 - %27 = struct_element_addr %26 : $*Int32, #Int32._value // user: %28 - %28 = load %27 : $*Builtin.Int32 // user: %29 - %29 = builtin "sadd_with_overflow_Int32"(%5 : $Builtin.Int32, %28 : $Builtin.Int32, %13 : $Builtin.Int1) : $(Builtin.Int32, Builtin.Int1) // user: %30 - %30 = tuple_extract %29 : $(Builtin.Int32, Builtin.Int1), 0 // user: %31 - br bb5(%30 : $Builtin.Int32) // id: %31 - -bb5(%32 : $Builtin.Int32): // Preds: bb3 bb4 - br bb1(%32 : $Builtin.Int32, %15 : $Builtin.Int32) // id: %33 -} - -// Combine two init calls into one in the common dominator -// CHECK-LABEL: sil @test_common_dominator -// CHECK: bb0(%0 : $Builtin.Int1): -// CHECK: apply -// CHECK: bb1: -// CHECK-NOT: apply -// CHECK: return -sil @test_common_dominator : $@convention(thin) (Builtin.Int1) -> Int64 { -bb0(%0 : $Builtin.Int1): - %1 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - %2 = apply %1() : $@convention(thin) () -> Builtin.RawPointer - %3 = pointer_to_address %2 : $Builtin.RawPointer to [strict] $*Int64 - %4 = struct_element_addr %3 : $*Int64, #Int64._value - %5 = load %4 : $*Builtin.Int64 - cond_br %0, bb1, bb2(%5 : $Builtin.Int64) - -bb1: - %8 = apply %1() : $@convention(thin) () -> Builtin.RawPointer - %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Int64 - %10 = struct_element_addr %9 : $*Int64, #Int64._value - %11 = load %10 : $*Builtin.Int64 - br bb2(%11 : $Builtin.Int64) - -bb2(%18 : $Builtin.Int64): - %19 = struct $Int64 (%18 : $Builtin.Int64) - return %19 : $Int64 -} - -// Combine two init calls into one in the common dominator -// CHECK-LABEL: sil @test_common_dominator2 -// CHECK: bb0(%0 : $Builtin.Int1): -// CHECK: apply -// CHECK: bb1: -// CHECK-NOT: apply -// CHECK: return -sil @test_common_dominator2 : $@convention(thin) (Builtin.Int1) -> Int64 { -bb0(%0 : $Builtin.Int1): - cond_br %0, bb1, bb2 - -bb1: - %1 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - %2 = apply %1() : $@convention(thin) () -> Builtin.RawPointer - %3 = pointer_to_address %2 : $Builtin.RawPointer to [strict] $*Int64 - %4 = struct_element_addr %3 : $*Int64, #Int64._value - %5 = load %4 : $*Builtin.Int64 - br bb3(%5 : $Builtin.Int64) - -bb2: - %11 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - %12 = apply %11() : $@convention(thin) () -> Builtin.RawPointer - %13 = pointer_to_address %12 : $Builtin.RawPointer to [strict] $*Int64 - %14 = struct_element_addr %13 : $*Int64, #Int64._value - %15 = load %14 : $*Builtin.Int64 - br bb3(%15 : $Builtin.Int64) - -bb3(%18 : $Builtin.Int64): - %19 = struct $Int64 (%18 : $Builtin.Int64) - return %19 : $Int64 -} - -// Test a special case: If there is a call in a loop and in its exit block, which is located -// before the loop, the init-call should still be hoisted out of the loop. -// CHECK-LABEL: sil @test_loopexit_and_loop -// CHECK: bb0(%0 : $Builtin.Int1): -// CHECK: apply -// CHECK: bb1: -// CHECK-NOT: apply -// CHECK: return -sil @test_loopexit_and_loop : $@convention(thin) (Builtin.Int1) -> Int64 { -bb0(%0 : $Builtin.Int1): - br bb2 - -bb1: - %1 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - %2 = apply %1() : $@convention(thin) () -> Builtin.RawPointer - %3 = pointer_to_address %2 : $Builtin.RawPointer to [strict] $*Int64 - %4 = struct_element_addr %3 : $*Int64, #Int64._value - %5 = load %4 : $*Builtin.Int64 - %r1 = struct $Int64 (%5 : $Builtin.Int64) - return %r1 : $Int64 - -bb2: - %11 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - %12 = apply %11() : $@convention(thin) () -> Builtin.RawPointer - %13 = pointer_to_address %12 : $Builtin.RawPointer to [strict] $*Int64 - %14 = struct_element_addr %13 : $*Int64, #Int64._value - %15 = load %14 : $*Builtin.Int64 - cond_br %0, bb2, bb1 -} - -// libg.MyGlobal.mutableAddressor : Swift.Int32 -sil [global_init] [dynamically_replacable] @_TF4libga8MyGlobalSi_dynamic : $@convention(thin) () -> Builtin.RawPointer - -// Don't hoist dynamic_function_ref calls. -// CHECK-LABEL: sil @test_loopexit_and_loop_dynamic -// CHECK: bb0(%0 : $Builtin.Int1): -// CHECK: bb1: -// CHECK: apply -// CHECK: return -// CHECK: bb2: -// CHECK: apply -// CHECK: cond_br -sil @test_loopexit_and_loop_dynamic : $@convention(thin) (Builtin.Int1) -> Int64 { -bb0(%0 : $Builtin.Int1): - br bb2 - -bb1: - %1 = dynamic_function_ref @_TF4libga8MyGlobalSi_dynamic : $@convention(thin) () -> Builtin.RawPointer - %2 = apply %1() : $@convention(thin) () -> Builtin.RawPointer - %3 = pointer_to_address %2 : $Builtin.RawPointer to [strict] $*Int64 - %4 = struct_element_addr %3 : $*Int64, #Int64._value - %5 = load %4 : $*Builtin.Int64 - %r1 = struct $Int64 (%5 : $Builtin.Int64) - return %r1 : $Int64 - -bb2: - %11 = dynamic_function_ref @_TF4libga8MyGlobalSi_dynamic : $@convention(thin) () -> Builtin.RawPointer - %12 = apply %11() : $@convention(thin) () -> Builtin.RawPointer - %13 = pointer_to_address %12 : $Builtin.RawPointer to [strict] $*Int64 - %14 = struct_element_addr %13 : $*Int64, #Int64._value - %15 = load %14 : $*Builtin.Int64 - cond_br %0, bb2, bb1 -} -// An init-call, which is guarded by an availability-check may not be speculated. -// In this test it may not be hoisted out of the loop. -// CHECK-LABEL: sil @test_availability_loop -// CHECK: [[INIT:%[0-9]+]] = function_ref @_TF4libga8MyGlobalSi -// CHECK: {{^bb2:}} -// CHECK-NEXT: apply [[INIT]]() -sil @test_availability_loop : $@convention(thin) (Builtin.Int1) -> () { -bb0(%0 : $Builtin.Int1): - %f1 = function_ref @test_availability : $@convention(thin) () -> Builtin.Int1 - %f2 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - br bb1 - -bb1: - %a1 = apply %f1() : $@convention(thin) () -> Builtin.Int1 - cond_br %a1, bb2, bb4 - -bb2: - %a2 = apply %f2() : $@convention(thin) () -> Builtin.RawPointer - br bb3 - -bb3: - cond_br %0, bb1, bb4 - -bb4: - %r = tuple () - return %r : $() -} - -// The init-call should be hoisted out of the inner loop, but not out of the -// outer loop, because of the availability check around the inner loop. -// CHECK-LABEL: sil @test_availability_loop_nest -// CHECK: [[INIT:%[0-9]+]] = function_ref @_TF4libga8MyGlobalSi -// CHECK: {{^bb2:}} -// CHECK-NEXT: apply [[INIT]]() -sil @test_availability_loop_nest : $@convention(thin) (Builtin.Int1) -> () { -bb0(%0 : $Builtin.Int1): - %f1 = function_ref @test_availability : $@convention(thin) () -> Builtin.Int1 - %f2 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - br bb1 - -bb1: - %a1 = apply %f1() : $@convention(thin) () -> Builtin.Int1 - cond_br %a1, bb2, bb4 - -bb2: - br bb3 - -bb3: - %a2 = apply %f2() : $@convention(thin) () -> Builtin.RawPointer - cond_br %0, bb3, bb4 - -bb4: - cond_br %0, bb1, bb5 - -bb5: - %r = tuple () - return %r : $() +private var testGlobal: Int64 + +sil_global private @globalinit_33_00F4D2139E6BDDFEC71E5005B67B5674_token0 : $Builtin.Word + +sil_global private @$s4test10testGlobalSivp : $Int64 + +sil private @globalinit_33_00F4D2139E6BDDFEC71E5005B67B5674_func0 : $@convention(c) () -> () { +bb0: + alloc_global @$s4test10testGlobalSivp + %1 = global_addr @$s4test10testGlobalSivp : $*Int64 + %2 = integer_literal $Builtin.Int64, 27 + %3 = struct $Int64 (%2 : $Builtin.Int64) + store %3 to %1 : $*Int64 + %5 = tuple () + return %5 : $() +} + +sil hidden [global_init] @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer { +bb0: + %0 = global_addr @globalinit_33_00F4D2139E6BDDFEC71E5005B67B5674_token0 : $*Builtin.Word + %1 = address_to_pointer %0 : $*Builtin.Word to $Builtin.RawPointer + %2 = function_ref @globalinit_33_00F4D2139E6BDDFEC71E5005B67B5674_func0 : $@convention(c) () -> () + %3 = builtin "once"(%1 : $Builtin.RawPointer, %2 : $@convention(c) () -> ()) : $() + %4 = global_addr @$s4test10testGlobalSivp : $*Int64 + %5 = address_to_pointer %4 : $*Int64 to $Builtin.RawPointer + return %5 : $Builtin.RawPointer +} + +// CHECK-LABEL: sil @dont_propagate_global_with_multiple_writes +// CHECK: [[V:%[0-9]+]] = load +// CHECK: return [[V]] +// CHECK: } // end sil function 'dont_propagate_global_with_multiple_writes' +sil @dont_propagate_global_with_multiple_writes : $@convention(thin) (Int64) -> Int64 { +bb0(%0 : $Int64): + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + %4 = pointer_to_address %3 : $Builtin.RawPointer to [strict] $*Int64 + %5 = integer_literal $Builtin.Int64, 42 + %6 = struct $Int64 (%5 : $Builtin.Int64) + %7 = begin_access [modify] [dynamic] [no_nested_conflict] %4 : $*Int64 + store %6 to %7 : $*Int64 + end_access %7 : $*Int64 + %33 = begin_access [read] [dynamic] [no_nested_conflict] %4 : $*Int64 + %35 = load %33 : $*Int64 + end_access %33 : $*Int64 + return %35 : $Int64 } -// The init-calls may not be moved to their common dominator because of -// availability guards. -// CHECK-LABEL: sil @test_availability_common_dominator -// CHECK: [[INIT:%[0-9]+]] = function_ref @_TF4libga8MyGlobalSi -// CHECK: {{^bb2:}} -// CHECK-NEXT: apply [[INIT]]() -// CHECK: {{^bb4:}} -// CHECK-NEXT: apply [[INIT]]() -sil @test_availability_common_dominator : $@convention(thin) (Builtin.Int1) -> () { -bb0(%0 : $Builtin.Int1): - %f1 = function_ref @test_availability : $@convention(thin) () -> Builtin.Int1 - %f2 = function_ref @_TF4libga8MyGlobalSi : $@convention(thin) () -> Builtin.RawPointer - cond_br %0, bb1, bb3 - -bb1: - %a1 = apply %f1() : $@convention(thin) () -> Builtin.Int1 - cond_br %a1, bb2, bb5 - -bb2: - %a2 = apply %f2() : $@convention(thin) () -> Builtin.RawPointer - br bb5 - -bb3: - %a3 = apply %f1() : $@convention(thin) () -> Builtin.Int1 - cond_br %0, bb4, bb5 - -bb4: - %a4 = apply %f2() : $@convention(thin) () -> Builtin.RawPointer - br bb5 - -bb5: - %r = tuple () - return %r : $() -} - -sil [_semantics "availability.test"] @test_availability : $@convention(thin) () -> Builtin.Int1 - diff --git a/test/SILOptimizer/licm_apply.sil b/test/SILOptimizer/licm_apply.sil index 0f6a3ae31b78a..a7b95254f4090 100644 --- a/test/SILOptimizer/licm_apply.sil +++ b/test/SILOptimizer/licm_apply.sil @@ -5,6 +5,8 @@ sil_stage canonical import Builtin import Swift +sil @unknown : $@convention(thin) () -> () + sil @read_from_param : $@convention(thin) (@inout Int64, Int64) -> Int64 { bb0(%0 : $*Int64, %1 : $Int64): debug_value %1 : $Int64 @@ -119,3 +121,190 @@ bb2: return %15 : $() } +sil [global_init] @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + +// CHECK-LABEL: sil @licm_global_init +// CHECK: apply +// CHECK: bb1: +// CHECK-NOT: apply +// CHECK: } // end sil function 'licm_global_init' +sil @licm_global_init : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + cond_br undef, bb1, bb2 + +bb2: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @dont_licm_ginit_prehead_not_pdom +// CHECK-NOT: apply +// CHECK: bb1: +// CHECK: apply +// CHECK: bb2: +// CHECK-NOT: apply +// CHECK: } // end sil function 'dont_licm_ginit_prehead_not_pdom' +sil @dont_licm_ginit_prehead_not_pdom : $@convention(thin) () -> () { +bb0: + cond_br undef, bb1, bb2 + +bb1: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + cond_br undef, bb1, bb2 + +bb2: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @licm_ginit_complex_cfg +// CHECK-NOT: apply +// CHECK: bb0: +// CHECK: apply +// CHECK: bb1: +// CHECK-NOT: apply +// CHECK: } // end sil function 'licm_ginit_complex_cfg' +sil @licm_ginit_complex_cfg : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + cond_br undef, bb2, bb3 + +bb2: + br bb4 + +bb3: + br bb4 + +bb4: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + cond_br undef, bb1, bb5 + +bb5: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @dont_licm_ginit_apply_not_pdom +// CHECK-NOT: apply +// CHECK: bb2: +// CHECK: apply +// CHECK: bb3: +// CHECK-NOT: apply +// CHECK: } // end sil function 'dont_licm_ginit_apply_not_pdom' +sil @dont_licm_ginit_apply_not_pdom : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + cond_br undef, bb2, bb3 + +bb2: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + br bb4 + +bb3: + br bb4 + +bb4: + cond_br undef, bb1, bb5 + +bb5: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @dont_licm_ginit_dependency1 +// CHECK: bb0: +// CHECK: [[I:%[0-9]+]] = function_ref @$s4test10testGlobalSivau +// CHECK: bb2: +// CHECK: [[U:%[0-9]+]] = function_ref @unknown +// CHECK: apply [[U]] +// CHECK: bb4: +// CHECK: apply [[I]] +// CHECK: bb5: +// CHECK: } // end sil function 'dont_licm_ginit_dependency1' +sil @dont_licm_ginit_dependency1 : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + cond_br undef, bb2, bb3 + +bb2: + %f = function_ref @unknown : $@convention(thin) () -> () + %a = apply %f() : $@convention(thin) () -> () + br bb4 + +bb3: + br bb4 + +bb4: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + cond_br undef, bb1, bb5 + +bb5: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @dont_licm_ginit_dependency2 +// CHECK: bb0: +// CHECK-DAG: [[I:%[0-9]+]] = function_ref @$s4test10testGlobalSivau +// CHECK-DAG: [[U:%[0-9]+]] = function_ref @unknown +// CHECK: bb1: +// CHECK: apply [[U]] +// CHECK: apply [[I]] +// CHECK: bb2: +// CHECK: } // end sil function 'dont_licm_ginit_dependency2' +sil @dont_licm_ginit_dependency2 : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + %f = function_ref @unknown : $@convention(thin) () -> () + %a = apply %f() : $@convention(thin) () -> () + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + cond_br undef, bb1, bb2 + +bb2: + %15 = tuple () + return %15 : $() +} + +// CHECK-LABEL: sil @licm_ginit_no_dependency +// CHECK: bb0: +// CHECK-DAG: [[I:%[0-9]+]] = function_ref @$s4test10testGlobalSivau +// CHECK-DAG: [[U:%[0-9]+]] = function_ref @unknown +// CHECK-DAG: apply [[I]] +// CHECK: bb1: +// CHECK: apply [[U]] +// CHECK: bb2: +// CHECK: } // end sil function 'licm_ginit_no_dependency' +sil @licm_ginit_no_dependency : $@convention(thin) () -> () { +bb0: + br bb1 + +bb1: + %2 = function_ref @$s4test10testGlobalSivau : $@convention(thin) () -> Builtin.RawPointer + %3 = apply %2() : $@convention(thin) () -> Builtin.RawPointer + %f = function_ref @unknown : $@convention(thin) () -> () + %a = apply %f() : $@convention(thin) () -> () + cond_br undef, bb1, bb2 + +bb2: + %15 = tuple () + return %15 : $() +} + diff --git a/test/SILOptimizer/licm_exclusivity.sil b/test/SILOptimizer/licm_exclusivity.sil index 6e7a5fd821cb7..39deb412efdb5 100644 --- a/test/SILOptimizer/licm_exclusivity.sil +++ b/test/SILOptimizer/licm_exclusivity.sil @@ -21,7 +21,7 @@ var globalY: X sil_global hidden @globalY : $X -sil hidden_external [global_init] @globalAddressor : $@convention(thin) () -> Builtin.RawPointer +sil hidden_external @globalAddressor : $@convention(thin) () -> Builtin.RawPointer // public func hoist_access_with_conflict() { // Tests Hoisting of begin/end access when there's a "sandwiched" unidentified access diff --git a/test/Serialization/Recovery/implementation-only-missing.swift b/test/Serialization/Recovery/implementation-only-missing.swift index 381d6510e0c36..311bfa67e00a4 100644 --- a/test/Serialization/Recovery/implementation-only-missing.swift +++ b/test/Serialization/Recovery/implementation-only-missing.swift @@ -15,6 +15,10 @@ // RUN: %target-swift-frontend -typecheck -DCLIENT_APP -primary-file %s -I %t -index-system-modules -index-store-path %t // RUN: %target-swift-frontend -emit-sil -DCLIENT_APP -primary-file %s -I %t -module-name client +//// Printing the public module should not crash when checking for overrides of +//// methods from the private module. +// RUN: %target-swift-ide-test -print-module -module-to-print=public_lib -source-filename=x -skip-overrides -I %t + #if PRIVATE_LIB @propertyWrapper @@ -38,11 +42,15 @@ public protocol HiddenProtocol { associatedtype Value } +public protocol HiddenProtocolWithOverride { + func hiddenOverride() +} + #elseif PUBLIC_LIB @_implementationOnly import private_lib -struct LibProtocolContraint { } +struct LibProtocolConstraint { } protocol LibProtocolTABound { } struct LibProtocolTA: LibProtocolTABound { } @@ -52,13 +60,13 @@ protocol LibProtocol { func hiddenRequirement( param: HiddenGenStruct - ) where A.Value == LibProtocolContraint + ) where A.Value == LibProtocolConstraint } extension LibProtocol where TA == LibProtocolTA { func hiddenRequirement( param: HiddenGenStruct - ) where A.Value == LibProtocolContraint { } + ) where A.Value == LibProtocolConstraint { } } public struct PublicStruct: LibProtocol { @@ -70,6 +78,10 @@ public struct PublicStruct: LibProtocol { public var wrappedVar: String } +struct StructWithOverride: HiddenProtocolWithOverride { + func hiddenOverride() {} +} + #elseif CLIENT_APP import public_lib diff --git a/test/SourceKit/CodeComplete/complete_build_session.swift b/test/SourceKit/CodeComplete/complete_build_session.swift new file mode 100644 index 0000000000000..fa6b7c7860587 --- /dev/null +++ b/test/SourceKit/CodeComplete/complete_build_session.swift @@ -0,0 +1,88 @@ +import Foo + +func test() { + +} + +// UNSUPPORTED: OS=windows-msvc + +// ----------------------------------------------------------------------------- +// Test that modifications for frameworks in '-Fsystem' doesn't affect the result. + +// RUN: %empty-directory(%t/ModuleCache) +// RUN: %empty-directory(%t/System/Frameworks) +// RUN: cp -R %S/../Inputs/build_session/Frameworks/Foo.framework %t/System/Frameworks/ +// RUN: cp -R %S/../Inputs/build_session/Frameworks/FooHelper.framework %t/System/Frameworks/ +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## ONE' == \ +// RUN: -req=complete -pos=4:1 %s -- %s -D ONE -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/Foo.framework %t/System/Frameworks/ == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/FooHelper.framework %t/System/Frameworks/ == \ +// RUN: -shell -- echo '## TWO' == \ +// RUN: -req=complete -pos=4:1 %s -- %s -D TWO -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | tee %t.response | %FileCheck %s --check-prefix=CHECK_SYSTEM +// RUN: sleep 2 +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## THREE' == \ +// RUN: -req=complete -pos=4:1 %s -- %s -D TWO -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | %FileCheck %s --check-prefix=CHECK_SYSTEM_2 + +// CHECK_SYSTEM-LABEL: ## ONE +// CHECK_SYSTEM-DAG: key.description: "fooFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooSubFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooHelperFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooHelperSubFunc(arg: Int32)" + +// CHECK_SYSTEM-LABEL: ## TWO +// CHECK_SYSTEM-DAG: key.description: "fooFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooSubFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooHelperFunc(arg: Int32)" +// CHECK_SYSTEM-DAG: key.description: "fooHelperSubFunc(arg: Int32)" + +// CHECK_SYSTEM_2-LABEL: ## THREE +// CHECK_SYSTEM_2-NOT: fooFunc( +// CHECK_SYSTEM_2-NOT: fooSubFunc( +// CHECK_SYSTEM_2-NOT: fooHelperFunc( +// CHECK_SYSTEM_2-NOT: fooHelperSubFunc( +// CHECK_SYSTEM_2-DAG: key.description: "fooFunc_mod(arg: Int32)" +// CHECK_SYSTEM_2-DAG: key.description: "fooSubFunc_mod(arg: Int32)" +// CHECK_SYSTEM_2-DAG: key.description: "fooHelperFunc_mod(arg: Int32)" +// CHECK_SYSTEM_2-DAG: key.description: "fooHelperSubFunc_mod(arg: Int32)" +// CHECK_SYSTEM_2-NOT: fooFunc( +// CHECK_SYSTEM_2-NOT: fooSubFunc( +// CHECK_SYSTEM_2-NOT: fooHelperFunc( +// CHECK_SYSTEM_2-NOT: fooHelperSubFunc( + +// ----------------------------------------------------------------------------- +// Test that modifications for frameworks in '-F' are immidiately propagated +// while modifications for frameworks in '-Fsystem' are not. + +// RUN: %empty-directory(%t/ModuleCache) +// RUN: %empty-directory(%t/Frameworks) +// RUN: %empty-directory(%t/System/Frameworks) +// RUN: cp -R %S/../Inputs/build_session/Frameworks/Foo.framework %t/Frameworks/ +// RUN: cp -R %S/../Inputs/build_session/Frameworks/FooHelper.framework %t/System/Frameworks/ +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## ONE' == \ +// RUN: -req=complete -pos=4:1 %s -- %s -D ONE -F %t/Frameworks -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/Foo.framework %t/Frameworks/ == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/FooHelper.framework %t/System/Frameworks/ == \ +// RUN: -shell -- echo '## TWO' == \ +// RUN: -req=complete -pos=4:1 %s -- %s -D TWO -F %t/Frameworks -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | %FileCheck %s --check-prefix=CHECK_USER + +// CHECK_USER-LABEL: ## ONE +// CHECK_USER-DAG: key.description: "fooFunc(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooSubFunc(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooHelperFunc(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooHelperSubFunc(arg: Int32)" + +// CHECK_USER-LABEL: ## TWO +// CHECK_USER-NOT: fooFunc( +// CHECK_USER-NOT: fooSubFunc( +// CHECK_USER-DAG: key.description: "fooFunc_mod(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooSubFunc_mod(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooHelperFunc(arg: Int32)" +// CHECK_USER-DAG: key.description: "fooHelperSubFunc(arg: Int32)" +// CHECK_USER-NOT: fooFunc( +// CHECK_USER-NOT: fooSubFunc( diff --git a/test/SourceKit/CursorInfo/cursor_info_cross_import.swift b/test/SourceKit/CursorInfo/cursor_info_cross_import.swift new file mode 100644 index 0000000000000..4ee54c88e1398 --- /dev/null +++ b/test/SourceKit/CursorInfo/cursor_info_cross_import.swift @@ -0,0 +1,25 @@ +import A +import B +import C + +func foo(x: From_ABAdditionsType) { + from_ABAdditions() + from__ABAdditionsCAdditions() + fromA() + fromB() + fromC() +} + +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=5:13 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK1 %s +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=6:5 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK2 %s +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=7:5 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK3 %s +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=8:5 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK4 %s +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=9:5 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK5 %s +// RUN: %sourcekitd-test -req=cursor -print-raw-response -pos=10:5 %s -- -Xfrontend -enable-cross-import-overlays -I %S/../Inputs/CrossImport %s | %FileCheck -check-prefix=CHECK6 %s + +// CHECK1: key.modulename: "A" +// CHECK2: key.modulename: "A" +// CHECK3: key.modulename: "A" +// CHECK4: key.modulename: "A" +// CHECK5: key.modulename: "B" +// CHECK6: key.modulename: "C" diff --git a/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift new file mode 100644 index 0000000000000..02cb0d6082298 --- /dev/null +++ b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift @@ -0,0 +1,14 @@ +// RUN: %empty-directory(%t.mod) +// RUN: %empty-directory(%t.mod/mcp) + +// Check doc info shows the decls from each of A's cross-import overlays and lists the required bystander modules. +// +// RUN: %sourcekitd-test -req=doc-info -module A -- -I %S/../Inputs/CrossImport -module-cache-path %t.mod/mcp > %t.response +// RUN: diff --strip-trailing-cr -u %s.A.response %t.response + +// Set up a cross-import module with doc comments and check the synthesized comments don't appear in the fully_annotated_decl entries. +// +// RUN: %target-swift-frontend -emit-module-path %t.mod/_OtherCAdditions.swiftmodule -emit-module-doc-path %t.mod/_OtherCAdditions.swiftdoc -module-cache-path %t.mod/mcp -I %S/../Inputs/CrossImport %S/../Inputs/CrossImport/_OtherCAdditions.swift -parse-as-library +// RUN: %sourcekitd-test -req=doc-info -module Other -- -target %target-triple -I %S/../Inputs/CrossImport -I %t.mod/ -module-cache-path %t.mod/mcp > %t.response +// RUN: diff --strip-trailing-cr -u %s.Other.response %t.response + diff --git a/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.A.response b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.A.response new file mode 100644 index 0000000000000..7e9d1ec0efb5c --- /dev/null +++ b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.A.response @@ -0,0 +1,237 @@ +import SwiftOnoneSupport + +func fromA() + + +// MARK: - B Additions + +import B + +// Available when B is imported with A +struct From_ABAdditionsType { +} + +// Available when B is imported with A +func from_ABAdditions() + + +// MARK: - B and C Additions + +import C + +// Available when B and C are imported with A +func from__ABAdditionsCAdditions() + +// Available when B and C are imported with A +func other(x x: A.From_ABAdditionsType) + + +[ + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 0, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 7, + key.length: 17 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 26, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 31, + key.length: 5 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 41, + key.length: 23 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 44, + key.length: 19 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 65, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 72, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 75, + key.length: 39 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 114, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 121, + key.length: 20 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 147, + key.length: 39 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 186, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 191, + key.length: 16 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 212, + key.length: 29 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 215, + key.length: 25 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 242, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 249, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 252, + key.length: 46 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 298, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 303, + key.length: 27 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 334, + key.length: 46 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 380, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 385, + key.length: 5 + }, + { + key.kind: source.lang.swift.syntaxtype.argument, + key.offset: 391, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.parameter, + key.offset: 393, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.typeidentifier, + key.offset: 396, + key.length: 1 + }, + { + key.kind: source.lang.swift.ref.struct, + key.name: "From_ABAdditionsType", + key.usr: "s:12_ABAdditions05From_A4TypeV", + key.offset: 398, + key.length: 20 + } +] +[ + { + key.kind: source.lang.swift.decl.function.free, + key.name: "fromA()", + key.usr: "s:1A5fromAyyF", + key.offset: 26, + key.length: 12, + key.fully_annotated_decl: "func fromA()" + }, + { + key.kind: source.lang.swift.decl.struct, + key.name: "From_ABAdditionsType", + key.usr: "s:12_ABAdditions05From_A4TypeV", + key.offset: 114, + key.length: 31, + key.fully_annotated_decl: "struct From_ABAdditionsType", + key.required_bystanders: [ + "B" + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.name: "from_ABAdditions()", + key.usr: "s:12_ABAdditions05from_A0yyF", + key.offset: 186, + key.length: 23, + key.fully_annotated_decl: "func from_ABAdditions()", + key.required_bystanders: [ + "B" + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.name: "from__ABAdditionsCAdditions()", + key.usr: "s:23__ABAdditionsCAdditions06from__aB0yyF", + key.offset: 298, + key.length: 34, + key.fully_annotated_decl: "func from__ABAdditionsCAdditions()", + key.required_bystanders: [ + "C", + "B" + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.name: "other(x:)", + key.usr: "s:23__ABAdditionsCAdditions5other1xy01_A005From_A4TypeV_tF", + key.offset: 380, + key.length: 39, + key.fully_annotated_decl: "func other(x: From_ABAdditionsType)", + key.entities: [ + { + key.kind: source.lang.swift.decl.var.local, + key.keyword: "x", + key.name: "x", + key.offset: 396, + key.length: 22 + } + ], + key.required_bystanders: [ + "C", + "B" + ] + } +] diff --git a/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.Other.response b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.Other.response new file mode 100644 index 0000000000000..eeac514ea02dd --- /dev/null +++ b/test/SourceKit/DocSupport/doc_swift_module_cross_import.swift.Other.response @@ -0,0 +1,92 @@ +import SwiftOnoneSupport + +func fromOther() + + +// MARK: - C Additions + +import C + +// Available when C is imported with Other +func from_OtherCAdditions() + + +[ + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 0, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 7, + key.length: 17 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 26, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 31, + key.length: 9 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 45, + key.length: 23 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 48, + key.length: 19 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 69, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 76, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 79, + key.length: 43 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 122, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 127, + key.length: 20 + } +] +[ + { + key.kind: source.lang.swift.decl.function.free, + key.name: "fromOther()", + key.usr: "s:5Other04fromA0yyF", + key.offset: 26, + key.length: 16, + key.fully_annotated_decl: "func fromOther()" + }, + { + key.kind: source.lang.swift.decl.function.free, + key.name: "from_OtherCAdditions()", + key.usr: "s:16_OtherCAdditions05from_aB0yyF", + key.doc.full_as_xml: "from_OtherCAdditions()s:16_OtherCAdditions05from_aB0yyFfunc from_OtherCAdditions()This has some interesting documentation that shouldn’t be separated from the decl when we print the comment detailing its required bystanders in the generated interface of ‘Other’.", + key.offset: 122, + key.length: 27, + key.fully_annotated_decl: "func from_OtherCAdditions()", + key.required_bystanders: [ + "C" + ] + } +] diff --git a/test/SourceKit/Inputs/CrossImport/A.swiftcrossimport/B.swiftoverlay b/test/SourceKit/Inputs/CrossImport/A.swiftcrossimport/B.swiftoverlay new file mode 100644 index 0000000000000..59972574215ff --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/A.swiftcrossimport/B.swiftoverlay @@ -0,0 +1,5 @@ +%YAML 1.2 +--- +version: 1 +modules: + - name: _ABAdditions diff --git a/test/SourceKit/Inputs/CrossImport/A.swiftinterface b/test/SourceKit/Inputs/CrossImport/A.swiftinterface new file mode 100644 index 0000000000000..4a9e7a9f7dc09 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/A.swiftinterface @@ -0,0 +1,6 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name A + +import Swift + +public func fromA() diff --git a/test/SourceKit/Inputs/CrossImport/B.swiftinterface b/test/SourceKit/Inputs/CrossImport/B.swiftinterface new file mode 100644 index 0000000000000..8696b81e3a515 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/B.swiftinterface @@ -0,0 +1,6 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name B + +import Swift + +public func fromB() diff --git a/test/SourceKit/Inputs/CrossImport/C.swiftinterface b/test/SourceKit/Inputs/CrossImport/C.swiftinterface new file mode 100644 index 0000000000000..94a37f8cd6c2c --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/C.swiftinterface @@ -0,0 +1,6 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name C + +import Swift + +public func fromC() diff --git a/test/SourceKit/Inputs/CrossImport/Other.swiftcrossimport/C.swiftoverlay b/test/SourceKit/Inputs/CrossImport/Other.swiftcrossimport/C.swiftoverlay new file mode 100644 index 0000000000000..45340e7b9a861 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/Other.swiftcrossimport/C.swiftoverlay @@ -0,0 +1,5 @@ +%YAML 1.2 +--- +version: 1 +modules: + - name: _OtherCAdditions diff --git a/test/SourceKit/Inputs/CrossImport/Other.swiftinterface b/test/SourceKit/Inputs/CrossImport/Other.swiftinterface new file mode 100644 index 0000000000000..8e1a84b9c9709 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/Other.swiftinterface @@ -0,0 +1,6 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name Other + +import Swift + +public func fromOther() diff --git a/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftcrossimport/C.swiftoverlay b/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftcrossimport/C.swiftoverlay new file mode 100644 index 0000000000000..b26125a4ea762 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftcrossimport/C.swiftoverlay @@ -0,0 +1,5 @@ +%YAML 1.2 +--- +version: 1 +modules: + - name: __ABAdditionsCAdditions diff --git a/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftinterface b/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftinterface new file mode 100644 index 0000000000000..379fbcff822b2 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/_ABAdditions.swiftinterface @@ -0,0 +1,9 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name _ABAdditions + +import Swift +@_exported import A +import B + +public func from_ABAdditions() +public struct From_ABAdditionsType {} diff --git a/test/SourceKit/Inputs/CrossImport/_OtherCAdditions.swift b/test/SourceKit/Inputs/CrossImport/_OtherCAdditions.swift new file mode 100644 index 0000000000000..fd785ddf2572b --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/_OtherCAdditions.swift @@ -0,0 +1,7 @@ +@_exported import Other +import C + +/// This has some interesting documentation that shouldn't be separated from +/// the decl when we print the comment detailing its required bystanders in the +/// generated interface of 'Other'. +public func from_OtherCAdditions() {} diff --git a/test/SourceKit/Inputs/CrossImport/__ABAdditionsCAdditions.swiftinterface b/test/SourceKit/Inputs/CrossImport/__ABAdditionsCAdditions.swiftinterface new file mode 100644 index 0000000000000..ba32500970195 --- /dev/null +++ b/test/SourceKit/Inputs/CrossImport/__ABAdditionsCAdditions.swiftinterface @@ -0,0 +1,9 @@ +// swift-interface-format-version: 1.0 +// swift-module-flags: -swift-version 5 -enable-library-evolution -module-name __ABAdditionsCAdditions + +import Swift +@_exported import _ABAdditions +import C + +public func from__ABAdditionsCAdditions() +public func other(x: _ABAdditions.From_ABAdditionsType) diff --git a/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h new file mode 100644 index 0000000000000..b8d7f2d4eef03 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h @@ -0,0 +1,6 @@ +#if !defined(__FOOSUB_H__) +#define __FOOSUB_H__ 1 + +int fooSubFunc(int arg); + +#endif /* ! __FOOSUB_H__ */ diff --git a/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Headers/Foo.h b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Headers/Foo.h new file mode 100644 index 0000000000000..d7f8c0c68efc1 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Headers/Foo.h @@ -0,0 +1,12 @@ +/* Foo.h + Copyright (c) 1815, Napoleon Bonaparte. All rights reserved. +*/ +#if !defined(__FOO_H__) +#define __FOO_H__ 1 + +#import +#import + +int fooFunc(int arg); + +#endif /* ! __FOO_H__ */ diff --git a/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Modules/module.modulemap b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Modules/module.modulemap new file mode 100644 index 0000000000000..78a2306883146 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/Foo.framework/Modules/module.modulemap @@ -0,0 +1,9 @@ +framework module Foo { + umbrella header "Foo.h" + export * + framework module FooSub { + umbrella header "FooSub.h" + export * + } +} + diff --git a/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h new file mode 100644 index 0000000000000..83ba19ca15ae0 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h @@ -0,0 +1 @@ +int fooHelperSubFunc(int arg); diff --git a/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelper.h b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelper.h new file mode 100644 index 0000000000000..9855e862a1319 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelper.h @@ -0,0 +1,3 @@ +#import + +int fooHelperFunc(int arg); diff --git a/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelperExplicit.h b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelperExplicit.h new file mode 100644 index 0000000000000..2164daeeedc95 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Headers/FooHelperExplicit.h @@ -0,0 +1 @@ +int fooHelperExplicitFunc(int a); diff --git a/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Modules/module.modulemap b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Modules/module.modulemap new file mode 100644 index 0000000000000..b08073c053a59 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks/FooHelper.framework/Modules/module.modulemap @@ -0,0 +1,12 @@ +framework module FooHelper { + umbrella header "FooHelper.h" + + framework module FooHelperSub { + umbrella header "FooHelperSub.h" + } + + explicit module FooHelperExplicit { + header "FooHelperExplicit.h" + } +} + diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h new file mode 100644 index 0000000000000..056e40b475050 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Frameworks/FooSub.framework/Headers/FooSub.h @@ -0,0 +1,6 @@ +#if !defined(__FOOSUB_H__) +#define __FOOSUB_H__ 1 + +int fooSubFunc_mod(int arg); + +#endif /* ! __FOOSUB_H__ */ diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Headers/Foo.h b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Headers/Foo.h new file mode 100644 index 0000000000000..b1a6b86e46bde --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Headers/Foo.h @@ -0,0 +1,12 @@ +/* Foo.h + Copyright (c) 1815, Napoleon Bonaparte. All rights reserved. +*/ +#if !defined(__FOO_H__) +#define __FOO_H__ 1 + +#import +#import + +int fooFunc_mod(int arg); + +#endif /* ! __FOO_H__ */ diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Modules/module.modulemap b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Modules/module.modulemap new file mode 100644 index 0000000000000..78a2306883146 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/Foo.framework/Modules/module.modulemap @@ -0,0 +1,9 @@ +framework module Foo { + umbrella header "Foo.h" + export * + framework module FooSub { + umbrella header "FooSub.h" + export * + } +} + diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h new file mode 100644 index 0000000000000..2369343effb09 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Frameworks/FooHelperSub.framework/Headers/FooHelperSub.h @@ -0,0 +1 @@ +int fooHelperSubFunc_mod(int arg); diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelper.h b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelper.h new file mode 100644 index 0000000000000..c6ef17d128bd3 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelper.h @@ -0,0 +1,3 @@ +#import + +int fooHelperFunc_mod(int arg); diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelperExplicit.h b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelperExplicit.h new file mode 100644 index 0000000000000..d9236acd248cf --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Headers/FooHelperExplicit.h @@ -0,0 +1 @@ +int fooHelperExplicitFunc_mod(int a); diff --git a/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Modules/module.modulemap b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Modules/module.modulemap new file mode 100644 index 0000000000000..b08073c053a59 --- /dev/null +++ b/test/SourceKit/Inputs/build_session/Frameworks_modified/FooHelper.framework/Modules/module.modulemap @@ -0,0 +1,12 @@ +framework module FooHelper { + umbrella header "FooHelper.h" + + framework module FooHelperSub { + umbrella header "FooHelperSub.h" + } + + explicit module FooHelperExplicit { + header "FooHelperExplicit.h" + } +} + diff --git a/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift new file mode 100644 index 0000000000000..077fb53c72c41 --- /dev/null +++ b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift @@ -0,0 +1,22 @@ +// RUN: %empty-directory(%t.mod) +// RUN: %empty-directory(%t.mod/mcp) + +// Check the interface shows the decls from each of A's cross-import overlays. +// +// RUN: %sourcekitd-test -req=interface-gen -module A -- -I %S/../Inputs/CrossImport -module-cache-path %t.mod/mcp > %t.response +// RUN: diff --strip-trailing-cr -u %s.A.response %t.response + +// Make sure cursor info within the generated interface of A on one of the +// decls originally from a cross-import decls shows 'A' as the parent module. +// +// RUN: %sourcekitd-test -req=interface-gen-open -module A -- -I %S/../Inputs/CrossImport -module-cache-path %t.mod/mcp == -req=cursor -print-raw-response -pos=11:15 -- -I %S/../Inputs/CrossImport -Xfrontend -enable-cross-import-overlays > %t.response +// RUN: %FileCheck --input-file %t.response %s +// +// CHECK: key.name: "From_ABAdditionsType" +// CHECK: key.modulename: "A" + +// Set up a cross-import module with doc comments +// +// RUN: %target-swift-frontend -emit-module-path %t.mod/_OtherCAdditions.swiftmodule -emit-module-doc-path %t.mod/_OtherCAdditions.swiftdoc -module-cache-path %t.mod/mcp -I %S/../Inputs/CrossImport %S/../Inputs/CrossImport/_OtherCAdditions.swift -parse-as-library +// RUN: %sourcekitd-test -req=interface-gen -module Other -- -target %target-triple -I %S/../Inputs/CrossImport -I %t.mod/ -module-cache-path %t.mod/mcp > %t.response +// RUN: diff --strip-trailing-cr -u %s.Other.response %t.response diff --git a/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.A.response b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.A.response new file mode 100644 index 0000000000000..47b8904897690 --- /dev/null +++ b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.A.response @@ -0,0 +1,313 @@ +import SwiftOnoneSupport + +public func fromA() + + +// MARK: - B Additions + +import B + +// Available when B is imported with A +public struct From_ABAdditionsType { +} + +// Available when B is imported with A +public func from_ABAdditions() + + +// MARK: - B and C Additions + +import C + +// Available when B and C are imported with A +public func from__ABAdditionsCAdditions() + +// Available when B and C are imported with A +public func other(x: A.From_ABAdditionsType) + + +[ + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 0, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 7, + key.length: 17 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 26, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 33, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 38, + key.length: 5 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 48, + key.length: 23 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 51, + key.length: 19 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 72, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 79, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 82, + key.length: 39 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 121, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 128, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 135, + key.length: 20 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 161, + key.length: 39 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 200, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 207, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 212, + key.length: 16 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 233, + key.length: 29 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 236, + key.length: 25 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 263, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 270, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 273, + key.length: 46 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 319, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 326, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 331, + key.length: 27 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 362, + key.length: 46 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 408, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 415, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 420, + key.length: 5 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 426, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.typeidentifier, + key.offset: 429, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.typeidentifier, + key.offset: 431, + key.length: 20 + } +] +[ + { + key.kind: source.lang.swift.ref.module, + key.offset: 7, + key.length: 17, + key.is_system: 1 + }, + { + key.kind: source.lang.swift.ref.module, + key.offset: 79, + key.length: 1 + }, + { + key.kind: source.lang.swift.ref.module, + key.offset: 270, + key.length: 1 + }, + { + key.kind: source.lang.swift.ref.module, + key.offset: 429, + key.length: 1 + }, + { + key.kind: source.lang.swift.ref.struct, + key.offset: 431, + key.length: 20 + } +] +[ + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "fromA()", + key.offset: 33, + key.length: 12, + key.nameoffset: 38, + key.namelength: 7, + key.attributes: [ + { + key.offset: 26, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + }, + { + key.kind: source.lang.swift.decl.struct, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "From_ABAdditionsType", + key.offset: 128, + key.length: 31, + key.nameoffset: 135, + key.namelength: 20, + key.bodyoffset: 157, + key.bodylength: 1, + key.attributes: [ + { + key.offset: 121, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "from_ABAdditions()", + key.offset: 207, + key.length: 23, + key.nameoffset: 212, + key.namelength: 18, + key.attributes: [ + { + key.offset: 200, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "from__ABAdditionsCAdditions()", + key.offset: 326, + key.length: 34, + key.nameoffset: 331, + key.namelength: 29, + key.attributes: [ + { + key.offset: 319, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "other(x:)", + key.offset: 415, + key.length: 37, + key.nameoffset: 420, + key.namelength: 32, + key.attributes: [ + { + key.offset: 408, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ], + key.substructure: [ + { + key.kind: source.lang.swift.decl.var.parameter, + key.name: "x", + key.offset: 426, + key.length: 25, + key.typename: "A.From_ABAdditionsType", + key.nameoffset: 426, + key.namelength: 1 + } + ] + } +] diff --git a/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.Other.response b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.Other.response new file mode 100644 index 0000000000000..cfe62635e1895 --- /dev/null +++ b/test/SourceKit/InterfaceGen/gen_swift_module_cross_import.swift.Other.response @@ -0,0 +1,147 @@ +import SwiftOnoneSupport + +public func fromOther() + + +// MARK: - C Additions + +import C + +// Available when C is imported with Other +/// This has some interesting documentation that shouldn't be separated from +/// the decl when we print the comment detailing its required bystanders in the +/// generated interface of 'Other'. +public func from_OtherCAdditions() + + +[ + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 0, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 7, + key.length: 17 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 26, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 33, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 38, + key.length: 9 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 52, + key.length: 23 + }, + { + key.kind: source.lang.swift.syntaxtype.comment.mark, + key.offset: 55, + key.length: 19 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 76, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 83, + key.length: 1 + }, + { + key.kind: source.lang.swift.syntaxtype.comment, + key.offset: 86, + key.length: 43 + }, + { + key.kind: source.lang.swift.syntaxtype.doccomment, + key.offset: 129, + key.length: 77 + }, + { + key.kind: source.lang.swift.syntaxtype.doccomment, + key.offset: 206, + key.length: 80 + }, + { + key.kind: source.lang.swift.syntaxtype.doccomment, + key.offset: 286, + key.length: 36 + }, + { + key.kind: source.lang.swift.syntaxtype.attribute.builtin, + key.offset: 322, + key.length: 6 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 329, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 334, + key.length: 20 + } +] +[ + { + key.kind: source.lang.swift.ref.module, + key.offset: 7, + key.length: 17, + key.is_system: 1 + }, + { + key.kind: source.lang.swift.ref.module, + key.offset: 83, + key.length: 1 + } +] +[ + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "fromOther()", + key.offset: 33, + key.length: 16, + key.nameoffset: 38, + key.namelength: 11, + key.attributes: [ + { + key.offset: 26, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + }, + { + key.kind: source.lang.swift.decl.function.free, + key.accessibility: source.lang.swift.accessibility.public, + key.name: "from_OtherCAdditions()", + key.offset: 329, + key.length: 27, + key.nameoffset: 334, + key.namelength: 22, + key.docoffset: 129, + key.doclength: 193, + key.attributes: [ + { + key.offset: 322, + key.length: 6, + key.attribute: source.decl.attribute.public + } + ] + } +] diff --git a/test/SourceKit/Sema/sema_build_session.swift b/test/SourceKit/Sema/sema_build_session.swift new file mode 100644 index 0000000000000..af3eee728627c --- /dev/null +++ b/test/SourceKit/Sema/sema_build_session.swift @@ -0,0 +1,83 @@ +import Foo +import FooHelper.FooHelperExplicit + +func swiftFunc() -> Int { 1 } + +func test() { + _ = fooFunc(1) + _ = fooSubFunc(1) + _ = fooHelperFunc(1) + _ = fooHelperSubFunc(1) + _ = fooHelperExplicitFunc(1) + _ = swiftFunc() +} + +// UNSUPPORTED: OS=windows-msvc + +// ----------------------------------------------------------------------------- +// Test that modifications for frameworks in '-Fsystem' doesn't affect the result. + +// RUN: %empty-directory(%t/ModuleCache) +// RUN: %empty-directory(%t/System/Frameworks) +// RUN: cp -R %S/../Inputs/build_session/Frameworks/Foo.framework %t/System/Frameworks/ +// RUN: cp -R %S/../Inputs/build_session/Frameworks/FooHelper.framework %t/System/Frameworks/ +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## ONE' == \ +// RUN: -req=sema %s -- %s -D ONE -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/Foo.framework %t/System/Frameworks/ == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/FooHelper.framework %t/System/Frameworks/ == \ +// RUN: -shell -- echo '## TWO' == \ +// RUN: -req=sema %s -- %s -D TWO -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | %FileCheck %s --check-prefix=CHECK_SYSTEM +// RUN: sleep 2 +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## THREE' == \ +// RUN: -req=sema %s -- %s -D THREE -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | %FileCheck %s --check-prefix=CHECK_SYSTEM_2 + +// CHECK_SYSTEM-LABEL: ## ONE +// CHECK_SYSTEM-NOT: key.description + +// CHECK_SYSTEM-LABEL: ## TWO +// CHECK_SYSTEM-NOT: key.description + +// CHECK_SYSTEM_2-LABEL: ## THREE +// CHECK_SYSTEM_2: key.severity: source.diagnostic.severity.error, +// CHECK_SYSTEM_2-NEXT: key.description: "use of unresolved identifier 'fooFunc'", +// CHECK_SYSTEM_2: key.severity: source.diagnostic.severity.error, +// CHECK_SYSTEM_2-NEXT: key.description: "use of unresolved identifier 'fooSubFunc'", +// CHECK_SYSTEM_2: key.severity: source.diagnostic.severity.error, +// CHECK_SYSTEM_2-NEXT: key.description: "use of unresolved identifier 'fooHelperFunc'", +// CHECK_SYSTEM_2: key.severity: source.diagnostic.severity.error, +// CHECK_SYSTEM_2-NEXT: key.description: "use of unresolved identifier 'fooHelperSubFunc'", +// CHECK_SYSTEM_2: key.severity: source.diagnostic.severity.error, +// CHECK_SYSTEM_2-NEXT: key.description: "use of unresolved identifier 'fooHelperExplicitFunc'", + +// ----------------------------------------------------------------------------- +// Test that modifications for frameworks in '-F' are immidiately propagated +// while modifications for frameworks in '-Fsystem' are not. + +// RUN: %empty-directory(%t/ModuleCache) +// RUN: %empty-directory(%t/Frameworks) +// RUN: %empty-directory(%t/System/Frameworks) +// RUN: cp -R %S/../Inputs/build_session/Frameworks/Foo.framework %t/Frameworks/ +// RUN: cp -R %S/../Inputs/build_session/Frameworks/FooHelper.framework %t/System/Frameworks/ +// RUN: %sourcekitd-test \ +// RUN: -shell -- echo '## ONE' == \ +// RUN: -req=sema %s -- %s -D ONE -F %t/Frameworks -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/Foo.framework %t/Frameworks/ == \ +// RUN: -shell -- cp -R %S/../Inputs/build_session/Frameworks_modified/FooHelper.framework %t/System/Frameworks/ == \ +// RUN: -shell -- echo '## TWO' == \ +// RUN: -req=sema %s -- %s -D TWO -F %t/Frameworks -Fsystem %t/System/Frameworks -module-cache-path %t/ModuleCache \ +// RUN: | tee %t.reponse | %FileCheck %s --check-prefix=CHECK_USER + +// CHECK_USER-LABEL: ## ONE +// CHECK_USER-NOT: key.description + +// CHECK_USER-LABEL: ## TWO +// CHECK_USER-NOT: key.severity: +// CHECK_USER: key.severity: source.diagnostic.severity.error, +// CHECK_USER-NEXT: key.description: "use of unresolved identifier 'fooFunc'", +// CHECK_USER: key.severity: source.diagnostic.severity.error, +// CHECK_USER-NEXT: key.description: "use of unresolved identifier 'fooSubFunc'", +// CHECK_USER-NOT: key.severity: diff --git a/test/TypeCoercion/overload_member.swift b/test/TypeCoercion/overload_member.swift index 4ebd3131fd7c2..84dc1482c1811 100644 --- a/test/TypeCoercion/overload_member.swift +++ b/test/TypeCoercion/overload_member.swift @@ -69,7 +69,7 @@ func test_static_method_value_coerce(_ a: A) { func test_mixed_overload(_ a: A, x: X, y: Y) { var x1 = a.mixed(x: x) x1 = x - var y1 = a.mixed(y: y) // expected-error {{static member 'mixed' cannot be used on instance of type 'A'}} {{12-12=A.}} + var y1 = a.mixed(y: y) // expected-error {{static member 'mixed' cannot be used on instance of type 'A'}} {{12-13=A}} A.mixed(x) // expected-error{{cannot convert value of type 'X' to expected argument type 'A'}} var x2 = A.mixed(a)(x: x) @@ -89,7 +89,7 @@ func test_mixed_overload_coerce(_ a: A, x: inout X, y: Y, z: Z) { func test_mixed_method_value_coerce(_ a: A) { var _ : (X) -> X = a.mixed var _ : (Y) -> Y = A.mixed - var _ : (Y) -> Y = a.mixed; // expected-error {{static member 'mixed' cannot be used on instance of type 'A'}} {{22-22=A.}} + var _ : (Y) -> Y = a.mixed; // expected-error {{static member 'mixed' cannot be used on instance of type 'A'}} {{22-23=A}} var _ : (A) -> (X) -> X = A.mixed } diff --git a/test/api-digester/stability-stdlib-abi-with-asserts.swift b/test/api-digester/stability-stdlib-abi-with-asserts.swift index e5dabff45aa7b..2f2a345a85407 100644 --- a/test/api-digester/stability-stdlib-abi-with-asserts.swift +++ b/test/api-digester/stability-stdlib-abi-with-asserts.swift @@ -1,3 +1,4 @@ +// REQUIRES: rdar60088553 // REQUIRES: OS=macosx // REQUIRES: swift_stdlib_asserts // RUN: %empty-directory(%t.tmp) diff --git a/test/api-digester/stability-stdlib-abi-without-asserts.swift b/test/api-digester/stability-stdlib-abi-without-asserts.swift index 6f81cd4f55789..c238c7b528f4a 100644 --- a/test/api-digester/stability-stdlib-abi-without-asserts.swift +++ b/test/api-digester/stability-stdlib-abi-without-asserts.swift @@ -1,3 +1,4 @@ +// REQUIRES: rdar60088553 // REQUIRES: OS=macosx // REQUIRES: swift_stdlib_no_asserts // RUN: %empty-directory(%t.tmp) diff --git a/test/decl/enum/enumtest.swift b/test/decl/enum/enumtest.swift index 94a68fdedbb7d..57b8117e65b0c 100644 --- a/test/decl/enum/enumtest.swift +++ b/test/decl/enum/enumtest.swift @@ -36,8 +36,7 @@ func test1a() -> unionSearchFlags { func test1b(_ b : Bool) { _ = 123 - _ = .description == 1 // expected-error {{instance member 'description' cannot be used on type 'Int'}} - // expected-error@-1 {{member 'description' in 'Int' produces result of type 'String', but context expects 'Int'}} + _ = .description == 1 // expected-error {{cannot infer contextual base in reference to member 'description'}} } enum MaybeInt { diff --git a/test/decl/operator/Inputs/lookup_moduleA.swift b/test/decl/operator/Inputs/lookup_moduleA.swift new file mode 100644 index 0000000000000..d0b492f9a7679 --- /dev/null +++ b/test/decl/operator/Inputs/lookup_moduleA.swift @@ -0,0 +1,36 @@ + +// Only declared in module A. +prefix operator >>> +public prefix func >>> (rhs: Int) {} + +precedencegroup DeclaredInModuleA {} + +// Declared in both modules A and B. +infix operator ??? +precedencegroup DeclaredInModulesAB {} + +// Declared in both modules A and B, but with a different +// precedence group in each. +infix operator ???? : DeclaredInModuleA + +// Declared in both modules A and B, but shadowed by lookup_other. +precedencegroup DeclaredInModulesABShadowed {} + +// Declared in both modules A and B. +postfix operator >> +postfix operator <<< + +precedencegroup P1 {} +infix operator ^^^^ : P1 + +infix operator &&& diff --git a/test/decl/operator/lookup_compatibility.swift b/test/decl/operator/lookup_compatibility.swift new file mode 100644 index 0000000000000..603ab9149fd3c --- /dev/null +++ b/test/decl/operator/lookup_compatibility.swift @@ -0,0 +1,125 @@ +// RUN: %empty-directory(%t) + +// RUN: %target-swift-frontend -emit-module %S/Inputs/lookup_moduleD.swift -module-name D -o %t -I %t +// RUN: %target-swift-frontend -emit-module %S/Inputs/lookup_moduleC.swift -module-name C -o %t -I %t +// RUN: %target-swift-frontend -emit-module %S/Inputs/lookup_moduleB.swift -module-name B -o %t -I %t +// RUN: %target-swift-frontend -emit-module %S/Inputs/lookup_moduleA.swift -module-name A -o %t -I %t +// RUN: %target-swift-frontend -emit-module %S/Inputs/lookup_module_exportsAC.swift -module-name ExportsAC -o %t -I %t +// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/lookup_other.swift %S/Inputs/lookup_other2.swift %S/Inputs/lookup_other_compat.swift -I %t + +import ExportsAC +import B + +infix operator ^^^ : DeclaredAcrossFiles +func ^^^ (lhs: Int, rhs: Int) -> Int { 0 } +func &&& (lhs: Int, rhs: Int) -> Int { 0 } + +// FIXME(SR-12132): The operator decl >>> is declared in module A, which we +// should be able to see through ExportsAC. +prefix func >>> (rhs: Double) {} // expected-error {{operator implementation without matching operator declaration}} + +// FIXME(SR-12132): We should also see precedencegroups in module A through +// ExportsAC. +infix operator ^^^^ : DeclaredInModuleA // expected-error {{unknown precedence group 'DeclaredInModuleA'}} + +// The operator decl for ??? is declared in both modules A and B, but has the +// same default precedence group in both, so there's no ambiguity. +func ??? (lhs: Int, rhs: Int) {} + +// Same for ???!, declared in modules ExportsAC and B, but has the same +// precedence group in both. +func ???! (lhs: Int, rhs: Int) {} + +// The operator decl for ???? is declared in both modules A and B, and has a +// different precedence group in each. This should therefore be ambiguous. +// However, for compatibility, we don't look through exports in other modules, +// so we don't see the one in module A. +func ???? (lhs: Int, rhs: Int) {} + +// The operator decl for ????! is declared in both modules ExportsAC and B, and +// has a different precedence group in each. Therefore ambiguous. +// FIXME: We shouldn't emit the unknown operator decl error. +func ????! (lhs: Int, rhs: Int) {} // expected-error {{ambiguous operator declarations found for operator}} +// expected-error@-1 {{operator implementation without matching operator declaration}} + +// Same as ????, the precedencegroup is declared in both modules A and B, but +// we don't look into module A for compatibility. +infix operator : DeclaredInModulesAB + +// The precedencegroup is declared in both modules ExportsAC and B, therefore +// ambiguous. +// FIXME: We shouldn't emit the 'unknown precedence group' error. +infix operator : DeclaredInModulesBExportsAC // expected-error {{multiple precedence groups found}} +// expected-error@-1 {{unknown precedence group 'DeclaredInModulesBExportsAC'}} + +// This precedencegroup is declared in this module as well as in both modules A +// and B. The decl in this module should shadow the imported ones, but for +// compatibility we don't see module A's decl and take module B's decl. +infix operator : DeclaredInModulesABShadowed + +// The operator decl for : ShadowsModuleA + +// This precedencegroup is declared in modules A, C, and ExportsAC, but the +// latter shadows both of the former. +infix operator : ShadowsModulesAC + +// This operator decl is declared in modules A, C, and ExportsAC, but the +// latter shadows both of the former. +func ????? (lhs: Int, rhs: Int) {} + +// This operator decl is declared in modules A, C, and ExportsAC, but the +// latter shadows both of the former, despite them having different +// precedencegroups. +func ?????? (lhs: Int, rhs: Int) {} + +// FIXME: Module D is imported through exports in both lookup_other and +// lookup_other2, but we fail to detect the fact that we're visiting the same +// thing twice. +infix operator <> : DeclaredInModuleD // expected-error {{unknown precedence group 'DeclaredInModuleD'}} + +// Also declared in lookup_other. To preserve compatibility, we allow an +// unambiguous lookup that will favor this declaration over lookup_other. +precedencegroup RedeclaredInModule {} +infix operator *** : RedeclaredInModule // Okay. + +func testOperatorLookup() { + // In lookup_other, DeclaredAcrossFiles is left associative, whereas in + // module B it is non-associative. Make sure we use module B's for + // compatibility. + _ = 5 ^^^ 5 ^^^ 5 + // expected-error@-1 {{adjacent operators are in unordered precedence groups 'AssignmentPrecedence' and 'DeclaredAcrossFiles'}} + // expected-error@-2 {{adjacent operators are in non-associative precedence group 'DeclaredAcrossFiles'}} + // expected-error@-3 {{cannot convert value of type '()' to expected argument type 'Int'}} + + // Same for &&&, in lookup_other it is declared as left associative. + _ = 5 &&& 5 &&& 5 // expected-error {{adjacent operators are in non-associative precedence group 'DefaultPrecedence'}} + + // The operator >>> is declared in module A, which we should be able to see + // through ExportsAC. + >>>1 + + // We've been evil and overriden TernaryPrecedence in both modules A and B. + // FIXME: We shouldn't emit the 'broken stdlib' error. + true ? () : () // expected-error {{multiple precedence groups found}} + // expected-error@-1 {{broken standard library: missing builtin precedence group 'TernaryPrecedence'}} +} + +precedencegroup CastingPrecedence { + lowerThan: AssignmentPrecedence +} + +func testBuiltinPrecedenceGroupOverriding() { + // Evil, but allowed. + var x = 0 + x = 0 as Int // expected-error {{cannot convert value of type '()' to type 'Int' in coercion}} +} diff --git a/test/decl/operator/redeclaration_compatibility.swift b/test/decl/operator/redeclaration_compatibility.swift new file mode 100644 index 0000000000000..c593c38a7d53c --- /dev/null +++ b/test/decl/operator/redeclaration_compatibility.swift @@ -0,0 +1,32 @@ +// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/redeclaration_other_compat.swift + +// We currently allow cross-file redeclarations. +precedencegroup RedeclaredAcrossFiles {} + +precedencegroup RedeclaredSameFile {} // expected-note {{previous precedence group declaration here}} +precedencegroup RedeclaredSameFile {} // expected-error {{precedence group redeclared}} + +precedencegroup RedeclaredSameFile2 { // expected-note {{previous precedence group declaration here}} + assignment: true +} +precedencegroup RedeclaredSameFile2 {} // expected-error {{precedence group redeclared}} + +// These are all declared in the other file, and so are allowed for now. +infix operator ^^^ +prefix operator >>> +postfix operator <<< +infix operator ^^^^ + +// This is declared as an infix operator in the other file, so no problem. +prefix operator &&& +postfix operator &&& + +infix operator %%% // expected-note {{previous operator declaration here}} +infix operator %%% // expected-error {{operator redeclared}} + +prefix operator %%% // expected-note {{previous operator declaration here}} +prefix operator %%% // expected-error {{operator redeclared}} + +precedencegroup P2 {} +infix operator *** : P2 // expected-note {{previous operator declaration here}} +infix operator *** // expected-error {{operator redeclared}} diff --git a/test/decl/var/property_wrappers_library_evolution.swift b/test/decl/var/property_wrappers_library_evolution.swift new file mode 100644 index 0000000000000..3a690c0913640 --- /dev/null +++ b/test/decl/var/property_wrappers_library_evolution.swift @@ -0,0 +1,20 @@ +// RUN: %target-swift-frontend -typecheck %s -verify -enable-library-evolution + +@propertyWrapper +public struct ResilientWrapper { + public var wrappedValue: T + + public init(wrappedValue: T, description: String) { + self.wrappedValue = wrappedValue + } +} + +func getHello() -> String { return "hello" } // expected-note 2 {{global function 'getHello()' is not '@usableFromInline' or public}} + +@frozen +public struct StructUsesPublishedAsPrivate { + public var integer: Int = 17 + + @ResilientWrapper(description: getHello()) // expected-error 2 {{global function 'getHello()' is internal and cannot be referenced from a property initializer in a '@frozen' type}} + var otherString: String = "World" +} diff --git a/test/expr/expressions.swift b/test/expr/expressions.swift index d7d8f7126cc16..2f739346437b5 100644 --- a/test/expr/expressions.swift +++ b/test/expr/expressions.swift @@ -709,7 +709,6 @@ func test() { func unusedExpressionResults() { // Unused l-value _ // expected-error{{'_' can only appear in a pattern or on the left side of an assignment}} - // expected-error@-1 {{expression resolves to an unused variable}} // Conditional Optional binding hides compiler error let optionalc:C? = nil diff --git a/test/lit.cfg b/test/lit.cfg index cabbca8f0c90f..5c5dc901fbb9f 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -1084,7 +1084,7 @@ elif run_os in ['windows-msvc']: ('%r -emit-pcm -target %s' % (config.swiftc, config.variant_triple)) -elif (run_os in ['linux-gnu', 'linux-gnueabihf', 'freebsd', 'windows-cygnus', 'windows-gnu'] or +elif (run_os in ['linux-gnu', 'linux-gnueabihf', 'freebsd', 'openbsd', 'windows-cygnus', 'windows-gnu'] or (kIsAndroid and run_os in ['linux-android', 'linux-androideabi'])): # Running lit and the compiler on Android itself is more like running on Linux, # ie the NDK and adb aren't needed, so use this instead. @@ -1107,6 +1107,12 @@ elif (run_os in ['linux-gnu', 'linux-gnueabihf', 'freebsd', 'windows-cygnus', 'w config.target_shared_library_prefix = 'lib' config.target_shared_library_suffix = ".so" config.target_sdk_name = "freebsd" + elif run_os == 'openbsd': + lit_config.note("Testing OpenBSD " + config.variant_triple) + config.target_object_format = "elf" + config.target_shared_library_prefix = 'lib' + config.target_shared_library_suffix = ".so" + config.target_sdk_name = "openbsd" elif kIsAndroid: lit_config.note("Testing Android " + config.variant_triple) config.target_object_format = "elf" diff --git a/test/stdlib/Dispatch.swift b/test/stdlib/Dispatch.swift index 69a1aca39c2c1..1f0403e2aade7 100644 --- a/test/stdlib/Dispatch.swift +++ b/test/stdlib/Dispatch.swift @@ -231,3 +231,55 @@ DispatchAPI.test("DispatchTimeInterval.never.equals") { expectTrue(DispatchTimeInterval.never != DispatchTimeInterval.seconds(10)); expectTrue(DispatchTimeInterval.seconds(10) == DispatchTimeInterval.seconds(10)); } + +// Only support 64bit +#if !(os(iOS) && (arch(i386) || arch(arm))) + +import Combine + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +private func clampedIntProduct(_ m1: Int, _ m2: UInt64) -> Int { + assert(m2 > 0, "multiplier must be positive") + guard m1 < Int.max, m2 < Int.max else { return Int.max } + let (result, overflow) = m1.multipliedReportingOverflow(by: Int(m2)) + if overflow { + return m1 > 0 ? Int.max : Int.min + } + return result +} + +@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) +extension DispatchTimeInterval { + fileprivate var nanoseconds: Int { + switch self { + case .seconds(let s): return clampedIntProduct(s, NSEC_PER_SEC) + case .milliseconds(let ms): return clampedIntProduct(ms, NSEC_PER_MSEC) + case .microseconds(let us): return clampedIntProduct(us, NSEC_PER_USEC) + case .nanoseconds(let ns): return ns + case .never: return Int.max + } + } +} + +DispatchAPI.test("DispatchTime.SchedulerTimeType.Stridable") { + if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { + // Basic checks for time types + for i in stride(from:1, through: 100, by: 5) { + let time1 = DispatchTime(uptimeNanoseconds: UInt64(i)) + let time2 = DispatchTime(uptimeNanoseconds: UInt64(i + 1)) + let schedulerTime1 = DispatchQueue.SchedulerTimeType(time1) + let schedulerTime2 = DispatchQueue.SchedulerTimeType(time2) + let addedTime = time2.distance(to: time1) + let addedSchedulerTime = schedulerTime2.distance(to: schedulerTime1) + expectEqual(addedTime.nanoseconds, addedSchedulerTime.magnitude) + } + + + let time1 = DispatchQueue.SchedulerTimeType(.init(uptimeNanoseconds: 10000)) + let time2 = DispatchQueue.SchedulerTimeType(.init(uptimeNanoseconds: 10431)) + let addedTime = time2.distance(to: time1) + expectEqual(addedTime.magnitude, (10000 - 10431)) + } +} + +#endif \ No newline at end of file diff --git a/test/stdlib/Error.swift b/test/stdlib/Error.swift index 1d095e80c9100..fc40c77b061f3 100644 --- a/test/stdlib/Error.swift +++ b/test/stdlib/Error.swift @@ -193,7 +193,7 @@ ErrorTests.test("test dealloc empty error box") { var errors: [Error] = [] ErrorTests.test("willThrow") { - if #available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { // Error isn't allowed in a @convention(c) function when ObjC interop is // not available, so pass it through an OpaquePointer. typealias WillThrow = @convention(c) (OpaquePointer) -> Void diff --git a/test/stdlib/ErrorBridged.swift b/test/stdlib/ErrorBridged.swift index c7fc1d051d681..4aa8c7156c885 100644 --- a/test/stdlib/ErrorBridged.swift +++ b/test/stdlib/ErrorBridged.swift @@ -770,7 +770,7 @@ ErrorBridgingTests.test("@objc error domains for nested types") { ErrorBridgingTests.test("error-to-NSObject casts") { let error = MyCustomizedError(code: 12345) - if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { // Unconditional cast let nsErrorAsObject1 = unconditionalCast(error, to: NSObject.self) let nsError1 = unconditionalCast(nsErrorAsObject1, to: NSError.self) @@ -800,7 +800,7 @@ ErrorBridgingTests.test("NSError-to-Error casts") { expectTrue(something is Error) } - if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { // TODO: Wrap some leak checking around this // Until then, this is a helpful debug tool should_not_leak_nserror() @@ -813,7 +813,7 @@ ErrorBridgingTests.test("CFError-to-Error casts") { expectTrue(something is Error) } - if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { // TODO: Wrap some leak checking around this // Until then, this is a helpful debug tool should_not_leak_cferror() @@ -826,7 +826,7 @@ enum MyError: Error { ErrorBridgingTests.test("SR-9207 crash in failed cast to NSError") { - if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + if #available(macOS 10.15.4, iOS 13.4, watchOS 6.2, tvOS 13.4, *) { let error = MyError.someThing let foundationError = error as NSError diff --git a/test/stdlib/FloatConstants.swift b/test/stdlib/FloatConstants.swift index e9b491b9e1bf2..eb03d30e29622 100644 --- a/test/stdlib/FloatConstants.swift +++ b/test/stdlib/FloatConstants.swift @@ -2,7 +2,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/stdlib/MathConstants.swift b/test/stdlib/MathConstants.swift index 85ba2d0bca6b8..08284143b1a6d 100644 --- a/test/stdlib/MathConstants.swift +++ b/test/stdlib/MathConstants.swift @@ -2,7 +2,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/stdlib/NSEvent.swift b/test/stdlib/NSEvent.swift index 131d188ae65f9..a4b28fb68761f 100644 --- a/test/stdlib/NSEvent.swift +++ b/test/stdlib/NSEvent.swift @@ -17,10 +17,10 @@ func testSpecialKey(_ specialKey: NSEvent.SpecialKey, rawValue: Int) { NSEventTests.test("NSEvent.specialKey") { testSpecialKey(NSEvent.SpecialKey.upArrow, rawValue: 0xF700) - if #available(macOS 9999, *) { + if #available(macOS 10.15.4, *) { // NSEvent.SpecialKey.deleteForward used to have the wrong rawValue in - // macOS 10.15 and below. See https://github.com/apple/swift/pull/26853 - // (rdar://54725550). + // macOS versions below 10.15.4. See + // https://github.com/apple/swift/pull/26853 (rdar://54725550). testSpecialKey(NSEvent.SpecialKey.deleteForward, rawValue: 0xF728) } } diff --git a/test/stdlib/PrintFloat.swift.gyb b/test/stdlib/PrintFloat.swift.gyb index a7bf257fcc09a..e1060be0fee29 100644 --- a/test/stdlib/PrintFloat.swift.gyb +++ b/test/stdlib/PrintFloat.swift.gyb @@ -12,7 +12,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/stdlib/Runtime.swift.gyb b/test/stdlib/Runtime.swift.gyb index 7f659e3a9dcfc..f98fbe0ea4173 100644 --- a/test/stdlib/Runtime.swift.gyb +++ b/test/stdlib/Runtime.swift.gyb @@ -12,7 +12,7 @@ import SwiftShims #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT @@ -605,7 +605,7 @@ Runtime.test("SwiftError layout constants for LLDB") { expectEqual(40, offsetof_SwiftError_typeMetadata.load(as: UInt.self)) expectEqual(72, sizeof_SwiftError.load(as: UInt.self)) #endif -#elseif os(Linux) || os(Android) || os(Windows) +#elseif os(Linux) || os(Android) || os(Windows) || os(OpenBSD) expectEqual(16, offsetof_SwiftError_typeMetadata.load(as: UInt.self)) expectEqual(32, sizeof_SwiftError.load(as: UInt.self)) #else diff --git a/test/stdlib/TestJSONEncoder.swift b/test/stdlib/TestJSONEncoder.swift index fdb3d45d6ff8c..0d1541d42d6bc 100644 --- a/test/stdlib/TestJSONEncoder.swift +++ b/test/stdlib/TestJSONEncoder.swift @@ -10,6 +10,7 @@ // REQUIRES: executable_test // REQUIRES: objc_interop // REQUIRES: rdar55727144 + import Swift import Foundation @@ -37,21 +38,16 @@ class TestJSONEncoder : TestJSONEncoderSuper { // MARK: - Encoding Top-Level Single-Value Types func testEncodingTopLevelSingleValueEnum() { - _testEncodeFailure(of: Switch.off) - _testEncodeFailure(of: Switch.on) - - _testRoundTrip(of: TopLevelWrapper(Switch.off)) - _testRoundTrip(of: TopLevelWrapper(Switch.on)) + _testRoundTrip(of: Switch.off) + _testRoundTrip(of: Switch.on) } func testEncodingTopLevelSingleValueStruct() { - _testEncodeFailure(of: Timestamp(3141592653)) - _testRoundTrip(of: TopLevelWrapper(Timestamp(3141592653))) + _testRoundTrip(of: Timestamp(3141592653)) } func testEncodingTopLevelSingleValueClass() { - _testEncodeFailure(of: Counter()) - _testRoundTrip(of: TopLevelWrapper(Counter())) + _testRoundTrip(of: Counter()) } // MARK: - Encoding Top-Level Structured Types @@ -94,13 +90,9 @@ class TestJSONEncoder : TestJSONEncoderSuper { func testEncodingTopLevelNullableType() { // EnhancedBool is a type which encodes either as a Bool or as nil. - _testEncodeFailure(of: EnhancedBool.true) - _testEncodeFailure(of: EnhancedBool.false) - _testEncodeFailure(of: EnhancedBool.fileNotFound) - - _testRoundTrip(of: TopLevelWrapper(EnhancedBool.true), expectedJSON: "{\"value\":true}".data(using: .utf8)!) - _testRoundTrip(of: TopLevelWrapper(EnhancedBool.false), expectedJSON: "{\"value\":false}".data(using: .utf8)!) - _testRoundTrip(of: TopLevelWrapper(EnhancedBool.fileNotFound), expectedJSON: "{\"value\":null}".data(using: .utf8)!) + _testRoundTrip(of: EnhancedBool.true, expectedJSON: "true".data(using: .utf8)!) + _testRoundTrip(of: EnhancedBool.false, expectedJSON: "false".data(using: .utf8)!) + _testRoundTrip(of: EnhancedBool.fileNotFound, expectedJSON: "null".data(using: .utf8)!) } func testEncodingMultipleNestedContainersWithTheSameTopLevelKey() { @@ -229,93 +221,74 @@ class TestJSONEncoder : TestJSONEncoderSuper { } // MARK: - Date Strategy Tests - func testEncodingDate() { + + // Disabled for now till we resolve rdar://52618414 + func x_testEncodingDate() { func formattedLength(of value: Double) -> Int { - let empty = UnsafeMutablePointer.allocate(capacity: 0) - defer { empty.deallocate() } - let length = snprintf(ptr: empty, 0, "%0.*g", DBL_DECIMAL_DIG, value) - return Int(length) + let empty = UnsafeMutablePointer.allocate(capacity: 0) + defer { empty.deallocate() } + let length = snprintf(ptr: empty, 0, "%0.*g", DBL_DECIMAL_DIG, value) + return Int(length) } // Duplicated to handle a special case func localTestRoundTrip(of value: T) { - var payload: Data! = nil - do { - let encoder = JSONEncoder() - payload = try encoder.encode(value) - } catch { - expectUnreachable("Failed to encode \(T.self) to JSON: \(error)") - } + var payload: Data! = nil + do { + let encoder = JSONEncoder() + payload = try encoder.encode(value) + } catch { + expectUnreachable("Failed to encode \(T.self) to JSON: \(error)") + } - do { - let decoder = JSONDecoder() - let decoded = try decoder.decode(T.self, from: payload) - - /// `snprintf`'s `%g`, which `JSONSerialization` uses internally for double values, does not respect - /// our precision requests in every case. This bug effects Darwin, FreeBSD, and Linux currently - /// causing this test (which uses the current time) to fail occasionally. - let evalEdgeCase: (Date, Date) -> () = { decodedDate, expectedDate in - if formattedLength(of: decodedDate.timeIntervalSinceReferenceDate) > DBL_DECIMAL_DIG { - let adjustedTimeIntervalSinceReferenceDate: (Date) -> Double = { - let adjustment = pow(10, Double(DBL_DECIMAL_DIG)) - return Double(floor(adjustment * $0.timeIntervalSinceReferenceDate) / adjustment) - } - - let decodedAprox = adjustedTimeIntervalSinceReferenceDate(decodedDate) - let valueAprox = adjustedTimeIntervalSinceReferenceDate(expectedDate) - expectEqual(decodedAprox, valueAprox, "\(T.self) did not round-trip to an equal value after DBL_DECIMAL_DIG adjustment \(decodedAprox) != \(valueAprox).") - } - } - - if let decodedDate = (decoded as? TopLevelWrapper)?.value, - let expectedDate = (value as? TopLevelWrapper)?.value { - evalEdgeCase(decodedDate, expectedDate) - return - } - - if let decodedDate = (decoded as? OptionalTopLevelWrapper)?.value, - let expectedDate = (value as? OptionalTopLevelWrapper)?.value { - evalEdgeCase(decodedDate, expectedDate) - return - } - - expectEqual(decoded, value, "\(T.self) did not round-trip to an equal value.") - } catch { - expectUnreachable("Failed to decode \(T.self) from JSON: \(error)") + do { + let decoder = JSONDecoder() + let decoded = try decoder.decode(T.self, from: payload) + + /// `snprintf`'s `%g`, which `JSONSerialization` uses internally for double values, does not respect + /// our precision requests in every case. This bug effects Darwin, FreeBSD, and Linux currently + /// causing this test (which uses the current time) to fail occasionally. + if formattedLength(of: (decoded as! Date).timeIntervalSinceReferenceDate) > DBL_DECIMAL_DIG + 2 { + let adjustedTimeIntervalSinceReferenceDate: (Date) -> Double = { date in + let adjustment = pow(10, Double(DBL_DECIMAL_DIG)) + return Double(floor(adjustment * date.timeIntervalSinceReferenceDate).rounded() / adjustment) + } + + let decodedAprox = adjustedTimeIntervalSinceReferenceDate(decoded as! Date) + let valueAprox = adjustedTimeIntervalSinceReferenceDate(value as! Date) + expectEqual(decodedAprox, valueAprox, "\(T.self) did not round-trip to an equal value after DBL_DECIMAL_DIG adjustment \(decodedAprox) != \(valueAprox).") + return } - } - - // Test the above `snprintf` edge case evaluation with known triggering cases - // Tests the two precision digits larger case - let knownBadDateTwoExtraDigits = Date(timeIntervalSinceReferenceDate: 0.0021413276231263384) - localTestRoundTrip(of: TopLevelWrapper(knownBadDateTwoExtraDigits)) + expectEqual(decoded, value, "\(T.self) did not round-trip to an equal value. \((decoded as! Date).timeIntervalSinceReferenceDate) != \((value as! Date).timeIntervalSinceReferenceDate)") + } catch { + expectUnreachable("Failed to decode \(T.self) from JSON: \(error)") + } + } - // Tests the one precision digit larger case - let knownBadDateOneExtraDigit = Date(timeIntervalSinceReferenceDate: 576487829.7193049) - localTestRoundTrip(of: TopLevelWrapper(knownBadDateOneExtraDigit)) + // Test the above `snprintf` edge case evaluation with a known triggering case + let knownBadDate = Date(timeIntervalSinceReferenceDate: 0.0021413276231263384) + localTestRoundTrip(of: knownBadDate) - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - localTestRoundTrip(of: TopLevelWrapper(Date())) + localTestRoundTrip(of: Date()) // Optional dates should encode the same way. - localTestRoundTrip(of: OptionalTopLevelWrapper(Date())) + localTestRoundTrip(of: Optional(Date())) } func testEncodingDateSecondsSince1970() { // Cannot encode an arbitrary number of seconds since we've lost precision since 1970. let seconds = 1000.0 - let expectedJSON = "{\"value\":1000}".data(using: .utf8)! + let expectedJSON = "1000".data(using: .utf8)! - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - _testRoundTrip(of: TopLevelWrapper(Date(timeIntervalSince1970: seconds)), + _testRoundTrip(of: Date(timeIntervalSince1970: seconds), expectedJSON: expectedJSON, dateEncodingStrategy: .secondsSince1970, dateDecodingStrategy: .secondsSince1970) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(Date(timeIntervalSince1970: seconds)), + _testRoundTrip(of: Optional(Date(timeIntervalSince1970: seconds)), expectedJSON: expectedJSON, dateEncodingStrategy: .secondsSince1970, dateDecodingStrategy: .secondsSince1970) @@ -324,16 +297,15 @@ class TestJSONEncoder : TestJSONEncoderSuper { func testEncodingDateMillisecondsSince1970() { // Cannot encode an arbitrary number of seconds since we've lost precision since 1970. let seconds = 1000.0 - let expectedJSON = "{\"value\":1000000}".data(using: .utf8)! + let expectedJSON = "1000000".data(using: .utf8)! - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - _testRoundTrip(of: TopLevelWrapper(Date(timeIntervalSince1970: seconds)), + _testRoundTrip(of: Date(timeIntervalSince1970: seconds), expectedJSON: expectedJSON, dateEncodingStrategy: .millisecondsSince1970, dateDecodingStrategy: .millisecondsSince1970) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(Date(timeIntervalSince1970: seconds)), + _testRoundTrip(of: Optional(Date(timeIntervalSince1970: seconds)), expectedJSON: expectedJSON, dateEncodingStrategy: .millisecondsSince1970, dateDecodingStrategy: .millisecondsSince1970) @@ -345,17 +317,16 @@ class TestJSONEncoder : TestJSONEncoderSuper { formatter.formatOptions = .withInternetDateTime let timestamp = Date(timeIntervalSince1970: 1000) - let expectedJSON = "{\"value\":\"\(formatter.string(from: timestamp))\"}".data(using: .utf8)! + let expectedJSON = "\"\(formatter.string(from: timestamp))\"".data(using: .utf8)! - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - _testRoundTrip(of: TopLevelWrapper(timestamp), + _testRoundTrip(of: timestamp, expectedJSON: expectedJSON, dateEncodingStrategy: .iso8601, dateDecodingStrategy: .iso8601) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(timestamp), + _testRoundTrip(of: Optional(timestamp), expectedJSON: expectedJSON, dateEncodingStrategy: .iso8601, dateDecodingStrategy: .iso8601) @@ -368,16 +339,15 @@ class TestJSONEncoder : TestJSONEncoderSuper { formatter.timeStyle = .full let timestamp = Date(timeIntervalSince1970: 1000) - let expectedJSON = "{\"value\":\"\(formatter.string(from: timestamp))\"}".data(using: .utf8)! + let expectedJSON = "\"\(formatter.string(from: timestamp))\"".data(using: .utf8)! - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - _testRoundTrip(of: TopLevelWrapper(timestamp), + _testRoundTrip(of: timestamp, expectedJSON: expectedJSON, dateEncodingStrategy: .formatted(formatter), dateDecodingStrategy: .formatted(formatter)) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(timestamp), + _testRoundTrip(of: Optional(timestamp), expectedJSON: expectedJSON, dateEncodingStrategy: .formatted(formatter), dateDecodingStrategy: .formatted(formatter)) @@ -393,15 +363,14 @@ class TestJSONEncoder : TestJSONEncoderSuper { } let decode = { (_: Decoder) throws -> Date in return timestamp } - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":42}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(timestamp), + let expectedJSON = "42".data(using: .utf8)! + _testRoundTrip(of: timestamp, expectedJSON: expectedJSON, dateEncodingStrategy: .custom(encode), dateDecodingStrategy: .custom(decode)) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(timestamp), + _testRoundTrip(of: Optional(timestamp), expectedJSON: expectedJSON, dateEncodingStrategy: .custom(encode), dateDecodingStrategy: .custom(decode)) @@ -414,15 +383,14 @@ class TestJSONEncoder : TestJSONEncoderSuper { let encode = { (_: Date, _: Encoder) throws -> Void in } let decode = { (_: Decoder) throws -> Date in return timestamp } - // We can't encode a top-level Date, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":{}}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(timestamp), + let expectedJSON = "{}".data(using: .utf8)! + _testRoundTrip(of: timestamp, expectedJSON: expectedJSON, dateEncodingStrategy: .custom(encode), dateDecodingStrategy: .custom(decode)) // Optional dates should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(timestamp), + _testRoundTrip(of: Optional(timestamp), expectedJSON: expectedJSON, dateEncodingStrategy: .custom(encode), dateDecodingStrategy: .custom(decode)) @@ -432,15 +400,14 @@ class TestJSONEncoder : TestJSONEncoderSuper { func testEncodingData() { let data = Data(bytes: [0xDE, 0xAD, 0xBE, 0xEF]) - // We can't encode a top-level Data, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":[222,173,190,239]}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(data), + let expectedJSON = "[222,173,190,239]".data(using: .utf8)! + _testRoundTrip(of: data, expectedJSON: expectedJSON, dataEncodingStrategy: .deferredToData, dataDecodingStrategy: .deferredToData) // Optional data should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(data), + _testRoundTrip(of: Optional(data), expectedJSON: expectedJSON, dataEncodingStrategy: .deferredToData, dataDecodingStrategy: .deferredToData) @@ -449,12 +416,11 @@ class TestJSONEncoder : TestJSONEncoderSuper { func testEncodingDataBase64() { let data = Data(bytes: [0xDE, 0xAD, 0xBE, 0xEF]) - // We can't encode a top-level Data, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":\"3q2+7w==\"}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(data), expectedJSON: expectedJSON) + let expectedJSON = "\"3q2+7w==\"".data(using: .utf8)! + _testRoundTrip(of: data, expectedJSON: expectedJSON) // Optional data should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(data), expectedJSON: expectedJSON) + _testRoundTrip(of: Optional(data), expectedJSON: expectedJSON) } func testEncodingDataCustom() { @@ -465,15 +431,14 @@ class TestJSONEncoder : TestJSONEncoderSuper { } let decode = { (_: Decoder) throws -> Data in return Data() } - // We can't encode a top-level Data, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":42}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(Data()), + let expectedJSON = "42".data(using: .utf8)! + _testRoundTrip(of: Data(), expectedJSON: expectedJSON, dataEncodingStrategy: .custom(encode), dataDecodingStrategy: .custom(decode)) // Optional data should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(Data()), + _testRoundTrip(of: Optional(Data()), expectedJSON: expectedJSON, dataEncodingStrategy: .custom(encode), dataDecodingStrategy: .custom(decode)) @@ -484,15 +449,14 @@ class TestJSONEncoder : TestJSONEncoderSuper { let encode = { (_: Data, _: Encoder) throws -> Void in } let decode = { (_: Decoder) throws -> Data in return Data() } - // We can't encode a top-level Data, so it'll be wrapped in a dictionary. - let expectedJSON = "{\"value\":{}}".data(using: .utf8)! - _testRoundTrip(of: TopLevelWrapper(Data()), + let expectedJSON = "{}".data(using: .utf8)! + _testRoundTrip(of: Data(), expectedJSON: expectedJSON, dataEncodingStrategy: .custom(encode), dataDecodingStrategy: .custom(decode)) // Optional Data should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(Data()), + _testRoundTrip(of: Optional(Data()), expectedJSON: expectedJSON, dataEncodingStrategy: .custom(encode), dataDecodingStrategy: .custom(decode)) @@ -500,73 +464,74 @@ class TestJSONEncoder : TestJSONEncoderSuper { // MARK: - Non-Conforming Floating Point Strategy Tests func testEncodingNonConformingFloats() { - _testEncodeFailure(of: TopLevelWrapper(Float.infinity)) - _testEncodeFailure(of: TopLevelWrapper(-Float.infinity)) - _testEncodeFailure(of: TopLevelWrapper(Float.nan)) + _testEncodeFailure(of: Float.infinity) + _testEncodeFailure(of: Float.infinity) + _testEncodeFailure(of: -Float.infinity) + _testEncodeFailure(of: Float.nan) - _testEncodeFailure(of: TopLevelWrapper(Double.infinity)) - _testEncodeFailure(of: TopLevelWrapper(-Double.infinity)) - _testEncodeFailure(of: TopLevelWrapper(Double.nan)) + _testEncodeFailure(of: Double.infinity) + _testEncodeFailure(of: -Double.infinity) + _testEncodeFailure(of: Double.nan) // Optional Floats/Doubles should encode the same way. - _testEncodeFailure(of: OptionalTopLevelWrapper(Float.infinity)) - _testEncodeFailure(of: OptionalTopLevelWrapper(-Float.infinity)) - _testEncodeFailure(of: OptionalTopLevelWrapper(Float.nan)) + _testEncodeFailure(of: Float.infinity) + _testEncodeFailure(of: -Float.infinity) + _testEncodeFailure(of: Float.nan) - _testEncodeFailure(of: OptionalTopLevelWrapper(Double.infinity)) - _testEncodeFailure(of: OptionalTopLevelWrapper(-Double.infinity)) - _testEncodeFailure(of: OptionalTopLevelWrapper(Double.nan)) + _testEncodeFailure(of: Double.infinity) + _testEncodeFailure(of: -Double.infinity) + _testEncodeFailure(of: Double.nan) } func testEncodingNonConformingFloatStrings() { let encodingStrategy: JSONEncoder.NonConformingFloatEncodingStrategy = .convertToString(positiveInfinity: "INF", negativeInfinity: "-INF", nan: "NaN") let decodingStrategy: JSONDecoder.NonConformingFloatDecodingStrategy = .convertFromString(positiveInfinity: "INF", negativeInfinity: "-INF", nan: "NaN") - _testRoundTrip(of: TopLevelWrapper(Float.infinity), - expectedJSON: "{\"value\":\"INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Float.infinity, + expectedJSON: "\"INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: TopLevelWrapper(-Float.infinity), - expectedJSON: "{\"value\":\"-INF\"}".data(using: .utf8)!, + _testRoundTrip(of: -Float.infinity, + expectedJSON: "\"-INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) // Since Float.nan != Float.nan, we have to use a placeholder that'll encode NaN but actually round-trip. - _testRoundTrip(of: TopLevelWrapper(FloatNaNPlaceholder()), - expectedJSON: "{\"value\":\"NaN\"}".data(using: .utf8)!, + _testRoundTrip(of: FloatNaNPlaceholder(), + expectedJSON: "\"NaN\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: TopLevelWrapper(Double.infinity), - expectedJSON: "{\"value\":\"INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Double.infinity, + expectedJSON: "\"INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: TopLevelWrapper(-Double.infinity), - expectedJSON: "{\"value\":\"-INF\"}".data(using: .utf8)!, + _testRoundTrip(of: -Double.infinity, + expectedJSON: "\"-INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) // Since Double.nan != Double.nan, we have to use a placeholder that'll encode NaN but actually round-trip. - _testRoundTrip(of: TopLevelWrapper(DoubleNaNPlaceholder()), - expectedJSON: "{\"value\":\"NaN\"}".data(using: .utf8)!, + _testRoundTrip(of: DoubleNaNPlaceholder(), + expectedJSON: "\"NaN\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) // Optional Floats and Doubles should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(Float.infinity), - expectedJSON: "{\"value\":\"INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Optional(Float.infinity), + expectedJSON: "\"INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: OptionalTopLevelWrapper(-Float.infinity), - expectedJSON: "{\"value\":\"-INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Optional(-Float.infinity), + expectedJSON: "\"-INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: OptionalTopLevelWrapper(Double.infinity), - expectedJSON: "{\"value\":\"INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Optional(Double.infinity), + expectedJSON: "\"INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) - _testRoundTrip(of: OptionalTopLevelWrapper(-Double.infinity), - expectedJSON: "{\"value\":\"-INF\"}".data(using: .utf8)!, + _testRoundTrip(of: Optional(-Double.infinity), + expectedJSON: "\"-INF\"".data(using: .utf8)!, nonConformingFloatEncodingStrategy: encodingStrategy, nonConformingFloatDecodingStrategy: decodingStrategy) } @@ -998,25 +963,37 @@ class TestJSONEncoder : TestJSONEncoderSuper { } func testInterceptDecimal() { - let expectedJSON = "{\"value\":10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000}".data(using: .utf8)! + let expectedJSON = "10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000".data(using: .utf8)! // Want to make sure we write out a JSON number, not the keyed encoding here. // 1e127 is too big to fit natively in a Double, too, so want to make sure it's encoded as a Decimal. let decimal = Decimal(sign: .plus, exponent: 127, significand: Decimal(1)) - _testRoundTrip(of: TopLevelWrapper(decimal), expectedJSON: expectedJSON) + _testRoundTrip(of: decimal, expectedJSON: expectedJSON) // Optional Decimals should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(decimal), expectedJSON: expectedJSON) + _testRoundTrip(of: Optional(decimal), expectedJSON: expectedJSON) } func testInterceptURL() { // Want to make sure JSONEncoder writes out single-value URLs, not the keyed encoding. - let expectedJSON = "{\"value\":\"http:\\/\\/swift.org\"}".data(using: .utf8)! + let expectedJSON = "\"http:\\/\\/swift.org\"".data(using: .utf8)! let url = URL(string: "http://swift.org")! - _testRoundTrip(of: TopLevelWrapper(url), expectedJSON: expectedJSON) + _testRoundTrip(of: url, expectedJSON: expectedJSON) // Optional URLs should encode the same way. - _testRoundTrip(of: OptionalTopLevelWrapper(url), expectedJSON: expectedJSON) + _testRoundTrip(of: Optional(url), expectedJSON: expectedJSON) + } + + func testInterceptURLWithoutEscapingOption() { + if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) { + // Want to make sure JSONEncoder writes out single-value URLs, not the keyed encoding. + let expectedJSON = "\"http://swift.org\"".data(using: .utf8)! + let url = URL(string: "http://swift.org")! + _testRoundTrip(of: url, expectedJSON: expectedJSON, outputFormatting: [.withoutEscapingSlashes]) + + // Optional URLs should encode the same way. + _testRoundTrip(of: Optional(url), expectedJSON: expectedJSON, outputFormatting: [.withoutEscapingSlashes]) + } } // MARK: - Type coercion @@ -1160,8 +1137,8 @@ class TestJSONEncoder : TestJSONEncoderSuper { throw CustomError.foo }) - let json = "{\"value\": 1}".data(using: .utf8)! - let _ = try! decoder.decode(EitherDecodable, TopLevelWrapper>.self, from: json) + let json = "1".data(using: .utf8)! + let _ = try! decoder.decode(EitherDecodable.self, from: json) } func testDecoderStateThrowOnDecodeCustomData() { @@ -1172,8 +1149,8 @@ class TestJSONEncoder : TestJSONEncoderSuper { throw CustomError.foo }) - let json = "{\"value\": 1}".data(using: .utf8)! - let _ = try! decoder.decode(EitherDecodable, TopLevelWrapper>.self, from: json) + let json = "1".data(using: .utf8)! + let _ = try! decoder.decode(EitherDecodable.self, from: json) } // MARK: - Helper Functions @@ -1662,47 +1639,6 @@ fileprivate struct _TestKey : CodingKey { } } -/// Wraps a type T so that it can be encoded at the top level of a payload. -fileprivate struct TopLevelWrapper : Codable, Equatable where T : Codable, T : Equatable { - let value: T - - init(_ value: T) { - self.value = value - } - - static func ==(_ lhs: TopLevelWrapper, _ rhs: TopLevelWrapper) -> Bool { - return lhs.value == rhs.value - } -} - -/// Wraps a type T (as T?) so that it can be encoded at the top level of a payload. -fileprivate struct OptionalTopLevelWrapper : Codable, Equatable where T : Codable, T : Equatable { - let value: T? - - init(_ value: T) { - self.value = value - } - - // Provide an implementation of Codable to encode(forKey:) instead of encodeIfPresent(forKey:). - private enum CodingKeys : String, CodingKey { - case value - } - - init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - value = try container.decode(T?.self, forKey: .value) - } - - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(value, forKey: .value) - } - - static func ==(_ lhs: OptionalTopLevelWrapper, _ rhs: OptionalTopLevelWrapper) -> Bool { - return lhs.value == rhs.value - } -} - fileprivate struct FloatNaNPlaceholder : Codable, Equatable { init() {} @@ -1785,7 +1721,8 @@ JSONEncoderTests.test("testEncodingOutputFormattingDefault") { TestJSONEncoder() JSONEncoderTests.test("testEncodingOutputFormattingPrettyPrinted") { TestJSONEncoder().testEncodingOutputFormattingPrettyPrinted() } JSONEncoderTests.test("testEncodingOutputFormattingSortedKeys") { TestJSONEncoder().testEncodingOutputFormattingSortedKeys() } JSONEncoderTests.test("testEncodingOutputFormattingPrettyPrintedSortedKeys") { TestJSONEncoder().testEncodingOutputFormattingPrettyPrintedSortedKeys() } -JSONEncoderTests.test("testEncodingDate") { TestJSONEncoder().testEncodingDate() } +// disabled for now due to a Date bug rdar://52618414 +// JSONEncoderTests.test("testEncodingDate") { TestJSONEncoder().testEncodingDate() } JSONEncoderTests.test("testEncodingDateSecondsSince1970") { TestJSONEncoder().testEncodingDateSecondsSince1970() } JSONEncoderTests.test("testEncodingDateMillisecondsSince1970") { TestJSONEncoder().testEncodingDateMillisecondsSince1970() } JSONEncoderTests.test("testEncodingDateISO8601") { TestJSONEncoder().testEncodingDateISO8601() } @@ -1813,6 +1750,7 @@ JSONEncoderTests.test("testNestedContainerCodingPaths") { TestJSONEncoder().test JSONEncoderTests.test("testSuperEncoderCodingPaths") { TestJSONEncoder().testSuperEncoderCodingPaths() } JSONEncoderTests.test("testInterceptDecimal") { TestJSONEncoder().testInterceptDecimal() } JSONEncoderTests.test("testInterceptURL") { TestJSONEncoder().testInterceptURL() } +JSONEncoderTests.test("testInterceptURLWithoutEscapingOption") { TestJSONEncoder().testInterceptURLWithoutEscapingOption() } JSONEncoderTests.test("testTypeCoercion") { TestJSONEncoder().testTypeCoercion() } JSONEncoderTests.test("testDecodingConcreteTypeParameter") { TestJSONEncoder().testDecodingConcreteTypeParameter() } JSONEncoderTests.test("testEncoderStateThrowOnEncode") { TestJSONEncoder().testEncoderStateThrowOnEncode() } diff --git a/test/stdlib/TestScanner.swift b/test/stdlib/TestScanner.swift new file mode 100644 index 0000000000000..1864dc0032f53 --- /dev/null +++ b/test/stdlib/TestScanner.swift @@ -0,0 +1,511 @@ +// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// RUN: %empty-directory(%t) +// +// RUN: %target-clang %S/Inputs/FoundationBridge/FoundationBridge.m -c -o %t/FoundationBridgeObjC.o -g +// RUN: %target-build-swift %s -I %S/Inputs/FoundationBridge/ -Xlinker %t/FoundationBridgeObjC.o -o %t/TestScanner +// RUN: %target-codesign %t/TestScanner + +// RUN: %target-run %t/TestScanner > %t.txt +// REQUIRES: executable_test +// REQUIRES: objc_interop + +import Foundation + +#if FOUNDATION_XCTEST +import XCTest +#if DEVELOPING_SCANNERAPI_AS_SEPARATE_MODULE +import ScannerAPI +#endif +class TestScannerSuper : XCTestCase { } +#else +import StdlibUnittest +class TestScannerSuper { } +#endif + +fileprivate func withScanner(for string: String, invoking block: ((Scanner) throws -> Void)? = nil) rethrows { + let scanner = Scanner(string: string) + scanner.locale = Locale(identifier: "en_US_POSIX") + try block?(scanner) +} + +extension CharacterSet { + fileprivate init(unicodeScalarsIn string: String) { + // Needed because: rdar://47615913 + var set = CharacterSet() + for character in string { + for scalar in character.unicodeScalars { + set.insert(scalar) + } + } + + self = set + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +class TestScanner : TestScannerSuper { + func testScanFloatingPoint() { + // Leading whitespace: + withScanner(for: " 1.2345") { + expectEqual($0.scanFloat(), 1.2345 as Float, "Parsing with leading whitespace should work") + } + + // Test all digits and numbers 0..9 + - E e: + withScanner(for: "-1.23456789E123") { + expectEqual($0.scanDouble(), atof("-1.23456789E123"), "Parsing double with uppercase exponential notation") + } + + withScanner(for: "+1.23456789e0") { + expectEqual($0.scanDouble(), atof("+1.23456789e0"), "Parsing double with lowercase exponential notation") + } + + // Large magnitude: + let largeA = "1234567890123456789012345678901234567890123456789012345678901234" + withScanner(for: largeA) { + expectEqual($0.scanDouble(), atof(largeA), "Parsing large magnitude double") + + } + + let largeB = "\(largeA)\(largeA)" + withScanner(for: largeB) { + expectEqual($0.scanDouble(), atof(largeB), "Parsing large magnitude double") + + } + + // Doubles and ints: + withScanner(for: " 3.14 -89.1 0.0 0.0 -4.E-4 128 100.99 ") { + expectEqual($0.scanDouble(), atof("3.14"), "Doubles and ints: 1") + expectEqual($0.scanDouble(), atof("-89.1"), "Doubles and ints: 2") + expectEqual($0.scanDouble(), atof("0.0"), "Doubles and ints: 3") + expectEqual($0.scanDouble(), atof("0.0"), "Doubles and ints: 4") + expectEqual($0.scanDouble(), atof("-4.E-4"), "Doubles and ints: 5") + expectEqual($0.scanDouble(), atof("128"), "Doubles and ints: 6") + expectEqual($0.scanInt(), 100, "Doubles and ints: 7") // Make sure scanning ints does not consume the decimal separator + expectEqual($0.scanDouble(), atof(".99"), "Doubles and ints: 8") + } + + // Roundtrip: + withScanner(for: String(format: " %3.5f %3.5f ", 3.14 as Double, -100.00 as Double)) { + expectEqual($0.scanDouble(), atof("3.14"), "Roundtrip: 1") + expectEqual($0.scanDouble(), atof("-100"), "Roundtrip: 2") + } + } + + func testHexRepresentation() { + // Long sequence: + withScanner(for: " 9 F 0xF 98 0x98 0x00098 0x980000000 0x980000000 acdcg0xacdcg0XACDCg0xg fFfffffE 0?777\t\n 004321X ") { + expectEqual($0.scanInt32(representation: .hexadecimal), 9, "Same as decimal") + expectEqual($0.scanInt32(representation: .hexadecimal), 0xF, "Single digit") + expectEqual($0.scanInt32(representation: .hexadecimal), 0xF, "Single digit with 0x prefix") + expectEqual($0.scanInt32(representation: .hexadecimal), 0x98, "Two digits") + expectEqual($0.scanInt32(representation: .hexadecimal), 0x98, "Two digits with 0x prefix") + expectEqual($0.scanInt32(representation: .hexadecimal), 0x98, "Two digits with 0x prefix and leading zeros") + + expectEqual($0.scanInt32(representation: .hexadecimal), Int32.max, "Overflow") + expectEqual($0.scanUInt64(representation: .hexadecimal), 0x980000000 as UInt64, "Unsigned 64-bit") + + expectEqual($0.scanInt32(representation: .hexadecimal), 0xacdc, "Followed by non-hex-digit without space") + expectEqual($0.scanString("g"), "g", "Consume non-hex-digit") + expectEqual($0.scanInt32(representation: .hexadecimal), 0xacdc, "Followed by non-hex-digit without space, with 0x prefix") + expectEqual($0.scanString("g"), "g", "Consume non-hex-digit (2)") + expectEqual($0.scanInt32(representation: .hexadecimal), 0xacdc, "Followed by non-hex-digit without space, with 0X prefix") + expectEqual($0.scanString("g"), "g", "Consume non-hex-digit (3)") + expectEqual($0.scanInt32(representation: .hexadecimal), 0, "'0x' followed by non-hex-digit without space") + expectEqual($0.scanInt32(representation: .hexadecimal), nil, "'x' (after trying to parse '0xg' as hexadecimal) isn't parsed as hex int itself") + expectEqual($0.scanString("xg"), "xg", "Consume non-hex-digits (4)") + + expectEqual($0.scanInt64(representation: .hexadecimal), 0xfffffffe, "Mixed case, 64-bit") + + expectEqual($0.scanInt32(representation: .hexadecimal), 0, "0 prefixing complex whitespace sequence") + expectEqual($0.scanString("?"), "?", "Consume complex whitespace sequence (1)") + expectEqual($0.scanInt32(representation: .hexadecimal), 0x777, "777 inside complex whitespace sequence") + expectEqual($0.scanInt32(representation: .hexadecimal), 0x4321, "4321 with leading zeros inside complex whitespace sequence") + expectFalse($0.isAtEnd, "The X was not consumed") + expectEqual($0.scanString("X"), "X", "Consume the X") + expectTrue($0.isAtEnd, "The X was not consumed") + } + } + + func testUInt64() { + // UInt64 long sequence: + withScanner(for: String(format: "%llu %llu %llu 42 + 42 0 %llu", UInt64.max / 10, UInt64.max - 1, UInt64.max, UInt64.max)) { + expectEqual($0.scanUInt64(), UInt64.max / 10, "Order of magnitude close to max") + expectEqual($0.scanUInt64(), UInt64.max - 1, "One less than max") + expectEqual($0.scanUInt64(), UInt64.max, "Max") + expectEqual($0.scanUInt64(), 42 as UInt64, "Short-sized integer") + expectEqual($0.scanUInt64(), 42 as UInt64, "Short-sized integer, with sign, ignoring whitespace") + expectEqual($0.scanUInt64(), 0 as UInt64, "Zero") + expectEqual($0.scanUInt64(), UInt64.max, "Max again after zero (ignoring prefix whitespace without merging this with the zero)") + } + + // Overflow: + withScanner(for: "\(UInt64.max)0") { + expectEqual($0.scanUInt64(), UInt64.max, "Overflow") + } + } + + func testInt64() { + // Int64 long sequence: + withScanner(for: String(format: "%lld %lld %lld 42 - 42 0 -1 -1 %lld %lld", Int64.max / 10, Int64.max - 1, Int64.max, Int64.min, Int64.max)) { + expectEqual($0.scanInt64(), Int64.max / 10, "Order of magnitude close to max") + expectEqual($0.scanInt64(), Int64.max - 1, "One less than max") + expectEqual($0.scanInt64(), Int64.max, "Max") + expectEqual($0.scanInt64(), 42 as Int64, "Short-sized integer") + expectEqual($0.scanInt64(), -42 as Int64, "Short-sized integer, with sign, ignoring whitespace") + expectEqual($0.scanInt64(), 0 as Int64, "Zero") + expectEqual($0.scanInt64(), -1 as Int64, "Minus one") + expectEqual($0.scanInt64(), -1 as Int64, "Minus one after whitespace") + expectEqual($0.scanInt64(), Int64.min, "Min") + expectEqual($0.scanInt64(), Int64.max, "Max again after min (no joining it with preceding min even with ignroed whitespace)") + } + + // Overflow: + withScanner(for: "\(Int64.max)0") { + expectEqual($0.scanInt64(), Int64.max, "Overflow") + } + } + + func testInt32() { + // Int32 long sequence: + withScanner(for: String(format: "%d %d %d 42 - 42 0 -1 -1 %d %d", Int32.max / 10, Int32.max - 1, Int32.max, Int32.min, Int32.max)) { + expectEqual($0.scanInt32(), Int32.max / 10, "Order of magnitude close to max") + expectEqual($0.scanInt32(), Int32.max - 1, "One less than max") + expectEqual($0.scanInt32(), Int32.max, "Max") + expectEqual($0.scanInt32(), 42 as Int32, "Short-sized integer") + expectEqual($0.scanInt32(), -42 as Int32, "Short-sized integer, with sign, ignoring whitespace") + expectEqual($0.scanInt32(), 0 as Int32, "Zero") + expectEqual($0.scanInt32(), -1 as Int32, "Minus one") + expectEqual($0.scanInt32(), -1 as Int32, "Minus one after whitespace") + expectEqual($0.scanInt32(), Int32.min, "Min") + expectEqual($0.scanInt32(), Int32.max, "Max again after min (no joining it with preceding min even with ignroed whitespace)") + } + + // Overflow: + withScanner(for: "\(Int32.max)0") { + expectEqual($0.scanInt32(), Int32.max, "Overflow") + } + } + + func testScanCharacter() { + withScanner(for: " hello ") { + expectEqual($0.scanCharacter(), "h", "Hello! (h)") + expectEqual($0.scanCharacter(), "e", "Hello! (e)") + expectEqual($0.scanCharacter(), "l", "Hello! (l)") + expectEqual($0.scanCharacter(), "l", "Hello! (l)") + expectFalse($0.isAtEnd, "Not at end yet") + expectEqual($0.scanCharacter(), "o", "Hello! (o)") + expectTrue($0.isAtEnd, "At end (ignores trailing whitespace)") + } + + withScanner(for: " \tde\u{0301}mode\u{0301}\n\t\n ") { + expectEqual($0.scanCharacter(), "d", "Démodé! (d)") + expectEqual($0.scanCharacter(), "é", "Démodé! (é)") // Two code points in original, comparing to é (single code point) + expectEqual($0.scanCharacter(), "m", "Démodé! (m)") + expectEqual($0.scanCharacter(), "o", "Démodé! (o)") + expectEqual($0.scanCharacter(), "d", "Démodé! (d)") + expectFalse($0.isAtEnd, "Not at end yet") + expectEqual($0.scanCharacter(), "é", "Démodé! (é)") // Two code points in original, comparing to é (single code point) + expectTrue($0.isAtEnd, "At end (ignores trailing whitespace)") + } + + withScanner(for: " \t\n❤️ \t\t\n") { + expectFalse($0.isAtEnd, "Not at end yet") + expectEqual($0.scanCharacter(), "❤️", "Scan single grapheme (made of single code point)") + expectTrue($0.isAtEnd, "At end (ignores trailing whitespace)") + } + + withScanner(for: " \t👩‍👩‍👧‍👧\n\t\n ") { + expectFalse($0.isAtEnd, "Not at end yet") + expectEqual($0.scanCharacter(), "👩‍👩‍👧‍👧", "Scan single grapheme (made of multiple code points)") + expectTrue($0.isAtEnd, "At end (ignores trailing whitespace)") + } + + // Unicode 10.0 emoji: + withScanner(for: " \t\u{1f9db}\u{200d}\u{2640}\u{fe0f}\n\t\n ") { // VAMPIRE, ZERO-WIDTH JOINER, FEMALE SIGN, VARIATION SELECTOR-16 + expectFalse($0.isAtEnd, "Not at end yet") + expectEqual($0.scanCharacter(), "🧛‍♀️", "Scan single grapheme (made of multiple code points)") + expectTrue($0.isAtEnd, "At end (ignores trailing whitespace)") + } + } + + func testScanString() { + // Scan skipping whitespace: + withScanner(for: "h el lo ") { + expectEqual($0.scanString("hello"), nil, "Split 'hello': Cannot scan the whole word in one go") + expectEqual($0.scanString("h"), "h", "Split 'hello' (h)") + expectEqual($0.scanString("el"), "el", "Split 'hello' (el)") + expectEqual($0.scanString("lo"), "lo", "Split 'hello' (lo)") + expectTrue($0.isAtEnd, "Split 'hello': should be at end.") + } + + // Scan without whitespace to skip: + withScanner(for: "hello ") { + expectEqual($0.scanString("hello"), "hello", "Joined 'hello': Can scan the whole word in one go") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanString("h"), "h", "Joined 'hello' (h)") + expectEqual($0.scanString("el"), "el", "Joined 'hello' (el)") + expectEqual($0.scanString("lo"), "lo", "Joined 'hello' (lo)") + expectTrue($0.isAtEnd, "Joined 'hello': should be at end.") + } + + // Scan without skipping whitespace: + withScanner(for: "h el lo ") { + $0.charactersToBeSkipped = nil + expectEqual($0.scanString("h"), "h", "Split 'hello', without skipping whitespace (h)") + expectEqual($0.scanString("el"), nil, "Split 'hello', without skipping whitespace (el can't be scanned without consuming whitespace)") + expectEqual($0.scanString(" "), " ", "Split 'hello', without skipping whitespace (consume whitespace 1)") + expectEqual($0.scanString("el"), "el", "Split 'hello', without skipping whitespace (el)") + expectEqual($0.scanString("lo"), nil, "Split 'hello', without skipping whitespace (lo can't be scanned without consuming whitespace)") + expectEqual($0.scanString(" "), " ", "Split 'hello', without skipping whitespace (consume whitespace 2)") + expectEqual($0.scanString("lo"), "lo", "Split 'hello', without skipping whitespace (lo)") + expectFalse($0.isAtEnd, "Split 'hello', without skipping whitespace: should not be at end without consuming trailing whitespace") + expectEqual($0.scanString(" "), " ", "Split 'hello', without skipping whitespace (consume whitespace 3)") + expectTrue($0.isAtEnd, "Split 'hello', without skipping whitespace: should be at end") + } + + // Case-insensitive scanning: + withScanner(for: "H eL lO ") { + $0.caseSensitive = false + expectEqual($0.scanString("h"), "H", "Case-insensitive split 'hello' (h)") + expectEqual($0.scanString("el"), "eL", "Case-insensitive split 'hello' (el)") + expectEqual($0.scanString("lo"), "lO", "Case-insensitive split 'hello' (lo)") + expectTrue($0.isAtEnd, "Case-insensitive split 'hello': should be at end.") + } + + // Equivalent graphemes: + withScanner(for: "e\u{0300}") { // 'e' with a combining grave accent, two code points + expectEqual($0.scanString("\u{00E8}" /* U+00E8 LATIN SMALL LETTER E WITH GRAVE, one code point */), $0.string, "Can scan different string to get original as long as all graphemes are equivalent") + } + + // Partial graphemes: + withScanner(for: "e\u{0301}\u{031A}\u{032B}") { + // We do not assert here that the legacy methods don't work because they behave inconsistently wrt graphemes, and are able to actually discern that a combination code point plus a sequence of combining diacriticals is actually not OK to scan. Check just the newer behavior here. + expectEqual($0.scanString("e"), nil, "New method must not split graphemes while scanning") + expectEqual($0.scanString("e\u{0301}"), nil, "New method must not split graphemes while scanning") + expectEqual($0.scanString("e\u{0301}\u{031A}"), nil, "New method must not split graphemes while scanning") + expectEqual($0.scanString("e\u{0301}\u{031A}\u{032B}"), "e\u{0301}\u{031A}\u{032B}", "New method must not split graphemes while scanning") + } + + withScanner(for: "Lily is a 👩🏻‍💻") { // That's: [ U+1F469 WOMAN, U+1F3FB EMOJI MODIFIER FITZPATRICK SCALE 1-2, U+200D ZERO-WIDTH JOINER, U+1F4BB PERSONAL COMPUTER ] + // The deprecated API can scan a string that's a prefix of the code point sequence of this string, but the new method cannot do so if it would split the final character. + // .scanString(_:into:) interacts inconsistently with emoji. U+1F469 WOMAN will not scan by itself, but U+1F469 WOMAN, U+1F3FB EMOJI MODIFIER FITZPATRICK SCALE 1-2 will scan even though it's only part of a grapheme. + // .scanString() is designed to work on graphemes, so it will not scan either of these sequences. + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}", into: nil), true, "Legacy method can split graphemes while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}"), nil, "New method must not split graphemes while scanning") + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}\u{200D}"), nil, "New method must not split graphemes while scanning") + + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}\u{200D}\u{1F4BB}"), "Lily is a 👩🏻‍💻", "New method must work if graphemes would not be split while scanning") + expectTrue($0.isAtEnd, "After scanning the last grapheme, we are at end") + } + + // Legacy method interaction with partial grapheme scanning: + withScanner(for: "Lily is a 👩🏻‍💻!") { + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}", into: nil), true, "Legacy method can split graphemes while scanning") + expectEqual(String($0.string[$0.currentIndex...]), "!", "The index to scan from is the one after the end of the grapheme") + expectEqual($0.scanString("!"), "!", "Scanning starts correctly from there") + } + + withScanner(for: "Lily is a 👩🏻‍💻") { + expectEqual($0.scanString("Lily is a \u{1F469}\u{1F3FB}", into: nil), true, "Legacy method can split graphemes while scanning") + expectFalse($0.isAtEnd, "The scanner can scan more using legacy methods, even though no whole graphemes are left to scan. This can only happen when legacy methods are invoked or the deprecated .scanLocation property is set directly.") + expectEqual($0.currentIndex, $0.string.endIndex, "The Swift.String.Index we will resume scanning from for new methods is correctly pointing to the end of the string") + } + } + + func testScanUpToString() { + // Scan skipping whitespace: + withScanner(for: " hel lo") { + expectEqual($0.scanUpToString("lo"), "hel ", "Leading whitespace is skipped but not trailing whitespace before the stop point") + expectEqual($0.scanString("lo"), "lo", "The up-to string can be scanned immediately afterwards") + } + + // Scan without skipping whitespace: + withScanner(for: " hel lo") { + $0.charactersToBeSkipped = nil + expectEqual($0.scanUpToString("lo"), " hel ", "No whitespace is skipped") + expectEqual($0.scanString("lo"), "lo", "The up-to string can be scanned immediately afterwards") + } + + // Case-insensitive: + withScanner(for: " hel LOo!") { + $0.caseSensitive = false + expectEqual($0.scanUpToString("lo"), "hel ", "Leading whitespace is skipped but not trailing whitespace before the stop point") + expectEqual($0.scanString("lo"), "LO", "The up-to string can be scanned immediately afterwards (and actual case is returned)") + } + + // Equivalent graphemes: + withScanner(for: "wow e\u{0300}") { // 'e' with a combining grave accent, two code points + expectEqual($0.scanUpToString("\u{00E8}" /* U+00E8 LATIN SMALL LETTER E WITH GRAVE, one code point */), "wow ", "Can scan different string to get original as long as all graphemes are equivalent") + expectEqual($0.scanString("\u{00E8}"), "e\u{0300}", "The up-to string can be scanned immediately afterwards (and actual source form is returned)") + } + + // Partial graphemes (diacritics): + withScanner(for: "wow e\u{0301}\u{031A}\u{032B} NOT FOUND") { + // We do not assert here that the legacy methods don't work because they behave inconsistently wrt graphemes, and are able to actually discern that a combination code point plus a sequence of combining diacriticals is actually not OK to scan. Check just the newer behavior here. + // The correct failure mode here is that the whole string should be returned on failure to match — the partial grapheme match won't stop scanUpToString(_:), which will keep looking later. This means that on failure to find, the methods will succeed -- this is why we reset .currentIndex after every invocation. + expectEqual($0.scanUpToString("e"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("e\u{0301}"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("e\u{0301}\u{031A}"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("e\u{0301}\u{031A}\u{032B}"), "wow ", "New method must match a full grapheme and stop there") + expectEqual($0.scanString("e\u{0301}\u{031A}\u{032B}"), "e\u{0301}\u{031A}\u{032B}", "The up-to string can be scanned immediately afterwards") + } + + // Partial graphemes (emoji): + withScanner(for: "Lily is a 👩🏻‍💻 NOT FOUND") { // That's: [ U+1F469 WOMAN, U+1F3FB EMOJI MODIFIER FITZPATRICK SCALE 1-2, U+200D ZERO-WIDTH JOINER, U+1F4BB PERSONAL COMPUTER ] + expectEqual($0.scanUpToString("\u{1F469}"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("\u{1F469}\u{1F3FB}"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("\u{1F469}\u{1F3FB}\u{200D}"), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToString("\u{1F469}\u{1F3FB}\u{200D}\u{1F4BB}"), "Lily is a ", "New method must work if graphemes would not be split while scanning") + expectEqual($0.scanString("👩🏻‍💻"), "👩🏻‍💻", "The up-to string can be scanned immediately afterwards") + } + } + + func testScanCharactersFromSet() { + // Scan skipping whitespace: + withScanner(for: " doremifasol123 whoa") { + expectEqual($0.scanCharacters(from: .alphanumerics), "doremifasol123", "Skip leading whitespace, but do stop when new whitespace occurs") + } + + // Scan without skipping whitespace: + withScanner(for: " doremifasol123 !!") { + $0.charactersToBeSkipped = nil + expectEqual($0.scanCharacters(from: .alphanumerics), nil, "Do not skip leading whitespace") + let combined = CharacterSet.alphanumerics.union(.whitespaces) + expectEqual($0.scanCharacters(from: combined), " doremifasol123 ", "Pick up whitespace when explicitly requested, including trailing whitespace, and stop before the final characters outside the set") + } + + // Case sensitivity does not impact .scanCharacters(from:) + withScanner(for: "wowWOW") { + $0.caseSensitive = false + expectEqual($0.scanCharacters(from: CharacterSet(charactersIn: "wo")), "wow", ".caseSensitive does not change which characters are found by scanCharacters(from:)") + } + + // Scan only full graphemes (diacritics): + withScanner(for: " e\u{0301}\u{031A}\u{032B} wow") { + expectEqual($0.scanCharacters(from: CharacterSet(charactersIn: "e")), nil, "Cannot scan a grapheme that contains one or more code points not in the set") + expectEqual($0.scanCharacters(from: CharacterSet(charactersIn: "e\u{0301}")), nil, "Cannot scan a grapheme that contains one or more code points not in the set") + expectEqual($0.scanCharacters(from: CharacterSet(charactersIn: "e\u{0301}\u{031A}")), nil, "Cannot scan a grapheme that contains one or more code points not in the set") + expectEqual($0.scanCharacters(from: CharacterSet(charactersIn: "e\u{0301}\u{031A}\u{032B}")), "e\u{0301}\u{031A}\u{032B}", "Can scan a grapheme if all of its code points are in the character set") + } + + // Scan only full graphemes (emoji): + withScanner(for: "Lily is a 👩🏻‍💻") { // That's: [ U+1F469 WOMAN, U+1F3FB EMOJI MODIFIER FITZPATRICK SCALE 1-2, U+200D ZERO-WIDTH JOINER, U+1F4BB PERSONAL COMPUTER ] + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanCharacters(from: CharacterSet(unicodeScalarsIn: "Lily is a \u{1F469}")), "Lily is a ", "Cannot scan a grapheme that contains one or more code points not in the set") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanCharacters(from: CharacterSet(unicodeScalarsIn: "Lily is a \u{1F469}\u{1F3FB}")), "Lily is a ", "Cannot scan a grapheme that contains one or more code points not in the set") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanCharacters(from: CharacterSet(unicodeScalarsIn: "Lily is a \u{1F469}\u{1F3FB}\u{200D}")), "Lily is a ","Cannot scan a grapheme that contains one or more code points not in the set") + + $0.currentIndex = $0.string.startIndex + let set = CharacterSet(unicodeScalarsIn: "Lily is a \u{1F469}\u{1F3FB}\u{200D}\u{1F4BB}") + expectEqual($0.scanCharacters(from: set), "Lily is a 👩🏻‍💻", "Can scan a grapheme if all of its code points are in the character set") + } + } + + func testScanUpToCharactersFromSet() { + // Scan skipping whitespace: + withScanner(for: " hel- lo ") { + let hyphen = CharacterSet(unicodeScalarsIn: "-") + + expectEqual($0.scanUpToCharacters(from: .alphanumerics), nil, "Whitespace should be skipped, and we should already be at a place where alphanumerics match 'hel'") + expectEqual($0.scanUpToCharacters(from: hyphen), "hel", "Whitespace should be skipped, and the rest captured until the separator") + expectEqual($0.scanCharacters(from: hyphen), "-", "You should be able to scan the up-to string immediately") + expectEqual($0.scanUpToCharacters(from: .alphanumerics), nil, "Whitespace should be skipped, and we should already be at a place where alphanumerics match 'lo'") + } + + // Scan without skipping whitespace: + withScanner(for: " hel- lo ") { + let hyphen = CharacterSet(unicodeScalarsIn: "-") + $0.charactersToBeSkipped = nil + + expectEqual($0.scanUpToCharacters(from: .alphanumerics), " ", "Whitespace should not be skipped.") + expectEqual($0.scanUpToCharacters(from: hyphen), "hel", "Move to the hyphen and scan the 'hel' in the process") + expectEqual($0.scanUpToCharacters(from: .whitespaces), "-", "Scan hyphen, to the whitespace") + expectEqual($0.scanUpToCharacters(from: .alphanumerics), " ", "Whitespace should not be skipped, and we should move at a place where alphanumerics match 'lo'") + expectEqual($0.scanString("lo"), "lo", "Consume the 'lo'") + } + + withScanner(for: "so, HELLO") { + $0.caseSensitive = false + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "h")), "so, HELLO", ".caseSensitive should not affect .scanUpToCharacters(from:); a set of only 'h' should not match 'H'") + } + + // Equivalent graphemes: + withScanner(for: "wow e\u{0300}") { // 'e' with a combining grave accent, two code points + let set = CharacterSet(unicodeScalarsIn: "\u{00E8}" /* U+00E8 LATIN SMALL LETTER E WITH GRAVE, one code point */) + expectEqual($0.scanUpToCharacters(from: set), "wow e\u{0300}", "Scanning using a character set should only match graphemes if the specific code points they are composed of are in the character set. 'e' + combining accent is not matched by a set with just 'è', even though they make equivalent graphemes") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "e\u{0300}")), "wow ", "Scanning does match if the specific code points are in the set") + + } + + // Partial graphemes (diacritics): + withScanner(for: "wow e\u{0301}\u{031A}\u{032B} NOT FOUND") { + // The correct failure mode here is that the whole string should be returned on failure to match — the partial grapheme match won't stop scanUpToCharacters(from:), which will keep looking later. This means that on failure to find, the methods will succeed -- this is why we reset .currentIndex after every invocation. + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "e")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "e\u{0301}")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "e\u{0301}\u{031A}")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "e\u{0301}\u{031A}\u{032B}")), "wow ", "Scanning does match if the specific code points are in the set") + expectEqual($0.scanCharacters(from: CharacterSet(unicodeScalarsIn: "e\u{0301}\u{031A}\u{032B}")), "e\u{0301}\u{031A}\u{032B}", "The up-to character set can be scanned immediately afterwards") + } + + // Partial graphemes (emoji): + withScanner(for: "Lily is a 👩🏻‍💻 NOT FOUND") { // That's: [ U+1F469 WOMAN, U+1F3FB EMOJI MODIFIER FITZPATRICK SCALE 1-2, U+200D ZERO-WIDTH JOINER, U+1F4BB PERSONAL COMPUTER ] + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "\u{1F469}")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "\u{1F469}\u{1F3FB}")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + $0.currentIndex = $0.string.startIndex + expectEqual($0.scanUpToCharacters(from: CharacterSet(unicodeScalarsIn: "\u{1F469}\u{1F3FB}\u{200D}")), $0.string, "New method must go past graphemes that match part of the scan-up-to string while scanning") + + $0.currentIndex = $0.string.startIndex + let finalSet = CharacterSet(unicodeScalarsIn: "\u{1F469}\u{1F3FB}\u{200D}\u{1F4BB}") + expectEqual($0.scanUpToCharacters(from: finalSet), "Lily is a ", "New method must work if graphemes would not be split while scanning") + expectEqual($0.scanCharacters(from: finalSet), "👩🏻‍💻", "The up-to character set can be scanned immediately afterwards") + } + } +} + +#if !FOUNDATION_XCTEST +if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) { + let testSuite = TestSuite("TestScanner") + let handler = TestScanner() + testSuite.test("testScanFloatingPoint") { handler.testScanFloatingPoint() } + testSuite.test("testHexRepresentation") { handler.testHexRepresentation() } + testSuite.test("testUInt64") { handler.testUInt64() } + testSuite.test("testInt64") { handler.testInt64() } + testSuite.test("testInt32") { handler.testInt32() } + testSuite.test("testScanCharacter") { handler.testScanCharacter() } + testSuite.test("testScanString") { handler.testScanString() } + testSuite.test("testScanUpToString") { handler.testScanUpToString() } + testSuite.test("testScanCharactersFromSet") { handler.testScanCharactersFromSet() } + testSuite.test("testScanUpToCharactersFromSet") { handler.testScanUpToCharactersFromSet() } + + runAllTests() +} +#endif + diff --git a/test/stdlib/URLSession.swift b/test/stdlib/URLSession.swift new file mode 100644 index 0000000000000..84b686dc7a26f --- /dev/null +++ b/test/stdlib/URLSession.swift @@ -0,0 +1,61 @@ +// RUN: %target-swift-frontend -typecheck %s +// REQUIRES: objc_interop +// REQUIRES: foundation + +import StdlibUnittest +import Foundation + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func testWebSocketTask() { + let task = URLSession.shared.webSocketTask(with: URL(string:"wss://test.example")!) + + task.resume() + + task.send(.string("Hello")) { error in + assert(error == nil) + } + + task.receive { result in + switch result { + case .success(.string(let string)): + assert(string == "Hello") + task.cancel() + default: + assertionFailure() + } + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func testURLError(_ error: Error) { + if let error = error as? URLError { + if error.networkUnavailableReason == .constrained { + // Handle Low Data Mode + } + if error.backgroundTaskCancelledReason == .backgroundUpdatesDisabled { + // Background refresh disabled + } + _ = try? error.downloadTaskResumeData?.write(to: URL(fileURLWithPath: "/tmp/1.data")) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func testURLCache() { + _ = URLCache(memoryCapacity: 0, diskCapacity: 0) + _ = URLCache(memoryCapacity: 0, diskCapacity: 0, directory: URL(fileURLWithPath: "/tmp")) +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func testTaskMetrics(_ metrics: URLSessionTaskMetrics) { + if let transaction = metrics.transactionMetrics.last { + if transaction.remotePort == 443 { + // HTTPS default + } + if transaction.negotiatedTLSProtocolVersion == .TLSv13 { + // TLS 1.3 + } + if transaction.negotiatedTLSCipherSuite == .CHACHA20_POLY1305_SHA256 { + // CHACHA20_POLY1305_SHA256 + } + } +} diff --git a/test/stdlib/VarArgs.swift b/test/stdlib/VarArgs.swift index 94aef60bf1b40..ea9317fef5fe2 100644 --- a/test/stdlib/VarArgs.swift +++ b/test/stdlib/VarArgs.swift @@ -6,7 +6,7 @@ import Swift #if _runtime(_ObjC) import Darwin import CoreGraphics -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc typealias CGFloat = Double #elseif os(Windows) diff --git a/test/stdlib/mmap.swift b/test/stdlib/mmap.swift index 7008923d4c660..d0e098493911b 100644 --- a/test/stdlib/mmap.swift +++ b/test/stdlib/mmap.swift @@ -6,7 +6,7 @@ import StdlibUnittest #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) import Glibc #else #error("Unsupported platform") diff --git a/test/stdlib/tgmath.swift.gyb b/test/stdlib/tgmath.swift.gyb index 6c624bc69bb98..fede221f148e8 100644 --- a/test/stdlib/tgmath.swift.gyb +++ b/test/stdlib/tgmath.swift.gyb @@ -19,7 +19,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin.C.tgmath -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/test/stdlib/tgmath_optimized.swift b/test/stdlib/tgmath_optimized.swift index 0501aab8cd3ab..fdf3b75299db1 100644 --- a/test/stdlib/tgmath_optimized.swift +++ b/test/stdlib/tgmath_optimized.swift @@ -6,7 +6,7 @@ #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) import Darwin -#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) +#elseif os(Linux) || os(FreeBSD) || os(OpenBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku) || os(WASI) import Glibc #elseif os(Windows) import MSVCRT diff --git a/tools/SourceKit/include/SourceKit/Core/LangSupport.h b/tools/SourceKit/include/SourceKit/Core/LangSupport.h index 5edfdd64e6def..d46c638b6ee68 100644 --- a/tools/SourceKit/include/SourceKit/Core/LangSupport.h +++ b/tools/SourceKit/include/SourceKit/Core/LangSupport.h @@ -485,6 +485,7 @@ struct DocEntityInfo { llvm::SmallString<64> LocalizationKey; std::vector GenericParams; std::vector GenericRequirements; + std::vector RequiredBystanders; unsigned Offset = 0; unsigned Length = 0; bool IsUnavailable = false; diff --git a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp index 9ada6f0362b59..92c8c42a42d3f 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftASTManager.cpp @@ -31,6 +31,7 @@ #include "swift/Sema/IDETypeChecking.h" #include "llvm/ADT/FoldingSet.h" +#include "llvm/Support/Chrono.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" @@ -374,7 +375,9 @@ struct SwiftASTManager::Implementation { std::shared_ptr Config, std::shared_ptr Stats, StringRef RuntimeResourcePath) : EditorDocs(EditorDocs), Config(Config), Stats(Stats), - RuntimeResourcePath(RuntimeResourcePath) {} + RuntimeResourcePath(RuntimeResourcePath), + SessionTimestamp(llvm::sys::toTimeT(std::chrono::system_clock::now())) { + } std::shared_ptr EditorDocs; std::shared_ptr Config; @@ -383,6 +386,7 @@ struct SwiftASTManager::Implementation { SourceManager SourceMgr; Cache ASTCache{ "sourcekit.swift.ASTCache" }; llvm::sys::Mutex CacheMtx; + std::time_t SessionTimestamp; WorkQueue ASTBuildQueue{ WorkQueue::Dequeuing::Serial, "sourcekit.swift.ASTBuilding" }; @@ -548,6 +552,18 @@ bool SwiftASTManager::initCompilerInvocation( if (Impl.Config->shouldOptimizeForIDE()) FrontendOpts.IgnoreSwiftSourceInfo = true; + // To save the time for module validation, consider the lifetime of ASTManager + // as a single build session. + // NOTE: 'SessionTimestamp - 1' because clang compares it with '<=' that may + // cause unnecessary validations if they happens within one second from + // the SourceKit startup. + ImporterOpts.ExtraArgs.push_back("-fbuild-session-timestamp=" + + std::to_string(Impl.SessionTimestamp - 1)); + ImporterOpts.ExtraArgs.push_back("-fmodules-validate-once-per-build-session"); + + auto &SearchPathOpts = Invocation.getSearchPathOptions(); + SearchPathOpts.DisableModulesValidateSystemDependencies = true; + // Disable expensive SIL options to reduce time spent in SILGen. disableExpensiveSILOptions(Invocation.getSILOptions()); diff --git a/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp b/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp index c081947a31e3a..a8209ebf47380 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp @@ -65,6 +65,7 @@ struct TextEntity { const Decl *Dcl = nullptr; TypeOrExtensionDecl SynthesizeTarget; const Decl *DefaultImplementationOf = nullptr; + ModuleDecl *UnderlyingModIfFromOverlay = nullptr; StringRef Argument; TextRange Range; unsigned LocOffset = 0; @@ -303,7 +304,8 @@ static bool initDocEntityInfo(const Decl *D, TypeOrExtensionDecl SynthesizedTarget, const Decl *DefaultImplementationOf, bool IsRef, bool IsSynthesizedExtension, DocEntityInfo &Info, - StringRef Arg = StringRef()) { + StringRef Arg = StringRef(), + ModuleDecl *ModIfFromOverlay = nullptr){ if (!IsRef && D->isImplicit()) return true; if (!D || isa(D) || @@ -409,6 +411,15 @@ static bool initDocEntityInfo(const Decl *D, SwiftLangSupport::printFullyAnnotatedGenericReq(Sig, OS); } } + + if (ModIfFromOverlay) { + ModuleDecl *MD = D->getModuleContext(); + SmallVector Bystanders; + ModIfFromOverlay->getAllBystandersForCrossImportOverlay(MD, Bystanders); + std::transform(Bystanders.begin(), Bystanders.end(), + std::back_inserter(Info.RequiredBystanders), + [](Identifier Bystander){ return Bystander.str().str(); }); + } } switch(D->getDeclContext()->getContextKind()) { @@ -446,7 +457,8 @@ static bool initDocEntityInfo(const TextEntity &Entity, if (initDocEntityInfo(Entity.Dcl, Entity.SynthesizeTarget, Entity.DefaultImplementationOf, /*IsRef=*/false, Entity.IsSynthesizedExtension, - Info, Entity.Argument)) + Info, Entity.Argument, + Entity.UnderlyingModIfFromOverlay)) return true; Info.Offset = Entity.Range.Offset; Info.Length = Entity.Range.Length; @@ -962,6 +974,16 @@ static bool getModuleInterfaceInfo(ASTContext &Ctx, StringRef ModuleName, Info.Text = std::string(OS.str()); Info.TopEntities = std::move(Printer.TopEntities); Info.References = std::move(Printer.References); + + // Add a reference to the main module on any entities from cross-import + // overlay modules (used to determine their bystanders later). + for (auto &Entity: Info.TopEntities) { + auto *EntityMod = Entity.Dcl->getModuleContext(); + if (!EntityMod || EntityMod == M) + continue; + if (M->isUnderlyingModuleOfCrossImportOverlay(EntityMod)) + Entity.UnderlyingModIfFromOverlay = M; + } return false; } diff --git a/tools/SourceKit/lib/SwiftLang/SwiftEditorInterfaceGen.cpp b/tools/SourceKit/lib/SwiftLang/SwiftEditorInterfaceGen.cpp index 7275933fb6b2c..fd0887468f6f9 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftEditorInterfaceGen.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftEditorInterfaceGen.cpp @@ -519,6 +519,10 @@ bool SwiftInterfaceGenContext::isModule() const { return Impl.IsModule; } +ModuleDecl *SwiftInterfaceGenContext::getModuleDecl() const { + return Impl.Mod; +}; + bool SwiftInterfaceGenContext::matches(StringRef ModuleName, const swift::CompilerInvocation &Invok) { if (!Impl.IsModule) diff --git a/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h b/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h index cf68e9ea6958f..86ee185e86c81 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h +++ b/tools/SourceKit/lib/SwiftLang/SwiftInterfaceGenContext.h @@ -21,6 +21,7 @@ namespace swift { class CompilerInvocation; class ValueDecl; + class ModuleDecl; } namespace SourceKit { @@ -58,6 +59,7 @@ class SwiftInterfaceGenContext : StringRef getDocumentName() const; StringRef getModuleOrHeaderName() const; bool isModule() const; + swift::ModuleDecl *getModuleDecl() const; bool matches(StringRef ModuleName, const swift::CompilerInvocation &Invok); diff --git a/tools/SourceKit/lib/SwiftLang/SwiftSourceDocInfo.cpp b/tools/SourceKit/lib/SwiftLang/SwiftSourceDocInfo.cpp index ac523ab55f2eb..a07d273940e59 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftSourceDocInfo.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftSourceDocInfo.cpp @@ -716,7 +716,7 @@ getParamParentNameOffset(const ValueDecl *VD, SourceLoc Cursor) { /// Returns true on success, false on error (and sets `Diagnostic` accordingly). static bool passCursorInfoForDecl(SourceFile* SF, const ValueDecl *VD, - const ModuleDecl *MainModule, + ModuleDecl *MainModule, const Type ContainerTy, bool IsRef, bool RetrieveRefactoring, @@ -906,7 +906,26 @@ static bool passCursorInfoForDecl(SourceFile* SF, if (ClangMod) ModuleName = ClangMod->getFullModuleName(); } else if (VD->getModuleContext() != MainModule) { - ModuleName = VD->getModuleContext()->getName().str().str(); + ModuleDecl *MD = VD->getModuleContext(); + // If the decl is from a cross-import overlay module, report the overlay's + // underlying module as the owning module. + if (SF) { + // In a source file we map the imported overlays to the underlying + // modules they shadow. + while (MD->getNameStr().startswith("_")) { + auto *Underlying = SF->getModuleShadowedBySeparatelyImportedOverlay(MD); + if (!Underlying) + break; + MD = Underlying; + } + } else if (MainModule) { + // In a module interface we need to map the declared overlays of the main + // module (which are included in its generated interface) back to the main + // module itself. + if (MainModule->isUnderlyingModuleOfCrossImportOverlay(MD)) + MD = MainModule; + } + ModuleName = MD->getName().str().str(); } StringRef ModuleInterfaceName; if (auto IFaceGenRef = Lang.getIFaceGenContexts().find(ModuleName, Invok)) @@ -1642,10 +1661,9 @@ void SwiftLangSupport::getCursorInfo( Receiver); } else { std::string Diagnostic; // Unused. - // FIXME: Should pass the main module for the interface but currently - // it's not necessary. + ModuleDecl *MainModule = IFaceGenRef->getModuleDecl(); passCursorInfoForDecl( - /*SourceFile*/nullptr, Entity.Dcl, /*MainModule*/ nullptr, + /*SourceFile*/nullptr, Entity.Dcl, MainModule, Type(), Entity.IsRef, Actionables, ResolvedCursorInfo(), /*OrigBufferID=*/None, SourceLoc(), {}, *this, Invok, Diagnostic, {}, Receiver); diff --git a/tools/SourceKit/tools/sourcekitd-test/Options.td b/tools/SourceKit/tools/sourcekitd-test/Options.td index ee21922a9c19b..48b9d2101ecfa 100644 --- a/tools/SourceKit/tools/sourcekitd-test/Options.td +++ b/tools/SourceKit/tools/sourcekitd-test/Options.td @@ -149,6 +149,9 @@ def suppress_config_request : Flag<["-"], "suppress-config-request">, def module_cache_path: Separate<["-"], "module-cache-path">, HelpText<"module cache path">; def module_cache_path_EQ : Joined<["-"], "module-cache-path=">, Alias; +def shell: Flag<["-"], "shell">, + HelpText<"Run shell command">; + def help : Flag<["-", "--"], "help">, HelpText<"Display available options">; diff --git a/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp b/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp index 7a70a621246ff..48330f4ac9080 100644 --- a/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp +++ b/tools/SourceKit/tools/sourcekitd-test/TestOptions.cpp @@ -389,6 +389,10 @@ bool TestOptions::parseArgs(llvm::ArrayRef Args) { ModuleCachePath = InputArg->getValue(); break; + case OPT_shell: + ShellExecution = true; + break; + case OPT_UNKNOWN: llvm::errs() << "error: unknown argument: " << InputArg->getAsString(ParsedArgs) << '\n' diff --git a/tools/SourceKit/tools/sourcekitd-test/TestOptions.h b/tools/SourceKit/tools/sourcekitd-test/TestOptions.h index 0be899ad16597..7551ab48afee5 100644 --- a/tools/SourceKit/tools/sourcekitd-test/TestOptions.h +++ b/tools/SourceKit/tools/sourcekitd-test/TestOptions.h @@ -124,6 +124,7 @@ struct TestOptions { llvm::StringMap VFSFiles; llvm::Optional VFSName; llvm::Optional CancelOnSubsequentRequest; + bool ShellExecution = false; bool parseArgs(llvm::ArrayRef Args); void printHelp(bool ShowHidden) const; }; diff --git a/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp b/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp index 75d5ba7f1d405..163a6aaf480b3 100644 --- a/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp +++ b/tools/SourceKit/tools/sourcekitd-test/sourcekitd-test.cpp @@ -26,6 +26,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" #include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" @@ -387,6 +388,16 @@ static int handleJsonRequestPath(StringRef QueryPath, const TestOptions &Opts) { return Error ? 1 : 0; } +static int performShellExecution(ArrayRef Args) { + auto Program = llvm::sys::findProgramByName(Args[0]); + if (std::error_code ec = Program.getError()) { + llvm::errs() << "command not found: " << Args[0] << "\n"; + return ec.value(); + } + SmallVector execArgs(Args.begin(), Args.end()); + return llvm::sys::ExecuteAndWait(*Program, execArgs); +} + static int handleTestInvocation(TestOptions Opts, TestOptions &InitOpts); static int handleTestInvocation(ArrayRef Args, @@ -419,6 +430,9 @@ static int handleTestInvocation(ArrayRef Args, } } + if (Opts.ShellExecution) + return performShellExecution(Opts.CompilerArgs); + assert(Opts.repeatRequest >= 1); for (unsigned i = 0; i < Opts.repeatRequest; ++i) { if (int ret = handleTestInvocation(Opts, InitOpts)) { diff --git a/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp b/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp index 457db15ddd136..b4f691b00f01e 100644 --- a/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp +++ b/tools/SourceKit/tools/sourcekitd/lib/API/Requests.cpp @@ -1573,11 +1573,15 @@ void SKDocConsumer::addDocEntityInfoToDict(const DocEntityInfo &Info, // while GenericParams is empty. if (!Info.GenericRequirements.empty()) { auto ReqArray = Elem.setArray(KeyGenericRequirements); + for (auto &Req : Info.GenericRequirements) { auto ReqElem = ReqArray.appendDictionary(); ReqElem.set(KeyDescription, Req); } } + + if (!Info.RequiredBystanders.empty()) + Elem.set(KeyRequiredBystanders, Info.RequiredBystanders); } void SKDocConsumer::failed(StringRef ErrDescription) { diff --git a/utils/gyb_sourcekit_support/UIDs.py b/utils/gyb_sourcekit_support/UIDs.py index 4f6b53d421b90..74022480fec38 100644 --- a/utils/gyb_sourcekit_support/UIDs.py +++ b/utils/gyb_sourcekit_support/UIDs.py @@ -177,6 +177,7 @@ def __init__(self, internal_name, external_name): KEY('VFSOptions', 'key.vfs.options'), KEY('Files', 'key.files'), KEY('OptimizeForIDE', 'key.optimize_for_ide'), + KEY('RequiredBystanders', 'key.required_bystanders'), ] diff --git a/validation-test/compiler_crashers_2_fixed/rdar60081992.swift b/validation-test/compiler_crashers_2_fixed/rdar60081992.swift new file mode 100644 index 0000000000000..47e96aa9f3850 --- /dev/null +++ b/validation-test/compiler_crashers_2_fixed/rdar60081992.swift @@ -0,0 +1,18 @@ +// RUN: %target-swift-frontend -typecheck %s + +struct Model where E: Comparable { + enum TimePeriod: CaseIterable { + case day, week, month, year + } +} + +struct MyStruct where E: Comparable { + init(entries: EA) where EA: BidirectionalCollection, EA.Element == E, EA.Index == Int { + typealias MDict = [Model.TimePeriod : T] where T: Numeric + func maxDict(_ keyPath: KeyPath, MDict>) -> MDict { + fatalError() + } + + fatalError() + } +}