Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtape + Fix Recursion #171

Open
wants to merge 10 commits into
base: subtape
Choose a base branch
from
29 changes: 21 additions & 8 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mutable struct TapedFunction{F, TapeType}
arg_binding_slots::Vector{Int} # arg indices in binding_values
retval_binding_slot::Int # 0 indicates the function has not returned
deepcopy_types::Type # use a Union type for multiple types
subtapes::IdDict{Any,TapedFunction}

function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T}
args_type = _accurate_typeof.(args)
Expand All @@ -72,7 +73,7 @@ mutable struct TapedFunction{F, TapeType}
ir = _infer(f, args_type)
binding_values, slots, tape = translate!(RawTape(), ir)

tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types)
tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types, IdDict{Any,TapedFunction}())
TRCache[cache_key] = tf # set cache
return tf
end
Expand All @@ -82,7 +83,7 @@ mutable struct TapedFunction{F, TapeType}

function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape,
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types)
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types, tf.subtapes)
end

TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf)
Expand Down Expand Up @@ -219,8 +220,9 @@ function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F
output = if is_primitive(func, inputs...)
func(inputs...)
else
tf_inner = TapedFunction(func, inputs..., cache=true)
tf_inner(inputs...; callback=callback)
tf_inner = get!(tf.subtapes, instr, TapedFunction(func, inputs..., cache=true))
markus7800 marked this conversation as resolved.
Show resolved Hide resolved
# continuation=false breaks "Multiple func calls subtapes" and "Copying task with subtapes"
tf_inner(inputs...; callback=callback, continuation=true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC continuation=true is now needed because we might be calling the same TapedFunction multiple times?

EDIT: In contrast to before when were just constructing a new one every time.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your understanding is correct.

In the tape, a function call constitutes a single instruction.
Previous to my changes, this function call was only allowed to call produce(val) once.
I think the reason was that without subtapes, we were only able to continue the execution after the call instruction.
This was safe to do if we only allow one produce and throw an error otherwise.

By flagging a function g as non-primitive is_primitive(typeof(g),...) = false, we tell Libtask that we want to be able to interrupt the execution in g (e.g. at produce(4)).

For this to work any parent tape function f (the caller) has to actually own the subtape of g.
The instruction counter of g has to be preserved, such that with continuation=true we continue exactly at the correct instruction in g. And yes, g is called multiple times with an updated counter.

Thus, when we fork/copy a task, we also have to copy all subtapes with their counters.

Also, when reusing a cached tape we have to reset all the counters of the subtapes counter = 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense; thank you! And again, great work:)

end
_update_var!(tf, instr.output, output)
tf.counter += 1
Expand Down Expand Up @@ -416,10 +418,13 @@ function translate!!(var::IRVar, line::Expr,
# Only some of the function calls can be optimized even though many of their results are
# inferred as constants: we only optimize primitive and datatype constants for now. For
# optimised function calls, we will evaluate the function at compile-time and cache results.
if isconst
v = ir.ssavaluetypes[var.id].val
_canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
end

# even if return value is const. The call may has side-effects.
# In particular, it may contain produce statements.
# if isconst
# v = ir.ssavaluetypes[var.id].val
# _canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
# end
markus7800 marked this conversation as resolved.
Show resolved Hide resolved
args = map(_bind_fn, line.args)
# args[1] is the function
func = line.args[1]
Expand Down Expand Up @@ -515,5 +520,13 @@ end
function Base.copy(tf::TapedFunction)
new_tf = TapedFunction(tf)
new_tf.binding_values = copy_bindings(tf.binding_values, tf.deepcopy_types)
new_tf.subtapes = IdDict{Any,TapedFunction}(func => copy(subtape) for (func, subtape) in tf.subtapes)
return new_tf
end

# when copying we want to keep the counters
# but if we instantiate new TapedTask, we have to reset the counters of cached tapes
function reset_counters!(tf::TapedFunction)
tf.counter = 1
reset_counters!.(values(tf.subtapes))
markus7800 marked this conversation as resolved.
Show resolved Hide resolved
end
3 changes: 2 additions & 1 deletion src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref b
deepcopy = Union{BASE_COPY_TYPES, deepcopy_types}
end
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy)
reset_counters!(tf) # we have to reset the counters of cached tapes
TapedTask(tf, args...)
end

Expand Down Expand Up @@ -122,7 +123,7 @@ end
function produce(val)
is_in_tapedtask() || return nothing
ttask = current_task().storage[:tapedtask]::TapedTask
length(ttask.produced_val) > 1 &&
length(ttask.produced_val) > 0 &&
error("There is a produced value which is not consumed.")
push!(ttask.produced_val, val)
return nothing
Expand Down
55 changes: 55 additions & 0 deletions test/tape_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,59 @@
@test consume(ttask) == 1
@test consume(ttask2) == 2
end

@testset "Copying task with subtapes" begin
function f2()
produce(1)
produce(2)
end

function g2()
f2()
end

Libtask.is_primitive(::typeof(f2), args...) = false

ttask = TapedTask(g2)
@test consume(ttask) == 1

ttask2 = copy(ttask)
@test consume(ttask2) == 2
@test consume(ttask) == 2

@test consume(ttask2) === nothing
@test consume(ttask) === nothing
end

@testset "Multiple func calls subtapes" begin
function f3()
produce(1)
produce(2)
end

function g3()
f3()
f3()
end

Libtask.is_primitive(::typeof(f3), args...) = false

ttask = TapedTask(g3)

@test consume(ttask) == 1
ttask2 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) == 1
ttask3 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) === nothing

@test consume(ttask2) == 2
@test consume(ttask2) == 1
@test consume(ttask2) == 2
@test consume(ttask2) === nothing

@test consume(ttask3) == 2
@test consume(ttask3) === nothing
end
end
56 changes: 55 additions & 1 deletion test/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,59 @@
@test ttask2.task.exception isa BoundsError
end
end

@testset "Too much producers" begin
markus7800 marked this conversation as resolved.
Show resolved Hide resolved
function f()
produce(1)
produce(2)
end

function g()
f()
end

ttask = TapedTask(g)
@test_throws Exception consume(ttask)
end

@testset "Multiple producers for non-primitive" begin
function f2()
produce(1)
produce(2)
end
Libtask.is_primitive(::typeof(f2), args...) = false

function g2()
f2()
end

ttask = TapedTask(g2)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Run two times" begin
function f4()
produce(2)
end

function g4()
produce(1)
f4()
end

Libtask.is_primitive(::typeof(f4), args...) = false

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end
end
end
end
76 changes: 71 additions & 5 deletions test/tf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,75 @@ Libtask.is_primitive(::typeof(foo), args...) = false
@test typeof(r) === Float64
end
@testset "recurse into function" begin
tf = Libtask.TapedFunction(bar, 5.0)
count = 0
tf(4.0; callback=() -> (count += 1))
@test count == 9
# tf = Libtask.TapedFunction(bar, 5.0)
# count = 0
# tf(4.0; callback=() -> (count += 1))
# @test count == 9

markus7800 marked this conversation as resolved.
Show resolved Hide resolved
function recurse(n::Int)
if n == 0
return 0
end
recurse(n-1)
produce(n)
end
Libtask.is_primitive(::typeof(recurse), args...) = false
ttask = TapedTask(recurse, 3)

@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) == 3
@test consume(ttask) === nothing

function recurse2(n::Int)
if n == 0
return 0
end
produce(n)
recurse2(n-1)
end
Libtask.is_primitive(::typeof(recurse2), args...) = false
ttask = TapedTask(recurse2, 3)

@test consume(ttask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end

@testset "Not optimize mutating call" begin

function f!(a)
a[1] = 2
return 1
end

function g1()
a = [1,2]
a[2] = f!(a)
produce(a[1])
end

ttask = TapedTask(g1)
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Not optimize producing call" begin
function f2()
produce(2)
return 1
end

function g2()
a = [1]
a[1] = f2()
produce(a[1])
end

ttask = TapedTask(g2)
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end
end
end
Loading