-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor C++ ops to take output tensor.
- Loading branch information
Showing
6 changed files
with
85 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import grouped_gemm.ops | ||
import grouped_gemm.backend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# NOTE: Torch needs to be imported before the custom | ||
# extensions. Otherwise libc10.so cannot be found. | ||
import torch | ||
|
||
# TODO(tgale): Wrap this in a try-block with better | ||
# error message and instructions for building the | ||
# c++ operations. | ||
import grouped_gemm_backend as backend | ||
|
||
|
||
def _allocate_output(a, b, batch_sizes, trans_a, trans_b): | ||
assert not (trans_a and trans_b) | ||
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" | ||
assert a.ndim == 2, "Expected 2d tensor for 'a'" | ||
assert b.ndim == (2 if trans_a else 3) | ||
|
||
shape = ( | ||
(batch_sizes.shape[0], a.shape[1], b.shape[1]) | ||
if trans_a else | ||
(a.shape[0], (b.shape[1] if trans_b else b.shape[2])) | ||
) | ||
return torch.empty(*shape, device=a.device, dtype=a.dtype) | ||
|
||
def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): | ||
if c is None: | ||
c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) | ||
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) | ||
return c | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters