-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add inner loop aggregation options #9
Conversation
Also added a inner_loop_kwargs argument to train_loop. Added tests for inner_loop aggregation and inner_loop_kwargs.
ciclo/callbacks.py
Outdated
@@ -23,6 +23,9 @@ | |||
from ciclo.utils import get_batch_size, is_scalar | |||
|
|||
|
|||
InnerLoopAggregation = Literal["last", "mean", "sum", "min", "max", "first"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Thanks @JamesAllingham for doing this! I like the approach of transposing as it makes implementing reduction fairly easy. |
Codecov Report
@@ Coverage Diff @@
## main #9 +/- ##
==========================================
+ Coverage 77.90% 78.39% +0.48%
==========================================
Files 12 12
Lines 1539 1578 +39
==========================================
+ Hits 1199 1237 +38
- Misses 340 341 +1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
No problem! Happy to contribute to this library. Thanks for the positive feedback. That's a great idea – done. Let me know what you think of the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
My pleasure! Thanks for fixing the pre-commit. |
Thanks @JamesAllingham for doing this! |
Fixes #8.
This PR introduces an
aggregation
option for theinner_loop
callback, which controls how the logs from the inner loop are aggregated. The default behavior of returning the last value for each entry in a collection is maintained. This is equivalent to manually settingaggregation="last"
.aggregation
supports a string value, one of"last"
,"first"
,"min"
,"max"
,"mean
", and"sum"
or aCallable
. In this case, all values in all collections will be aggregated as specified. However, since some collections might require different aggregations (e.g.,"stateful_metrics"
and"metrics"
likely want"last"
and"mean"
, respectively) it is also possible to specify adict
from collection to aggregation such as{"stateful_metrics": "last", "metrics": "mean"}
. If a collection is not specified in this dictionary, then the"last"
aggregation is used.A
inner_loop_kwargs
argument has been added totrain_loop
in order to surface the aggregation options to a user of that API.The PR also adds tests for both
aggregation
andinner_loop_kwargs
.Let me know what you think :) I also wasn't sure how best to document this functionality.