TorchBlade converts PyTorch workloads to MHLO based on Torch-MLIR. Then compile the MHLO modules via BladeDISC compiler.
The BladeDISC Dev Team is cooperating with the community to add Torch-To-Mhlo conversion to Torch-MLIR, especially fully dynamic shape features. See RFC: llvm/torch-mlir#999. We appeal to the community developers interested in joining.
TorchBlade Torch operator conversion will comply with Torch-MLIR generally. Developers of this part are recommended to have a look at the following:
TorchBlade will cluster supported operation regions into subgraphs and then convert subgraphs from TorchScript into MHLO modules via Torch-MLIR; Finally, they will be compiled and executed by BladeDISC.
As shown in torch_mlir_op_filter.cpp:GetTorchMlirWhiteList()
,
the operator name is required to be in the list so that TorchBlade will try to convert it as supported.
To the debug needs, one can use the environment variable TORCH_MHLO_OP_WHITE_LIST
to specify some extra white list operators
without re-compilation. For example, export TORCH_MHLO_OP_WHITE_LIST="aten::batch_norm_a;aten::activation_b;"
before a run.
You can find ATen native node schema definitions at
ATen/native;
The TorchBlade also provides a tool function
node_schema_str
that returns the schema of the input node.
import torch
import torch_blade.tools as tools
@torch.jit.script
def add(x, y):
return x + y
print(add.graph)
for n in add.graph.nodes():
print(tools.node_schema_str(n))
graph(%x.1 : Tensor,
%y.1 : Tensor):
%4 : int = prim::Constant[value=1]()
%5 : Tensor = aten::add(%x.1, %y.1, %4)
return (%5)
aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
Note that data types and shapes are critical to Deep Learning Compilers nowadays. TorchBlade holds its shape analysis upon TorchScript;
- It focuses on rank and dynamic shape inferences
- To support subgraph compilation clustering and cache, ranks, and data types are needed before conversion to Torch-MLIR
To avoid exceptions, TorchBlade will skip the shape and type refinement passes from Torch-MLIR.
Please add type and shape analsyses when it's missing, see shape_analysis.cpp
.
It means that the Torch operator is missing from Torch-Dialect of Torch-MLIR.
if you found an exception raised because of "failed to legalize operation torch.operator
".
You must add the new operator to Torch-Dialect, see Torch-MLIR Update ODS.
There are 2 ways to add lowering:
- Decompose a Torch operator into more primitive operators
- Convert a Torch operator into MHLO directly
You can add BladeDISC special Torch operator lowerings to DiscDecomposeComplexOps
and DiscTorchToMhlo
.
In general, it's encouraged to try to add lowerings to the passes in Torch-MLIR.
Please reference to Torch-MLIR Torch Ops Lowering as well.
In rare cases, BladeDISC will fail to generate high-performant executables for the input MHLO modules. To fix them, You can deep dive into the BladeDISC code generation pass pipeline. See A Walkthough of the BladeDISC Pass Pipeline
Feel free to fire an issue to BladeDISC Dev Team if something blocks you: https://github.com/alibaba/BladeDISC/issues.
End-to-End unit tests are required as well. Please refer to the unit tests in
pytorch_blade/tests/disc
, an example is:
class TestDiscActivation(DiscTestCase):
def test_relu(self, activation_func):
relu = torch.nn.ReLU()
x = torch.randn([2, 4, 16, 16], device=self.device)
self._test_cvt_to_disc(relu, (x,))