Skip to content

Commit

Permalink
reset cached tape counters when instantiating new TapedTask
Browse files Browse the repository at this point in the history
  • Loading branch information
markus7800 committed Jan 17, 2024
1 parent ffc5805 commit 171ef24
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,10 @@ function Base.copy(tf::TapedFunction)
new_tf.subtapes = IdDict{Any,TapedFunction}(func => copy(subtape) for (func, subtape) in tf.subtapes)
return new_tf
end

# when copying we want to keep the counters
# but if we instantiate new TapedTask, we have to reset the counters of cached tapes
function reset_counters!(tf::TapedFunction)
tf.counter = 1
reset_counters!.(values(tf.subtapes))
end
1 change: 1 addition & 0 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref b
deepcopy = Union{BASE_COPY_TYPES, deepcopy_types}
end
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy)
reset_counters!(tf) # we have to reset the counters of cached tapes
TapedTask(tf, args...)
end

Expand Down
1 change: 0 additions & 1 deletion test/tape_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,5 @@

@test consume(ttask3) == 2
@test consume(ttask3) === nothing

end
end
23 changes: 23 additions & 0 deletions test/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,28 @@
@test consume(ttask) == 2
@test consume(ttask) === nothing
end

@testset "Run two times" begin
function f4()
produce(2)
end

function g4()
produce(1)
f4()
end

Libtask.is_primitive(::typeof(f4), args...) = false

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing

ttask = TapedTask(g4)
@test consume(ttask) == 1
@test consume(ttask) == 2
@test consume(ttask) === nothing
end
end
end

0 comments on commit 171ef24

Please sign in to comment.