-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequential Macro #2565
base: main
Are you sure you want to change the base?
Sequential Macro #2565
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2565 +/- ##
==========================================
- Coverage 82.37% 82.35% -0.02%
==========================================
Files 825 828 +3
Lines 105643 105773 +130
==========================================
+ Hits 87026 87112 +86
- Misses 18617 18661 +44 ☔ View full report in Codecov by Sentry. |
/// gen_sequential! { | ||
/// // No config | ||
/// Relu, | ||
/// Sigmoid; | ||
/// // Has config | ||
/// DropoutConfig => Dropout, | ||
/// LeakyReluConfig => LeakyRelu; | ||
/// // Requires a backend (<B>) | ||
/// LinearConfig => Linear | ||
/// } | ||
/// ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use pattern matching to differentiate modules that require a backend and config:
Sequential!(
Relu, // Without config
Dropout(DropoutConfig), // With config
Linear(LinearConfig; B), // With config + Generics
Custom(; B), // No config + Generics
Custom2(config; B, A, C), // With config + many generics
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this would be my ideal method for doing this, it would require me to rewrite it as a proc macro since there isn't a great way to differentiate depending on whether a value is present or not (specifically the whole ($cfg$(; $($generic),+))?
) in declarative macros. Separating into multiple blocks allows me to know that the structs in that block need to have .init()
called if they have a config or .init(device)
called if they have a backend-dependent config without needing actual Rust code to differentiate them. A proc macro would fix this but would also require a new crate just for the macro and some more advanced parsing techniques.
Additionally, how would multiple generics work? Do all generics need to be unique? If I define A
as generic across Custom2
and Custom3
, is it the same generic? I assume it would be. We would probably designate B
as reserved and meaning "needs a device passed to it on initialization".
Should I look into rewriting this as a more complicated but more flexible proc macro or should I keep the simpler but slightly more rigid declarative macro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I actually don't think the macro helps much, and I don't think we should implement a proc macro. Maybe the real solution would be to implement a trait Forward
instead of simply having a method. We could then support tuples as sequential layers. The Forward
trait would be totaly decoupled from the Module
trait and only used to simplify composing multiple forward
methods.
That would be a good addition. I saw @wingertge implementation (https://github.com/wingertge/craft-burn/blob/5f02609dadfe206aeb4c16db873446fd80101bd2/src/lib.rs). Maybe @wingertge can also give a feedback on this. This is how it was used: https://github.com/wingertge/craft-burn/blob/5f02609dadfe206aeb4c16db873446fd80101bd2/src/refine.rs#L27 |
This PR has been marked as stale because it has not been updated for over a month |
Checklist
run-checks all
script has been executed.Changes
Burn currently lacks an analog to
nn.Sequential
in PyTorch, so I created a sequential macro to generate a similar structure. I used a macro since there is no unifying trait amongst all modules, specifically with regards to how they are initialized. I created a relatively lenient system that classes modules into different bins depending on how they're initialized, then generate a structure with that information. This structure can be initialized like normal (usingSequentialConfig
, a struct containing an enum that allows for individual customization of each step) then executed withSequential::forward
.Testing
I have tested various combinations of modules with the macro and it performs as expected. I also added a unit test to ensure that the macro properly generates on each new build.