Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtape + Fix Recursion #171

Open
wants to merge 10 commits into
base: subtape
Choose a base branch
from
28 changes: 19 additions & 9 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mutable struct TapedFunction{F, TapeType}
arg_binding_slots::Vector{Int} # arg indices in binding_values
retval_binding_slot::Int # 0 indicates the function has not returned
deepcopy_types::Type # use a Union type for multiple types
subtapes::IdDict{Any,TapedFunction}

function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T}
args_type = _accurate_typeof.(args)
Expand All @@ -66,13 +67,14 @@ mutable struct TapedFunction{F, TapeType}
if cache && haskey(TRCache, cache_key) # use cache
cached_tf = TRCache[cache_key]::TapedFunction{F, T}
tf = copy(cached_tf)
tf.counter = 1
# we have to reset the counters of cached tapes (also the counters of subtapes)
reset_counters!(tf)
return tf
end
ir = _infer(f, args_type)
binding_values, slots, tape = translate!(RawTape(), ir)

tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types)
tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types, IdDict{Any,TapedFunction}())
TRCache[cache_key] = tf # set cache
return tf
end
Expand All @@ -82,7 +84,7 @@ mutable struct TapedFunction{F, TapeType}

function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape,
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types)
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types, tf.subtapes)
end

TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf)
Expand Down Expand Up @@ -219,8 +221,11 @@ function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F
output = if is_primitive(func, inputs...)
func(inputs...)
else
tf_inner = TapedFunction(func, inputs..., cache=true)
tf_inner(inputs...; callback=callback)
tf_inner = get!(tf.subtapes, instr) do
TapedFunction(func, inputs...; cache=true)
end
# continuation=false breaks "Multiple func calls subtapes" and "Copying task with subtapes"
tf_inner(inputs...; callback=callback, continuation=true)
end
_update_var!(tf, instr.output, output)
tf.counter += 1
Expand Down Expand Up @@ -416,10 +421,7 @@ function translate!!(var::IRVar, line::Expr,
# Only some of the function calls can be optimized even though many of their results are
# inferred as constants: we only optimize primitive and datatype constants for now. For
# optimised function calls, we will evaluate the function at compile-time and cache results.
if isconst
v = ir.ssavaluetypes[var.id].val
_canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
end

args = map(_bind_fn, line.args)
# args[1] is the function
func = line.args[1]
Expand Down Expand Up @@ -515,5 +517,13 @@ end
function Base.copy(tf::TapedFunction)
new_tf = TapedFunction(tf)
new_tf.binding_values = copy_bindings(tf.binding_values, tf.deepcopy_types)
new_tf.subtapes = IdDict{Any,TapedFunction}(func => copy(subtape) for (func, subtape) in tf.subtapes)
return new_tf
end

# when copying we want to keep the counters
# but if we instantiate new TapedTask, we have to reset the counters of cached tapes
function reset_counters!(tf::TapedFunction)
tf.counter = 1
foreach(reset_counters!, values(tf.subtapes))
end
2 changes: 1 addition & 1 deletion src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ end
function produce(val)
is_in_tapedtask() || return nothing
ttask = current_task().storage[:tapedtask]::TapedTask
length(ttask.produced_val) > 1 &&
length(ttask.produced_val) > 0 &&
error("There is a produced value which is not consumed.")
push!(ttask.produced_val, val)
return nothing
Expand Down
55 changes: 55 additions & 0 deletions test/tape_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,59 @@
@test consume(ttask) == 1
@test consume(ttask2) == 2
end

@testset "Copying task with subtapes" begin
function f2()
produce(1)
produce(2)
end

function g2()
f2()
end

Libtask.is_primitive(::typeof(f2), args...) = false

ttask = TapedTask(g2)
@test consume(ttask) == 1

ttask2 = copy(ttask)
@test consume(ttask2) == 2
@test consume(ttask) == 2

@test consume(ttask2) === nothing
@test consume(ttask) === nothing
end

@testset "Multiple func calls subtapes" begin
function f3()
produce(1)
produce(2)
end

function g3()
f3()
f3()
end

Libtask.is_primitive(::typeof(f3), args...) = false

ttask = TapedTask(g3)

@test consume(ttask) == 1
ttask2 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) == 1
ttask3 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) === nothing

@test consume(ttask2) == 2
@test consume(ttask2) == 1
@test consume(ttask2) == 2
@test consume(ttask2) === nothing

@test consume(ttask3) == 2
@test consume(ttask3) === nothing
end
end
56 changes: 55 additions & 1 deletion test/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,59 @@
@test ttask2.task.exception isa BoundsError
end
end

@testset "Too many producers" begin
function f()
produce(1)
produce(2)
end

function g()
f()
end

ttask = TapedTask(g)
@test_throws Exception consume(ttask)
end

@testset "Multiple producers for non-primitive" begin
function f2()
produce(1)
produce(2)
end
Libtask.is_primitive(::typeof(f2), args...) = false

function g2()
f2()
end

ttask = TapedTask(g2)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Run two times" begin
function f4()
produce(2)
end

function g4()
produce(1)
f4()
end

Libtask.is_primitive(::typeof(f4), args...) = false

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end
end
end
end
71 changes: 66 additions & 5 deletions test/tf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,70 @@ Libtask.is_primitive(::typeof(foo), args...) = false
@test typeof(r) === Float64
end
@testset "recurse into function" begin
tf = Libtask.TapedFunction(bar, 5.0)
count = 0
tf(4.0; callback=() -> (count += 1))
@test count == 9
function recurse(n::Int)
if n == 0
return 0
end
recurse(n-1)
produce(n)
end
Libtask.is_primitive(::typeof(recurse), args...) = false
ttask = TapedTask(recurse, 3)

@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) == 3
@test consume(ttask) === nothing

function recurse2(n::Int)
if n == 0
return 0
end
produce(n)
recurse2(n-1)
end
Libtask.is_primitive(::typeof(recurse2), args...) = false
ttask = TapedTask(recurse2, 3)

@test consume(ttask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end

@testset "Not optimize mutating call" begin

function f!(a)
a[1] = 2
return 1
end

function g1()
a = [1,2]
a[2] = f!(a)
produce(a[1])
end

ttask = TapedTask(g1)
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Not optimize producing call" begin
function f2()
produce(2)
return 1
end

function g2()
a = [1]
a[1] = f2()
produce(a[1])
end

ttask = TapedTask(g2)
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end
end
end
Loading