Commit 39f1015
[torchlib] Implement torch.ops.prims.broadcast_in_dim.default (#2382)
This PR implements the missing
`torch.ops.prims.broadcast_in_dim.default` operation that appears in
BERT_pytorch and other PyTorch models.
## Overview
The `broadcast_in_dim` operation is a primitive that broadcasts a tensor
to a target shape by specifying which dimensions of the output
correspond to the input tensor dimensions. This is different from
standard broadcasting operations.
## Implementation Details
**Function signature:**
```python
def prims_broadcast_in_dim(
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
) -> TensorType:
```
**Parameters:**
- `a`: Input tensor to broadcast
- `shape`: Target output shape
- `broadcast_dimensions`: Specifies which dimensions of the output shape
correspond to the input tensor dimensions
**Example:**
```python
# Input tensor: [3, 4]
# Target shape: [2, 3, 5, 4]
# broadcast_dimensions: [1, 3]
# Result: Input dimension 0 (size 3) maps to output dimension 1
# Input dimension 1 (size 4) maps to output dimension 3
# Output dimensions 0 and 2 are broadcasted (filled from size 1)
```
Fixes #2218. Fix pytorch/pytorch#135343
---------
Signed-off-by: Justin Chu <[email protected]>
Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Co-authored-by: justinchuby <[email protected]>
Co-authored-by: Justin Chu <[email protected]>1 parent 8ed3521 commit 39f1015
File tree
3 files changed
+60
-2
lines changed- onnxscript/function_libs/torch_lib/ops
- tests/function_libs/torch_lib
3 files changed
+60
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
176 | 176 | | |
177 | 177 | | |
178 | 178 | | |
| 179 | + | |
179 | 180 | | |
180 | | - | |
| 181 | + | |
181 | 182 | | |
182 | 183 | | |
183 | 184 | | |
184 | | - | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
185 | 206 | | |
186 | 207 | | |
187 | 208 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
87 | 87 | | |
88 | 88 | | |
89 | 89 | | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
90 | 119 | | |
91 | 120 | | |
92 | 121 | | |
| |||
2687 | 2716 | | |
2688 | 2717 | | |
2689 | 2718 | | |
| 2719 | + | |
| 2720 | + | |
| 2721 | + | |
| 2722 | + | |
| 2723 | + | |
| 2724 | + | |
| 2725 | + | |
2690 | 2726 | | |
2691 | 2727 | | |
2692 | 2728 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2136 | 2136 | | |
2137 | 2137 | | |
2138 | 2138 | | |
| 2139 | + | |
2139 | 2140 | | |
2140 | 2141 | | |
2141 | 2142 | | |
| |||
0 commit comments