Skip to content

Conversation

catch-twenty-two
Copy link
Contributor

@catch-twenty-two catch-twenty-two commented Jul 24, 2025

This adds functionality to broadcast two tensors with potentially different static ranks to a common rank and then preform an operation on them. Before adding more operations via macros, I wanted to make sure this would be helpful to the community since as of now I use it the following way:

I have not fully incorporated this into the burn project, since I wanted to get a useful or not useful comment first

    fn test_broadcast_multi_dims_values() {
        let device = &NdArrayDevice::default();
        type B = NdArray<f32>;

        let a = Tensor::<B, 3>::from_data(
            [
                [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
                [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
            ],
            device,
        );

        let b = Tensor::<B, 2>::from_data([[4, 11, 10, 5]], device);

        let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<B, 2>);

        let a_add_b = a.add(b);

        Tensor::<B, 3>::from_data(
            [
                [[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
                [[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
            ],
            device,
        )
        .into_data()
        .assert_eq(&a_add_b.to_data(), true);
    }

Functionality can be easily added to incorporate operators, and then used like in the following example:

let a = Tensor::<B, 6>::empty([7, 6, 2, 3, 1, 9], device);
let b = Tensor::<B, 4>::empty([2, 1, 7, 1], device);

let a = add_broadcast!(a: Tensor<B, 6>, b: Tensor<B, 4>);

Note: Macros were used since the broadcast operation requires knowledge of what size the rank from the resulting tensors will be (which are constants) prior to compile time.

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#3344
#1499

Changes

I've added the logic to correctly shape the tensors via a broadcast function and have operations done on them.

Testing

Multiple unit tests, and integrated into my own code base.

Copy link

codecov bot commented Jul 24, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 63.43%. Comparing base (38874eb) to head (6df1934).

❌ Your project check has failed because the head coverage (63.43%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #3417   +/-   ##
=======================================
  Coverage   63.43%   63.43%           
=======================================
  Files         981      981           
  Lines      109705   109705           
=======================================
+ Hits        69589    69592    +3     
+ Misses      40116    40113    -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having broadcasting for tensors of different rank would address a real gap in usability 🙂

But we should be a thoughtful about the right approach (and I haven't really thought about it yet).

Some general comments on the current macro:

  • it only applies left padding to both operands, but broadcasting can be a bit more complex
  • the syntax is a little verbose with the type annotations vs type inference
  • having a per-op macro like add_broadcast! is helpful to apply different broadcasting rules, but can lead to macro explosion; we can probably have an element-wise broadcast and matmul broadcast to cover almost all cases

@catch-twenty-two
Copy link
Contributor Author

catch-twenty-two commented Jul 30, 2025

Thanks for the review!

I think having broadcasting for tensors of different rank would address a real gap in usability 🙂

So, that kinda sounds like a yay, not a nay on usefulness? 😁

But we should be a thoughtful about the right approach (and I haven't really thought about it yet).

Some general comments on the current macro:

  • it only applies left padding to both operands, but broadcasting can be a bit more complex

Could you give me an quick example? I thought I had all the bases covered here afa the general way pytorch handles broadcasting. If there are some corner cases I'd love to account for them.

https://docs.pytorch.org/docs/stable/notes/broadcasting.html

  • the syntax is a little verbose with the type annotations vs type inference

Agreed! I couldn't think of a better way to pass the generic constants to the actual function (since as far as I know, this is the only way to do this?) to compare which one is larger (since this has to be done at compile time) and still have it look intuitively like the rest of the burn/rust tensor type syntax

let a: Tensor<B, D, K> = ect..

Any ideas would be greatly appreciated, as I know once this syntax is incorporated to the api, you are pretty much stuck with it.

I've since, updated it a bit, but it's still verbose:

broadcast!(
    a: Tensor<Backend, RANK_A>,
    b: Tensor<RANK_B>
)
  • having a per-op macro like add_broadcast! is helpful to apply different broadcasting rules, but can lead to macro explosion; we can probably have an element-wise broadcast and matmul broadcast to cover almost all cases

Agreed, this is why I only did it for one... I just wanted to show an example of what could be done... Personally I really dislike macros, I feel that they obfuscate the code and make debugging difficult, the less the better! I'll look into incorporating matmul.

@laggui
Copy link
Member

laggui commented Aug 5, 2025

Sorry for the late response, notification got lost in the recent flood 😅

So, that kinda sounds like a yay, not a nay on usefulness? 😁

Yeah I definitely see the usefulness! Just gotta nail the usage/implementation 🙂

Could you give me an quick example? I thought I had all the bases covered here afa the general way pytorch handles broadcasting. If there are some corner cases I'd love to account for them.

Actually this kinda ties into the third point regarding elemwise and matmul broadcast rules. For elemwise I believe it's simply left padding, it's matmul that has different rules.

Not entirely sure yet about the best approach, I'd have to think about the best way to tackle the syntax and rules 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants