Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
markus7800 committed Jan 17, 2024
1 parent 51944bc commit ffc5805
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F
func(inputs...)
else
tf_inner = get!(tf.subtapes, instr, TapedFunction(func, inputs..., cache=true))
# continuation=false breaks "Multiple func calls subtapes" and "Copying task with subtapes"
tf_inner(inputs...; callback=callback, continuation=true)
end
_update_var!(tf, instr.output, output)
Expand Down
24 changes: 12 additions & 12 deletions test/tape_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,18 @@
end

@testset "Copying task with subtapes" begin
function f()
function f2()
produce(1)
produce(2)
end

function g()
f()
function g2()
f2()
end

Libtask.is_primitive(::typeof(f), args...) = false
Libtask.is_primitive(::typeof(f2), args...) = false

ttask = TapedTask(g)
ttask = TapedTask(g2)
@test consume(ttask) == 1

ttask2 = copy(ttask)
Expand All @@ -217,19 +217,19 @@
end

@testset "Multiple func calls subtapes" begin
function f()
function f3()
produce(1)
produce(2)
end

function g()
f()
f()
function g3()
f3()
f3()
end

Libtask.is_primitive(::typeof(f), args...) = false
Libtask.is_primitive(::typeof(f3), args...) = false

ttask = TapedTask(g)
ttask = TapedTask(g3)

@test consume(ttask) == 1
ttask2 = copy(ttask)
Expand All @@ -248,4 +248,4 @@
@test consume(ttask3) === nothing

end
end
end
17 changes: 17 additions & 0 deletions test/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,22 @@
ttask = TapedTask(g)
@test_throws Exception consume(ttask)
end

@testset "Multiple producers for non-primitive" begin
function f2()
produce(1)
produce(2)
end
Libtask.is_primitive(::typeof(f2), args...) = false

function g2()
f2()
end

ttask = TapedTask(g2)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end
end
end

0 comments on commit ffc5805

Please sign in to comment.