Skip to content

Commit

Permalink
store and run subtapes + remove optimization of const return type fun…
Browse files Browse the repository at this point in the history
…ctions
  • Loading branch information
markus7800 committed Jan 17, 2024
1 parent 789c1b3 commit e3f1903
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 17 deletions.
22 changes: 14 additions & 8 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mutable struct TapedFunction{F, TapeType}
arg_binding_slots::Vector{Int} # arg indices in binding_values
retval_binding_slot::Int # 0 indicates the function has not returned
deepcopy_types::Type # use a Union type for multiple types
subtapes::IdDict{Any,TapedFunction}

function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T}
args_type = _accurate_typeof.(args)
Expand All @@ -72,7 +73,7 @@ mutable struct TapedFunction{F, TapeType}
ir = _infer(f, args_type)
binding_values, slots, tape = translate!(RawTape(), ir)

tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types)
tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types, IdDict{Any,TapedFunction}())
TRCache[cache_key] = tf # set cache
return tf
end
Expand All @@ -82,7 +83,7 @@ mutable struct TapedFunction{F, TapeType}

function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape,
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types)
tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types, tf.subtapes)
end

TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf)
Expand Down Expand Up @@ -219,8 +220,9 @@ function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F
output = if is_primitive(func, inputs...)
func(inputs...)
else
tf_inner = TapedFunction(func, inputs..., cache=true)
tf_inner(inputs...; callback=callback)
tf_inner = get!(tf.subtapes, instr, TapedFunction(func, inputs..., cache=true))
tf_inner(inputs...; callback=callback, continuation=true)
@assert tf_inner.retval_binding_slot != 0 # tf_inner should return
end
_update_var!(tf, instr.output, output)
tf.counter += 1
Expand Down Expand Up @@ -416,10 +418,13 @@ function translate!!(var::IRVar, line::Expr,
# Only some of the function calls can be optimized even though many of their results are
# inferred as constants: we only optimize primitive and datatype constants for now. For
# optimised function calls, we will evaluate the function at compile-time and cache results.
if isconst
v = ir.ssavaluetypes[var.id].val
_canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
end

# even if return value is const. The call may has side-effects.
# In particular, it may contain produce statements.
# if isconst
# v = ir.ssavaluetypes[var.id].val
# _canbeoptimized(v) && return _const_instruction(var, v, bindings, ir)
# end
args = map(_bind_fn, line.args)
# args[1] is the function
func = line.args[1]
Expand Down Expand Up @@ -515,5 +520,6 @@ end
function Base.copy(tf::TapedFunction)
new_tf = TapedFunction(tf)
new_tf.binding_values = copy_bindings(tf.binding_values, tf.deepcopy_types)
new_tf.subtapes = IdDict{Any,TapedFunction}(func => copy(subtape) for (func, subtape) in tf.subtapes)
return new_tf
end
58 changes: 57 additions & 1 deletion test/tape_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,60 @@
@test consume(ttask) == 1
@test consume(ttask2) == 2
end
end

@testset "Copying task with subtapes" begin
function f()
produce(1)
produce(2)
end

function g()
f()
end

Libtask.is_primitive(::typeof(f), args...) = false

ttask = TapedTask(g)
@test consume(ttask) == 1

ttask2 = copy(ttask)
@test consume(ttask2) == 2
@test consume(ttask) == 2

@test consume(ttask2) === nothing
@test consume(ttask) === nothing
end

@testset "Multiple func calls subtapes" begin
function f()
produce(1)
produce(2)
end

function g()
f()
f()
end

Libtask.is_primitive(::typeof(f), args...) = false

ttask = TapedTask(g)

@test consume(ttask) == 1
ttask2 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) == 1
ttask3 = copy(ttask)
@test consume(ttask) == 2
@test consume(ttask) === nothing

@test consume(ttask2) == 2
@test consume(ttask2) == 1
@test consume(ttask2) == 2
@test consume(ttask2) === nothing

@test consume(ttask3) == 2
@test consume(ttask3) === nothing

end
end
4 changes: 1 addition & 3 deletions test/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@
end

@testset "Too much producers" begin


function f()
produce(1)
produce(2)
Expand All @@ -172,4 +170,4 @@
@test_throws Exception consume(ttask)
end
end
end
end
76 changes: 71 additions & 5 deletions test/tf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,75 @@ Libtask.is_primitive(::typeof(foo), args...) = false
@test typeof(r) === Float64
end
@testset "recurse into function" begin
tf = Libtask.TapedFunction(bar, 5.0)
count = 0
tf(4.0; callback=() -> (count += 1))
@test count == 9
# tf = Libtask.TapedFunction(bar, 5.0)
# count = 0
# tf(4.0; callback=() -> (count += 1))
# @test count == 9

function recurse(n::Int)
if n == 0
return 0
end
recurse(n-1)
produce(n)
end
Libtask.is_primitive(::typeof(recurse), args...) = false
ttask = TapedTask(recurse, 3)

@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) == 3
@test consume(ttask) === nothing

function recurse2(n::Int)
if n == 0
return 0
end
produce(n)
recurse2(n-1)
end
Libtask.is_primitive(::typeof(recurse2), args...) = false
ttask = TapedTask(recurse2, 3)

@test consume(ttask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end

@testset "Not optimize mutating call" begin

function f!(a)
a[1] = 2
return 1
end

function g1()
a = [1,2]
a[2] = f!(a)
produce(a[1])
end

ttask = TapedTask(g1)
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Not optimize producing call" begin
function f2()
produce(2)
return 1
end

function g2()
a = [1]
a[1] = f2()
produce(a[1])
end

ttask = TapedTask(g2)
@test consume(ttask) == 2
@test consume(ttask) == 1
@test consume(ttask) === nothing
end
end
end

0 comments on commit e3f1903

Please sign in to comment.