Skip to content

Commit 8e6438c

Browse files
author
emcastillo
authored
Merge pull request #796 from emcastillo/custom_torch_ops
Add interface for register custom ops in `torch.ops.ppe`
2 parents 73082f6 + f1fd387 commit 8e6438c

File tree

6 files changed

+183
-0
lines changed

6 files changed

+183
-0
lines changed

pytorch_pfn_extras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@
2828
from pytorch_pfn_extras.runtime._to import to # NOQA
2929

3030
if requires("2.0.0"):
31+
from pytorch_pfn_extras import ops # NOQA
3132
from pytorch_pfn_extras._dynamo import compile # NOQA

pytorch_pfn_extras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytorch_pfn_extras.ops.register import OpDesc, register # NOQA

pytorch_pfn_extras/ops/register.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Any, Callable, cast
2+
3+
import torch
4+
import torch.library
5+
6+
# Libraries used to store the ops definitions
7+
library = torch.library.Library("ppe", "DEF")
8+
library_impl = torch.library.Library("ppe", "IMPL", "CompositeExplicitAutograd")
9+
library_autograd_impl = torch.library.Library("ppe", "IMPL", "Autograd")
10+
library_meta_impl = torch.library.Library("ppe", "IMPL", "Meta")
11+
12+
13+
class OpDesc:
14+
"""Metadata to register an op to torch.library.
15+
16+
Attributes:
17+
op (callable): code to be executed in the forward/backward of the op.
18+
meta (callable): function to perform shape inference for forward/backward
19+
passes.
20+
signature (str): Arguments and return type of the function
21+
``"(Tensor a, Tensor b) -> Tensor[]"``.
22+
"""
23+
24+
def __init__(
25+
self,
26+
op: Callable[..., Any],
27+
meta: Callable[..., Any],
28+
signature: str,
29+
) -> None:
30+
self.op = op
31+
self.meta = meta
32+
self.signature = signature
33+
34+
35+
def _get_autograd(name: str) -> Callable[..., Any]:
36+
class RunBackward(torch.autograd.Function):
37+
@staticmethod
38+
def forward(ctx, *args, **kwargs): # type: ignore[no-untyped-def]
39+
ctx.save_for_backward(*args)
40+
op_h = torch._C._dispatch_find_schema_or_throw(
41+
f"ppe::{name}_fwd", ""
42+
)
43+
return torch._C._dispatch_call_boxed(op_h, *args, **kwargs)
44+
45+
@staticmethod
46+
def backward(ctx, *args): # type: ignore[no-untyped-def]
47+
i_args = tuple(ctx.saved_tensors)
48+
op_h = torch._C._dispatch_find_schema_or_throw(
49+
f"ppe::{name}_bwd", ""
50+
)
51+
return torch._C._dispatch_call_boxed(op_h, *(args + i_args), **{})
52+
53+
return cast(Callable[..., Any], RunBackward.apply)
54+
55+
56+
def register(
57+
name: str,
58+
fwd_op: OpDesc,
59+
bwd_op: OpDesc,
60+
) -> None:
61+
"""
62+
Register a custom op under ``torch.ops.ppe.name``
63+
64+
The function appears as a primitive op in the forward and backward
65+
``torch.fx.Graph``s after compiling torch code with `aot_autograd` backend.
66+
Note that for backward functions, all the arguments of the backward pass
67+
together with the forward arguments are passed to it. This means if forward had
68+
``fwd_op(x, y)`` ``x,y`` arguments, the custom bwd_op needs to have a
69+
signature like``bwd_op(grad_output, x, y)``
70+
71+
Arguments:
72+
name (str): name of the op, shows how it is registered in ``torch.ops.ppe``.
73+
fwd_op (ppe.ops.OpDesc): code that is executed in the forward pass
74+
bwd_op (ppe.ops.OpDesc): code that is executed in the backward pass
75+
"""
76+
function_sig = f"{name}{fwd_op.signature}"
77+
function_fwd_sig = f"{name}_fwd{fwd_op.signature}"
78+
function_bwd_sig = f"{name}_bwd{bwd_op.signature}"
79+
for s in (function_sig, function_fwd_sig, function_bwd_sig):
80+
library.define(s)
81+
82+
def function(*args): # type: ignore[no-untyped-def]
83+
op_h = torch._C._dispatch_find_schema_or_throw(f"ppe::{name}_fwd", "")
84+
return torch._C._dispatch_call_boxed(op_h, *args, **{})
85+
86+
library_impl.impl(name, function)
87+
library_impl.impl(f"{name}_fwd", fwd_op.op)
88+
library_impl.impl(f"{name}_bwd", bwd_op.op)
89+
library_meta_impl.impl(f"{name}_fwd", fwd_op.meta)
90+
library_meta_impl.impl(f"{name}_bwd", bwd_op.meta)
91+
library_autograd_impl.impl(name, _get_autograd(name))

stubs/torch/_C/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4073,6 +4073,8 @@ def _activate_cuda_trace() -> None: ...
40734073
# Defined in torch/csrc/Module.cpp
40744074
def _current_graph_task_id() -> _int: ...
40754075
def _current_autograd_node() -> _Node: ...
4076+
def _dispatch_find_schema_or_throw(name: str, postfix: str) -> Any: ...
4077+
def _dispatch_call_boxed(op: Any, args: Any, kwargs: Any) -> Any: ...
40764078

40774079
class _OutOfMemoryError:
40784080
pass

stubs/torch/library/__init__.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# flake8: noqa
2+
from typing import Any, Callable
3+
4+
class Library:
5+
def __init__(self, ns: str, kind: str, dispatch_key: str = "") -> None: ...
6+
def impl(
7+
self, name: str, fn: Callable[..., Any], dispatch_key: str = ""
8+
) -> None: ...
9+
def define(self, name: str) -> None: ...
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import sys
2+
3+
import pytest
4+
import pytorch_pfn_extras as ppe
5+
import torch
6+
7+
8+
def _get_function_nodes(fx_module):
9+
return [
10+
node for node in fx_module.graph.nodes if node.op == "call_function"
11+
]
12+
13+
14+
@pytest.mark.skipif(
15+
not ppe.requires("2.1.0") or sys.platform == "win32",
16+
reason="torch custom ops only works for PyTorch>=2.1 and linux",
17+
)
18+
def test_register():
19+
def test(a):
20+
return a * 2
21+
22+
def test_bwd(g, a):
23+
return g
24+
25+
def test_meta(a):
26+
return torch.empty_like(a)
27+
28+
def test_bwd_meta(g, a):
29+
return torch.empty_like(a)
30+
31+
fwd_op = ppe.ops.OpDesc(test, test_meta, "(Tensor a) -> Tensor")
32+
bwd_op = ppe.ops.OpDesc(
33+
test_bwd, test_bwd_meta, "(Tensor g, Tensor a) -> Tensor"
34+
)
35+
ppe.ops.register("test", fwd_op, bwd_op)
36+
37+
class TestModule(torch.nn.Module):
38+
def forward(self, a):
39+
# Call the custom function
40+
return torch.ops.ppe.test(a)
41+
42+
found_fwd_op = False
43+
found_bwd_op = False
44+
45+
from functorch.compile import make_boxed_func
46+
from torch._dynamo.backends.common import aot_autograd
47+
48+
# Detect the custom ops
49+
def fwd_compiler_fn(fx_module: torch.fx.GraphModule, _):
50+
nonlocal found_fwd_op
51+
function_nodes = _get_function_nodes(fx_module)
52+
assert len(function_nodes) == 1
53+
found_fwd_op = (
54+
function_nodes[0].target is torch.ops.ppe.test_fwd.default
55+
)
56+
return make_boxed_func(fx_module)
57+
58+
def bwd_compiler_fn(fx_module: torch.fx.GraphModule, _):
59+
nonlocal found_bwd_op
60+
function_nodes = _get_function_nodes(fx_module)
61+
assert len(function_nodes) == 1
62+
found_bwd_op = (
63+
function_nodes[0].target is torch.ops.ppe.test_bwd.default
64+
)
65+
return make_boxed_func(fx_module)
66+
67+
aot_backend = aot_autograd( # type: ignore[no-untyped-call]
68+
fw_compiler=fwd_compiler_fn,
69+
bw_compiler=bwd_compiler_fn,
70+
)
71+
m = TestModule()
72+
torch._dynamo.reset()
73+
module_opt = torch.compile(m, fullgraph=True, backend=aot_backend)
74+
shape = [1, 16, 2048, 128]
75+
x = torch.ones(shape, requires_grad=True)
76+
y = module_opt(x)
77+
y.sum().backward()
78+
assert found_fwd_op
79+
assert found_bwd_op

0 commit comments

Comments
 (0)