diff --git a/julia/Project.toml b/julia/Project.toml index bbe8d88d6..4c2b53e71 100644 --- a/julia/Project.toml +++ b/julia/Project.toml @@ -3,14 +3,18 @@ uuid = "bb22f25d-cb49-471c-b017-930e329a2928" version = "0.1.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CombinedParsers = "5ae71ed2-6f8a-4ed1-b94f-e14e8158f19e" [compat] +ChainRulesCore = "^1.0" CombinedParsers = "^0.2" +Zygote = "^0.6.22" julia = "^1.6" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "Zygote"] diff --git a/julia/README.md b/julia/README.md index 03144e7c2..74b59a602 100644 --- a/julia/README.md +++ b/julia/README.md @@ -2,7 +2,7 @@ DexCall provides a mechanism for calling dex-lang code from JuliaLang. Three main mechanism are provided for this: `evaluate`, `DexModule` and the `dex_func` string macro. -Two helper methods are also provided: `juliaize` and `NativeFunction`. +Several helper methods are also provided: `juliaize`, `dexize`, and `NativeFunction`. ## `evaluate`: just run a single Dex expression. `evaluate` takes in a Dex expression as a string and runs it, returning a `Atom` (see below). @@ -53,7 +53,7 @@ julia> m.addTwo(m.y) "[44., 44., 44.]" ``` -## Atoms: `juliaize`, `NativeFunction` +## Atoms: `juliaize`, `dexize` and `NativeFunction` `evaluate` and the contents of a `DexModule` are returned as `Atom`s. These can be displayed, but not much else. @@ -87,6 +87,8 @@ julia> typeof(convert(Int64, m.x)) Int64 ``` +The inverse of `juliaize` is `dexize`, it is currently very limited. + To convert function `Atom`s into something you can execute as if it was a regular julia function use `NativeFunction`. This will compile it and handing inputs and outputs without needing to del with `Atom`s directly. diff --git a/julia/src/DexCall.jl b/julia/src/DexCall.jl index 90977e791..1f5f454c0 100644 --- a/julia/src/DexCall.jl +++ b/julia/src/DexCall.jl @@ -1,14 +1,16 @@ "Calling Dex from Julia" module DexCall + using ChainRulesCore using CombinedParsers using CombinedParsers.Regexp - export evaluate, DexError, DexModule, juliaize, NativeFunction, @dex_func_str + export evaluate, DexError, DexModule, dexize, juliaize, NativeFunction, @dex_func_str include("api_types.jl") include("api.jl") include("evaluate.jl") include("native_function.jl") + include("chainrules.jl") # use this to disable free'ing haskell objects after we have closed the RTS const NO_FREE = Ref(false) diff --git a/julia/src/api_types.jl b/julia/src/api_types.jl index d6b206501..a343dd43a 100644 --- a/julia/src/api_types.jl +++ b/julia/src/api_types.jl @@ -55,11 +55,3 @@ function CAtom(atm::Ptr{HsAtom}) iszero(success) && throw_from_dex() return result[] end - -""" - juliaize(x) - -Get the corresponding Julia object from some output of Dex. -""" -juliaize(x::CAtom) = bust_union(x) -juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x)) \ No newline at end of file diff --git a/julia/src/chainrules.jl b/julia/src/chainrules.jl new file mode 100644 index 000000000..bd497f57e --- /dev/null +++ b/julia/src/chainrules.jl @@ -0,0 +1,93 @@ + +function ChainRulesCore.frule((_, ẋs...), f_native::NativeFunction{R}, xs...) where R + f = f_native.atom + env = f.ctx + env = insert(env, "f", f.ptr) + + primal_binders = f_native.argument_signature + primal_args_sig = repr_sig(primal_binders) + primal_args = extract_arg_names(primal_binders) + + tangent_binders = generate_tangent_binders(primal_binders) + tangent_args_sig = repr_sig(tangent_binders) + tangent_args = extract_arg_names(tangent_binders) + + primal_res_sig = repr_result_sig(f_native.result_signature) + dual_res_sig = "($primal_res_sig&$primal_res_sig)" + m = DexModule(""" + def frule_inner $primal_args_sig->$tangent_args_sig : $dual_res_sig = + (y, pushforward) = linearize f $primal_args + dy = pushforward $tangent_args + (y, dy) + """, + env + ) + # Convert the Atom into `NativeFunction` so can work with any argument type: + frule_inner_native = NativeFunction(m.frule_inner) + return frule_inner_native(xs..., ẋs...) +end + +extract_arg_names(binds::Vector{Binder}) = join((bind.name for bind in binds), " ") +""" +Given the `Binder` for the signature of a primal argument/s constructs the matching one +for the tangent. +For now we only support types with tangent type matcing primal type +""" +function generate_tangent_binders(pbinds::Vector{Binder}) + return [generate_tangent_binder(pbind) for pbind in pbinds if !pbind.implicit] +end +function generate_tangent_binder(pbind::Binder) + pbind.implicit && throw(DomainError(pbind, "Implict arguments have no tangents")) + return Binder(Symbol(:d, pbind.name), pbind.type, false) +end + +function ChainRulesCore.frule((_, ẋ), f::Atom, x::Atom) + ẋ isa Atom || throw(DomainError(ẋ, "Tangent to an Atom must be an Atom")) + env = f.ctx + env = insert(env, "f", f.ptr) + env = insert(env, "dx", ẋ.ptr) + env = insert(env, "x", x.ptr) + + m = DexModule(raw""" + (y, pushforward) = linearize f x + dy = pushforward dx + """, + env + ) + return m.y, m.dy +end + +function ChainRulesCore.rrule(f::Atom, x::Atom) + env = f.ctx + env = insert(env, "f", f.ptr) + env = insert(env, "x", x.ptr) + + m = DexModule(raw""" + (y, pushforward) = linearize f x + pullback = transposeLinear pushforward + """, + env + ) + + # It is important that we close over `m` as otherwise the env may be destroyed by GC + pullback(x̄::Atom)= (NoTangent(), m.pullback(x̄)) + return m.y, pullback +end + +ChainRulesCore.frule((_, ẋ), ::typeof(juliaize), x) = juliaize(x), juliaize(ẋ) +function ChainRulesCore.rrule(::typeof(juliaize), x::Atom) + env= x.ctx + + # pullback must take a julia typed cotangent and give back a dex typed cotangent + juliaize_pullback(ȳ) = (NoTangent(), dexize(ȳ, env)) + return juliaize(x), juliaize_pullback +end + + +ChainRulesCore.frule((_, ẋ), ::typeof(dexize), x) = dexize(x), dexize(ẋ) +function ChainRulesCore.rrule(::typeof(dexize), x) + # pullback must take a dex typed cotangent and give back a julia typed cotangent + dexize_pullback(ȳ) = (NoTangent(), juliaize(ȳ)) + return dexize(x), dexize_pullback +end + diff --git a/julia/src/evaluate.jl b/julia/src/evaluate.jl index f86d16a3e..ccd7bed6c 100644 --- a/julia/src/evaluate.jl +++ b/julia/src/evaluate.jl @@ -13,9 +13,37 @@ end Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr)) +""" + juliaize(x) + +Get the corresponding Julia object from some output of Dex. +""" +juliaize(x::CAtom) = bust_union(x) +juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x)) juliaize(x::Atom) = juliaize(x.ptr) Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom)) +""" + dexize(x) + +Get the corresponding Dex object from some output of Julia. + +NB: this is currently a hack that goes via string processing. +""" +function dexize(x::Float32, _module=PRELUDE, env=_module) + isnan(x) && return evaluate("nan", _module, env) + x === Inf32 && return evaluate("infinity", _module, env) + x === -Inf32 && return evaluate("-infinity", _module, env) + + str = repr(x) + if endswith(str, "f0") + evaluate(str[1:end-2], _module, env) + else + # convert "123f45" into "123 * (intpow 10.0 45)" + evaluate(replace(str, "f"=> " * (intpow 10.0 ") * ")", _module, env) + end +end + function (self::Atom)(args...) # TODO: Make those calls more hygenic @@ -60,12 +88,18 @@ julia> m.y "84" ``` """ -function DexModule(source::AbstractString) - ctx = dex_eval(PRELUDE, source) +function DexModule(source::AbstractString, parent_ctx=PRELUDE) + ctx = dex_eval(parent_ctx, source) ctx == C_NULL && throw_from_dex() m = DexModule(ctx) finalizer(m) do _m - destroy_context(getfield(_m, :ctx)) + # TODO: Undo commenting this out. But for now this causes a lot of problems. + # DexModule will often go out of scope, while a Atom attached to that context still + # exists. Possibly we need to make the ctx a mutable struct everywhere, and then + # attach the finalizer there. + #(also will let us delete some manual destroys in other palces) + + #destroy_context(getfield(_m, :ctx)) end return m end diff --git a/julia/src/native_function.jl b/julia/src/native_function.jl index 69e06edb2..aa898acdd 100644 --- a/julia/src/native_function.jl +++ b/julia/src/native_function.jl @@ -87,6 +87,32 @@ end ArrayBuilder{T}(size) where T = ArrayBuilder{T,length(size)}(size) +"representation of this as it would appear in a dex `def` function signature" +repr_sig(binders::AbstractVector{Binder}) = join(Iterators.map(repr_sig, binders), "->") +function repr_sig(binder::Binder) + str = "($(binder.name):$(repr_sig(binder.type)))" + if binder.implicit + str *= "?" + end + return str +end +function repr_sig(builder::ArrayBuilder{T}) where {T} + sizes_repr = Iterators.map(size_element -> "Fin $size_element", builder.size) + return join(sizes_repr, "=>") * "=>" * repr_sig(T) +end + +# For most types, like Int32 and Float64 Dex and Julia use identical names +# and for Symbols represent implicts they are also are made into strings by `string` +repr_sig(x) = string(x) +# TODO: Word8, Char etc? + + +"representation of this as it would appear in the result part of a dex `def` function signature" +repr_result_sig(x) = repr_sig(x) +function repr_result_sig(binders::AbstractVector{Binder}) + return "(" * join(Iterators.map(repr_result_sig, binders), "&") * ")" +end +repr_result_sig(binder::Binder) = repr_result_sig(binder.type) """ NativeFunction{R} @@ -98,14 +124,14 @@ Usually constructed using [`@dex_func_str`](@ref), or via `NativeFunction(atom)` on some [`DexCall.Atom`](@ref). """ struct NativeFunction{R} <: Function - c_func_ptr::Ptr{Nothing} + atom::Atom # non-compiled Atom form of this function + c_func_ptr::Ptr{Nothing} # compiled C API form of this function argument_signature::Vector{Binder} result_signature::Vector{Binder} end -NativeFunction(atom::Atom, jit=JIT) = NativeFunction(atom.ptr, atom.ctx, jit) -function NativeFunction(atom::Ptr{HsAtom}, ctx=PRELUDE, jit=JIT) - c_func_ptr = compile(ctx, atom, jit) +function NativeFunction(atom::Atom, ctx=atom.ctx, jit=JIT) + c_func_ptr = compile(ctx, atom.ptr, jit) sig_ptr = get_function_signature(c_func_ptr, jit) sig_ptr == C_NULL && error("Failed to retrieve the function signature") @@ -114,6 +140,7 @@ function NativeFunction(atom::Ptr{HsAtom}, ctx=PRELUDE, jit=JIT) result_signature = parse_sig(signature.res) R = result_type(result_signature) f = NativeFunction{R}( + atom, c_func_ptr, parse_sig(signature.arg), result_signature @@ -295,6 +322,7 @@ function parse_sig(sig) parser("i32"=>Int32), parser("i64"=>Int64), parser("i8"=>Int8), + # TODO: Word8, Char etc? ) size_ele = NumericParser(Int) | name sizes = join(Repeat(size_ele),",") diff --git a/julia/test/chainrules.jl b/julia/test/chainrules.jl new file mode 100644 index 000000000..49d48fe51 --- /dev/null +++ b/julia/test/chainrules.jl @@ -0,0 +1,52 @@ +const double_dex = evaluate(raw"\x:Float. 2.0 * x") + +@testset "frule: dexize, evaluate, juliaize" begin + a, ȧ = frule((NoTangent(), 10f0), dexize, 1.5f0) + b, ḃ = frule((NoTangent(), ȧ), double_dex, a) + c, ċ = frule((NoTangent(), ḃ), juliaize, b) + @test c === 3.0f0 + @test ċ === 20f0 +end + +@testset "rrule: dexize, evaluate, juliaize" begin + x = 1.5f0 + a, a_pb = rrule(dexize, x) + b, b_pb = rrule(double_dex, a) + c, c_pb = rrule(juliaize, b) + + @test c === 3.0f0 + c̄ = 10f0 + _, b̄ = c_pb(c̄) + _, ā = b_pb(b̄) + _, x̄ = a_pb(ā) + + @test x̄ === 20f0 +end + +@testset "Integration Test: Zygote.jl" begin + double_via_dex = juliaize ∘ double_dex ∘ dexize + y, pb= Zygote.pullback(double_via_dex, 1.5f0) + @test y == 3f0 + @test pb(1f0) == (2f0,) +end + + +@testset "frule NativeFunction" begin + dex_func"decimate_dex = \x:Float. x/10.0" + @test frule((NoTangent(), 50f0), decimate_dex, 150f0) === (15f0, 5f0) + + dex_func"sum3_dex = \x:(Fin 3=>Float). sum x" + @test frule((NoTangent(), [1f0, 10f0, 100f0]), sum3_dex, [1f0, 2f0, 3f0]) === (6f0, 111f0) + + dex_func"twovec_dex = \x:(Float32). [x,x]" + twovec_dex(1f2) + @test frule((NoTangent(), 10f0), twovec_dex, 4f0) == ([4f0, 4f0], [10f0,10f0]) + + # With Implicts + dex_func"def mysum_dex (arg0:Int32)?-> (arg1:Fin arg0 => Float32) : Float32 = sum arg1" + @test_broken frule((NoTangent(), [1f0, 10f0, 100f0, 1000f0]), mysum_dex, [1f0, 2f0, 3f0, 4f0]) === (10f0, 1111f0) + + # With multiple arguments + dex_func"add_dex = \x:Float32 y:Float32. x+y" + @test_broken frule((NoTangent(), 10f0, 100f0), add_dex, 1f0, 2f0) +end \ No newline at end of file diff --git a/julia/test/evaluate.jl b/julia/test/evaluate.jl index e6d0f6f6f..3536b1e3e 100644 --- a/julia/test/evaluate.jl +++ b/julia/test/evaluate.jl @@ -21,6 +21,16 @@ @test juliaize(evaluate("IToW8 65")) === Int8(65) end + @testset "dexize" begin + @test juliaize(dexize(0f0)) === 0f0 + @test juliaize(dexize(42f0)) === 42f0 + @test juliaize(dexize(123f15)) === 123f15 + @test dexize(123f15) isa DexCall.Atom + @test isnan(juliaize(dexize(NaN32))) + @test (juliaize(dexize(Inf32))) == Inf32 + @test (juliaize(dexize(-Inf32))) == -Inf32 + end + @testset "Atom function call" begin m = DexModule(""" def addOne (x: Float) : Float = x + 1.0 diff --git a/julia/test/native_function.jl b/julia/test/native_function.jl index dcb0cc37d..fc96d096f 100644 --- a/julia/test/native_function.jl +++ b/julia/test/native_function.jl @@ -1,6 +1,7 @@ @testset "native_function.jl" begin @testset "signature parser" begin + # Testing Implementation details, can remove if implementation changes @testset "$example" for example in ( "arg0:f32", "arg0:f32,arg1:f32", @@ -17,6 +18,20 @@ @test DexCall.parse_sig(example) isa Vector{DexCall.Binder} end end + + @testset "signature repr" begin + # Testing Implementation details, can remove if implementation changes + as_in_dex_sig = DexCall.repr_sig ∘ DexCall.parse_sig + @test as_in_dex_sig("arg0:f32") == "(arg0:Float32)" + @test as_in_dex_sig("arg0:f32,arg1:f32") == "(arg0:Float32)->(arg1:Float32)" + @test as_in_dex_sig("arg0:i64,arg1:i32") == "(arg0:Int64)->(arg1:Int32)" + @test as_in_dex_sig("arg0:f32[10]") == "(arg0:Fin 10=>Float32)" + @test as_in_dex_sig("?arg0:i32,arg1:f32[arg0]") == "(arg0:Int32)?->(arg1:Fin arg0=>Float32)" + @test as_in_dex_sig("arg2:f32[arg0]") == "(arg2:Fin arg0=>Float32)" + @test as_in_dex_sig("?arg0:i32,?arg1:i32,arg2:f32[arg0,arg1]") == "(arg0:Int32)?->(arg1:Int32)?->(arg2:Fin arg0=>Fin arg1=>Float32)" + @test as_in_dex_sig("arg3:f32[arg1,arg0]") == "(arg3:Fin arg1=>Fin arg0=>Float32)" + @test as_in_dex_sig("arg0:f32,?arg1:i32,arg2:f32[arg1]") == "(arg0:Float32)->(arg1:Int32)?->(arg2:Fin arg1=>Float32)" + end @testset "dex_func anon funcs" begin @test dex_func"\x:Float. exp x"(0f0) === 1f0 diff --git a/julia/test/runtests.jl b/julia/test/runtests.jl index dc4065c12..94af13c1a 100644 --- a/julia/test/runtests.jl +++ b/julia/test/runtests.jl @@ -1,8 +1,11 @@ using Test using DexCall +using ChainRulesCore +using Zygote # for integration tests @testset "DexCall" begin include("api.jl") include("evaluate.jl") include("native_function.jl") + include("chainrules.jl") end \ No newline at end of file