Skip to content

Commit

Permalink
Merge pull request #496 from swiftwasm/maxd/master-merge
Browse files Browse the repository at this point in the history
Resolve conflicts with master
  • Loading branch information
kateinoigakukun authored Mar 26, 2020
2 parents eac57a3 + ea7ca46 commit 84373d3
Show file tree
Hide file tree
Showing 304 changed files with 10,218 additions and 3,047 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ jobs:
runs-on: ubuntu-18.04

steps:
- name: Free disk space
run: |
df -h
sudo swapoff -a
sudo rm -f /swapfile
sudo apt clean
docker rmi $(docker image ls -aq)
df -h
- uses: actions/checkout@v1
with:
path: swift
Expand Down
File renamed without changes.
61 changes: 61 additions & 0 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5774,6 +5774,67 @@ The rules on generic substitutions are identical to those of ``apply``.
Differentiable Programming
~~~~~~~~~~~~~~~~~~~~~~~~~~

differentiable_function
```````````````````````
::

sil-instruction ::= 'differentiable_function'
sil-differentiable-function-parameter-indices
sil-value ':' sil-type
sil-differentiable-function-derivative-functions-clause?

sil-differentiable-function-parameter-indices ::=
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
sil-differentiable-derivative-functions-clause ::=
'with_derivative'
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'

differentiable_function [parameters 0] %0 : $(T) -> T \
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}

Creates a ``@differentiable`` function from an original function operand and
derivative function operands (optional). There are two derivative function
kinds: a Jacobian-vector products (JVP) function and a vector-Jacobian products
(VJP) function.

``[parameters ...]`` specifies parameter indices that the original function is
differentiable with respect to.

The ``with_derivative`` clause specifies the derivative function operands
associated with the original function.

The differentiation transformation canonicalizes all `differentiable_function`
instructions, generating derivative functions if necessary to fill in derivative
function operands.

In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the
``with_derivative`` clause is mandatory.


differentiable_function_extract
```````````````````````````````
::

sil-instruction ::= 'differentiable_function_extract'
'[' sil-differentiable-function-extractee ']'
sil-value ':' sil-type
('as' sil-type)?

sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \
as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector)

Extracts the original function or a derivative function from the given
``@differentiable`` function. The extractee is one of the following:
``[original]``, ``[jvp]``, or ``[vjp]``.

In lowered SIL, an explicit extractee type may be provided. This is currently
used by the LoadableByAddress transformation, which rewrites function types.

differentiability_witness_function
``````````````````````````````````
::
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ class ASTMangler : public Mangler {

void appendProtocolConformance(const ProtocolConformance *conformance);
void appendProtocolConformanceRef(const RootProtocolConformance *conformance);
void appendAnyProtocolConformance(CanGenericSignature genericSig,
CanType conformingType,
ProtocolConformanceRef conformance);
void appendConcreteProtocolConformance(
const ProtocolConformance *conformance);
void appendDependentProtocolConformance(const ConformanceAccessPath &path);
Expand Down
57 changes: 18 additions & 39 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1675,12 +1675,11 @@ struct DeclNameRefWithLoc {
DeclNameLoc Loc;
};

/// Attribute that marks a function as differentiable and optionally specifies
/// custom associated derivative functions: 'jvp' and 'vjp'.
/// Attribute that marks a function as differentiable.
///
/// Examples:
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
/// @differentiable(where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y))
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr,
Expand All @@ -1696,16 +1695,6 @@ class DifferentiableAttr final
bool Linear;
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameRefWithLoc> JVP;
/// The VJP function.
Optional<DeclNameRefWithLoc> VJP;
/// The JVP function (optional), resolved by the type checker if JVP name is
/// specified.
FuncDecl *JVPFunction = nullptr;
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiability parameter indices, resolved by the type checker.
/// The bit stores whether the parameter indices have been computed.
///
Expand All @@ -1720,36 +1709,35 @@ class DifferentiableAttr final
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
GenericSignature DerivativeGenericSignature;
/// The source location of the implicitly inherited protocol requirement
/// `@differentiable` attribute. Used for diagnostics, not serialized.
///
/// This is set during conformance type-checking, only for implicit
/// `@differentiable` attributes created for non-public protocol witnesses of
/// protocol requirements with `@differentiable` attributes.
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;

explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenericSignature);

public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(AbstractFunctionDecl *original,
bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameRefWithLoc> jvp,
Optional<DeclNameRefWithLoc> vjp,
GenericSignature derivativeGenSig);

Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
Expand All @@ -1758,16 +1746,6 @@ class DifferentiableAttr final
/// Should only be used by parsing and deserialization.
void setOriginalDeclaration(Decl *originalDeclaration);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Optional<DeclNameRefWithLoc> getJVP() const { return JVP; }

/// Get the optional 'vjp:' function name and location.
/// Use this instead of `getVJPFunction` to check whether the attribute has a
/// registered VJP.
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }

private:
/// Returns true if the given `@differentiable` attribute has been
/// type-checked.
Expand Down Expand Up @@ -1800,10 +1778,13 @@ class DifferentiableAttr final
DerivativeGenericSignature = derivativeGenSig;
}

FuncDecl *getJVPFunction() const { return JVPFunction; }
void setJVPFunction(FuncDecl *decl);
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);
SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const {
return ImplicitlyInheritedDifferentiableAttrLocation;
}
void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) {
assert(isImplicit());
ImplicitlyInheritedDifferentiableAttrLocation = loc;
}

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
Expand All @@ -1812,9 +1793,7 @@ class DifferentiableAttr final

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
bool omitDerivativeFunctions = false) const;
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand Down
62 changes: 62 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,41 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// A component of a SIL `@differentiable` function-typed value.
struct NormalDifferentiableFunctionTypeComponent {
enum innerty : unsigned { Original = 0, JVP = 1, VJP = 2 } rawValue;

NormalDifferentiableFunctionTypeComponent() = default;
NormalDifferentiableFunctionTypeComponent(innerty rawValue)
: rawValue(rawValue) {}
NormalDifferentiableFunctionTypeComponent(
AutoDiffDerivativeFunctionKind kind);
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
operator innerty() const { return rawValue; }

/// Returns the derivative function kind, if the component is a derivative
/// function.
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// A component of a SIL `@differentiable(linear)` function-typed value.
struct LinearDifferentiableFunctionTypeComponent {
enum innerty : unsigned {
Original = 0,
Transpose = 1,
} rawValue;

LinearDifferentiableFunctionTypeComponent() = default;
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
: rawValue(rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
operator innerty() const { return rawValue; }
};

/// A derivative function configuration, uniqued in `ASTContext`.
/// Identifies a specific derivative function given an original function.
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
Expand Down Expand Up @@ -406,6 +441,33 @@ GenericSignature getConstrainedDerivativeGenericSignature(
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
bool isTranspose = false);

/// Retrieve config from the function name of a variant of
/// `Builtin.applyDerivative`, e.g. `Builtin.applyDerivative_jvp_arity2`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinApplyDerivativeConfig(
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
unsigned &arity, bool &rethrows);

/// Retrieve config from the function name of a variant of
/// `Builtin.applyTranspose`, e.g. `Builtin.applyTranspose_arity2`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinApplyTransposeConfig(
StringRef operationName, unsigned &arity, bool &rethrows);

/// Retrieve config from the function name of a variant of
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
/// `Builtin.differentiableFunction_arity1_throws`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinDifferentiableOrLinearFunctionConfig(
StringRef operationName, unsigned &arity, bool &throws);

/// Retrieve config from the function name of a variant of
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
/// `Builtin.differentiableFunction_arity1_throws`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinDifferentiableOrLinearFunctionConfig(
StringRef operationName, unsigned &arity, bool &throws);

} // end namespace autodiff

} // end namespace swift
Expand Down
12 changes: 12 additions & 0 deletions include/swift/AST/Builtins.def
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,18 @@ BUILTIN_SIL_OPERATION(ConvertStrongToUnownedUnsafe, "convertStrongToUnownedUnsaf
/// now.
BUILTIN_SIL_OPERATION(ConvertUnownedUnsafeToGuaranteed, "convertUnownedUnsafeToGuaranteed", Special)

/// applyDerivative
BUILTIN_SIL_OPERATION(ApplyDerivative, "applyDerivative", Special)

/// applyTranspose
BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special)

/// differentiableFunction
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)

/// linearFunction
BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special)

#undef BUILTIN_SIL_OPERATION

// BUILTIN_RUNTIME_CALL - A call into a runtime function.
Expand Down
Loading

0 comments on commit 84373d3

Please sign in to comment.