-
Notifications
You must be signed in to change notification settings - Fork 690
Adding option for broadcasting tensors similar to torch WIP #3417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adding option for broadcasting tensors similar to torch WIP #3417
Conversation
5192999
to
6df1934
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
❌ Your project check has failed because the head coverage (63.43%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #3417 +/- ##
=======================================
Coverage 63.43% 63.43%
=======================================
Files 981 981
Lines 109705 109705
=======================================
+ Hits 69589 69592 +3
+ Misses 40116 40113 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having broadcasting for tensors of different rank would address a real gap in usability 🙂
But we should be a thoughtful about the right approach (and I haven't really thought about it yet).
Some general comments on the current macro:
- it only applies left padding to both operands, but broadcasting can be a bit more complex
- the syntax is a little verbose with the type annotations vs type inference
- having a per-op macro like
add_broadcast!
is helpful to apply different broadcasting rules, but can lead to macro explosion; we can probably have an element-wise broadcast and matmul broadcast to cover almost all cases
Thanks for the review!
So, that kinda sounds like a yay, not a nay on usefulness? 😁
Could you give me an quick example? I thought I had all the bases covered here afa the general way pytorch handles broadcasting. If there are some corner cases I'd love to account for them. https://docs.pytorch.org/docs/stable/notes/broadcasting.html
Agreed! I couldn't think of a better way to pass the generic constants to the actual function (since as far as I know, this is the only way to do this?) to compare which one is larger (since this has to be done at compile time) and still have it look intuitively like the rest of the burn/rust tensor type syntax
Any ideas would be greatly appreciated, as I know once this syntax is incorporated to the api, you are pretty much stuck with it. I've since, updated it a bit, but it's still verbose:
Agreed, this is why I only did it for one... I just wanted to show an example of what could be done... Personally I really dislike macros, I feel that they obfuscate the code and make debugging difficult, the less the better! I'll look into incorporating matmul. |
Sorry for the late response, notification got lost in the recent flood 😅
Yeah I definitely see the usefulness! Just gotta nail the usage/implementation 🙂
Actually this kinda ties into the third point regarding elemwise and matmul broadcast rules. For elemwise I believe it's simply left padding, it's matmul that has different rules. Not entirely sure yet about the best approach, I'd have to think about the best way to tackle the syntax and rules 🤔 |
This adds functionality to broadcast two tensors with potentially different static ranks to a common rank and then preform an operation on them. Before adding more operations via macros, I wanted to make sure this would be helpful to the community since as of now I use it the following way:
I have not fully incorporated this into the burn project, since I wanted to get a useful or not useful comment first
Functionality can be easily added to incorporate operators, and then used like in the following example:
Note: Macros were used since the broadcast operation requires knowledge of what size the rank from the resulting tensors will be (which are constants) prior to compile time.
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
#3344
#1499
Changes
I've added the logic to correctly shape the tensors via a broadcast function and have operations done on them.
Testing
Multiple unit tests, and integrated into my own code base.