Skip to content

Commit 885d2d2

Browse files
authored
Fix event count for multiple handlers (#556)
* Fix event count for multiple handlers If multiple event handlers are registered for the same event, the counter is incremented before each handler. This results in filters like `every: 10` to never match any events. * Add test for event counts * Remove dummy handler from filter tests * Add multiple handlers when testing event counts
1 parent 40aede6 commit 885d2d2

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

lib/axon/loop.ex

+2-2
Original file line numberDiff line numberDiff line change
@@ -1881,15 +1881,15 @@ defmodule Axon.Loop do
18811881
# attached to the loop.
18821882
# TODO(seanmor5): Custom events
18831883
defp fire_event(event, handler_fns, state, debug?) do
1884+
state = update_counts(state, event)
1885+
18841886
handler_fns[event]
18851887
|> Enum.reverse()
18861888
|> Enum.reduce_while({:continue, state}, fn {handler, filter}, {_, state} ->
18871889
if debug? do
18881890
Logger.debug("Axon.Loop fired event #{inspect(event)}")
18891891
end
18901892

1891-
state = update_counts(state, event)
1892-
18931893
if filter.(state, event) do
18941894
case handler.(state) do
18951895
{:continue, %State{} = state} ->

test/axon/loop_test.exs

+75
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,17 @@ defmodule Axon.LoopTest do
483483
end)
484484
end
485485

486+
def send_event_counts_handler(loop, event) do
487+
Axon.Loop.handle_event(loop, event, fn state ->
488+
send(self(), {event, state.event_counts})
489+
{:continue, state}
490+
end)
491+
end
492+
493+
def continue_handler(loop, event) do
494+
Axon.Loop.handle_event(loop, event, &{:continue, &1})
495+
end
496+
486497
test "fires correctly on :started" do
487498
ExUnit.CaptureIO.capture_io(fn ->
488499
run_dummy_loop!(:started, 5, 10)
@@ -596,6 +607,70 @@ defmodule Axon.LoopTest do
596607

597608
refute_received _
598609
end
610+
611+
test "events are counted correctly" do
612+
model = Axon.input("foo")
613+
614+
data =
615+
Stream.repeatedly(fn ->
616+
xs = Nx.tensor([[Enum.random(0..10)]])
617+
ys = Nx.greater(xs, 5)
618+
{xs, ys}
619+
end)
620+
621+
ExUnit.CaptureIO.capture_io(fn ->
622+
# loop with multiple :iteration_started handlers to test that
623+
# the event counts are only incremented once per event
624+
model
625+
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
626+
|> send_event_counts_handler(:started)
627+
|> send_event_counts_handler(:epoch_started)
628+
|> continue_handler(:iteration_started)
629+
|> continue_handler(:iteration_started)
630+
|> send_event_counts_handler(:epoch_completed)
631+
|> Axon.Loop.run(data, %{}, epochs: 2, iterations: 10)
632+
end)
633+
634+
assert_received {:started,
635+
%{
636+
started: 1
637+
}}
638+
639+
assert_received {:epoch_started,
640+
%{
641+
started: 1,
642+
epoch_started: 1
643+
}}
644+
645+
assert_received {:epoch_completed,
646+
%{
647+
started: 1,
648+
epoch_started: 1,
649+
epoch_completed: 1,
650+
iteration_started: 10,
651+
iteration_completed: 10
652+
}}
653+
654+
assert_received {:epoch_started,
655+
%{
656+
started: 1,
657+
epoch_started: 2,
658+
epoch_completed: 1,
659+
iteration_started: 10,
660+
iteration_completed: 10
661+
}}
662+
663+
assert_received {:epoch_completed,
664+
%{
665+
started: 1,
666+
epoch_started: 2,
667+
epoch_completed: 2,
668+
iteration_started: 20,
669+
iteration_completed: 20
670+
}}
671+
672+
refute_received _
673+
end
599674
end
600675

601676
describe "filters" do

0 commit comments

Comments
 (0)