Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surprising history.collect behavior with train_loop + test_step. #10

Open
JamesAllingham opened this issue Aug 21, 2023 · 2 comments
Open

Comments

@JamesAllingham
Copy link
Contributor

I found the history.collect behavior slightly surprising when a test_step runs less frequently than the train_step. Concretely, I was surprised that collect now returns train metrics at the same frequency as the test metrics. Consider the following toy setup.

def increment(state, key):
    state[key] += 1
    logs = ciclo.logs()
    logs.add_metric(key, state[key])
    return logs, state

state = {"a": 0, "b": 0}

_, history, _ = ciclo.train_loop(
    state,
    ciclo.elapse(range(6)),
    {
        ciclo.on_train_step: lambda state: increment(state, "a"),
        ciclo.on_test_step: lambda state: increment(state, "b"),
    },
    test_dataset=lambda: ciclo.elapse(range(1)),
    epoch_duration=2,
    stop=6,
)

If we collect both a train metric and a test metric, we get

steps, a, b_test = history.collect("steps", "a", "b_test")
print(steps)    # [0, 2, 4]
print(a)        # [1, 3, 5]
print(b_test)   # [1, 2, 3]

where the steps and the train metric (a) are subsampled.

Compare this to collecting the train and test metrics separately

steps_a, a = history.collect("steps", "a")
steps_b, b_test = history.collect("steps", "b_test")
print(steps_a)  # [0, 1, 2, 3, 4, 5]
print(a)        # [1, 2, 3, 4, 5, 6]
print(steps_b)  # [0, 2, 4]
print(b_test)   # [1, 2, 3]

where the train steps and metric are not subsampled.

I wouldn't say this is a bug, but perhaps the behavior should be documented somewhere or included in some examples?

@cgarciae
Copy link
Owner

cgarciae commented Aug 21, 2023

Yeah, this should be documented, collect returns rows where all keys appear so some keys might be subsampled up to the least frequent ones, if they have no overlap you get empty lists. I've used collect mainly to plot so this makes sense in that context. If there are other use cases maybe we can generalize the behavior with a flag.

@JamesAllingham
Copy link
Contributor Author

JamesAllingham commented Aug 21, 2023

In terms of use cases, I'm not sure it is necessary to add a flag – I was trying to use collect to help me debug by grabbing all of the metrics conveniently. I was slightly confused that things didn't match the Keras logger output during training.

I imagine that simply calling collect twice would be enough of a workaround for most trivial cases like mine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants