forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdifferentiable_attr_access_control.swift
32 lines (26 loc) · 1.92 KB
/
differentiable_attr_access_control.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// RUN: %target-swift-frontend -typecheck -verify %s
// If the original function is "exported" (public or @usableFromInline), then
// its primal/adjoint must also be exported.
// Ok: all public.
@differentiable(reverse, adjoint: dfoo1(_:primal:seed:))
public func foo1(_ x: Float) -> Float { return 1 }
public func dfoo1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 }
// Ok: all internal.
struct CheckpointsFoo {}
@differentiable(reverse, primal: pfoo2(_:), adjoint: dfoo2(_:checkpoints:originalValue:seed:))
func foo2(_ x: Float) -> Float { return 1 }
func pfoo2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) }
func dfoo2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 }
// Ok: all private.
@differentiable(reverse, adjoint: dfoo3(_:primal:seed:))
private func foo3(_ x: Float) -> Float { return 1 }
private func dfoo3(_ x: Float, primal: Float, seed: Float) -> Float { return 1 }
// Error: adjoint not exported.
@differentiable(reverse, adjoint: dbar1(_:primal:seed:)) // expected-error {{adjoint 'dbar1(_:primal:seed:)' is required to either be public or @usableFromInline because the original function 'bar1' is public or @usableFromInline}}
public func bar1(_ x: Float) -> Float { return 1 }
private func dbar1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 }
// Error: primal not exported.
@differentiable(reverse, primal: pbar2(_:), adjoint: dbar2(_:checkpoints:originalValue:seed:)) // expected-error {{primal 'pbar2' is required to either be public or @usableFromInline because the original function 'bar2' is public or @usableFromInline}}
@usableFromInline func bar2(_ x: Float) -> Float { return 1 }
func pbar2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) }
func dbar2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 }