Skip to content

Commit

Permalink
Load libpyg.so first to let torch.library.register_fake find cust…
Browse files Browse the repository at this point in the history
…om operators (#329)

Part of pyg-team/pytorch_geometric#8890.


This PR reorders import statements (in order to let the decorator
`torch.library.register_fake` in future PRs) find pyg-lib's custom
operators. This PR also includes minor clean up.
  • Loading branch information
akihironitta authored Jul 25, 2024
1 parent 5731c0d commit be36298
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 47 deletions.
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ repos:
name: Check packaging
args: [--min=10, .]

- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
name: Upgrade Python syntax
args: [--py38-plus]

- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
name: Remove unused imports and variables
args: [
--remove-all-unused-imports,
--remove-unused-variables,
--remove-duplicate-keys,
--ignore-init-module-imports,
--in-place,
]

- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
Expand All @@ -37,6 +57,13 @@ repos:
- id: isort
name: Sort imports

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.3
hooks:
- id: ruff
name: Ruff formatting
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
hooks:
Expand Down
27 changes: 27 additions & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
include = ["pyproject.toml", "pyg_lib/**/*.py"]
extend-exclude = [
"pyg_lib/testing.py",
"test",
"tools",
"setup.py",
"benchmark",
]
src = ["pyg_lib"]
line-length = 80
target-version = "py38"

[lint]
select = [
"D",
]
ignore = [
"D100", # TODO Don't ignore "Missing docstring in public module"
"D104", # TODO Don't ignore "Missing docstring in public package"
"D205", # Ignore "blank line required between summary line and description"
]

[lint.pydocstyle]
convention = "google"

[format]
quote-style = "single"
13 changes: 6 additions & 7 deletions pyg_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@

import torch

import pyg_lib.ops # noqa
import pyg_lib.sampler # noqa
import pyg_lib.partition # noqa

from .home import get_home_dir, set_home_dir
from pyg_lib.home import get_home_dir, set_home_dir

__version__ = '0.4.0'

# * `libpyg.so`: The name of the shared library file.
# * `torch.ops.pyg`: The used namespace.
# * `pyg_lib`: The name of the Python package.
# TODO Make naming more consistent.


def load_library(lib_name: str):
def load_library(lib_name: str) -> None:
if bool(os.getenv('BUILD_DOCS', 0)):
return

Expand All @@ -38,6 +33,10 @@ def load_library(lib_name: str):

load_library('libpyg')

import pyg_lib.ops # noqa
import pyg_lib.sampler # noqa
import pyg_lib.partition # noqa


def cuda_version() -> int:
r"""Returns the CUDA version for which :obj:`pyg_lib` was compiled with.
Expand Down
40 changes: 22 additions & 18 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor


def pytreeify(cls):
def _pytreeify(cls):
r"""A pytree is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values.
Expand Down Expand Up @@ -56,7 +56,7 @@ def new_backward(ctx, *flat_grad_outputs):
return cls


@pytreeify
@_pytreeify
class GroupedMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, args: Tuple[Tensor]) -> Tuple[Tensor]:
Expand Down Expand Up @@ -96,13 +96,15 @@ def backward(ctx, *outs_grad: Tuple[Tensor]) -> Tuple[Tensor]:
return tuple(inputs_grad + others_grad)


def grouped_matmul(inputs: List[Tensor], others: List[Tensor],
biases: Optional[List[Tensor]] = None) -> List[Tensor]:
def grouped_matmul(
inputs: List[Tensor],
others: List[Tensor],
biases: Optional[List[Tensor]] = None,
) -> List[Tensor]:
r"""Performs dense-dense matrix multiplication according to groups,
utilizing dedicated kernels that effectively parallelize over groups.
.. code-block:: python
Example:
inputs = [torch.randn(5, 16), torch.randn(3, 32)]
others = [torch.randn(16, 32), torch.randn(32, 64)]
Expand Down Expand Up @@ -135,14 +137,17 @@ def grouped_matmul(inputs: List[Tensor], others: List[Tensor],
return outs


def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor,
bias: Optional[Tensor] = None) -> Tensor:
def segment_matmul(
inputs: Tensor,
ptr: Tensor,
other: Tensor,
bias: Optional[Tensor] = None,
) -> Tensor:
r"""Performs dense-dense matrix multiplication according to segments along
the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
dedicated kernels that effectively parallelize over groups.
.. code-block:: python
Example:
inputs = torch.randn(8, 16)
ptr = torch.tensor([0, 5, 8])
other = torch.randn(2, 16, 32)
Expand All @@ -153,11 +158,11 @@ def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor,
assert out[5:8] == inputs[5:8] @ other[1]
Args:
input (torch.Tensor): The left operand 2D matrix of shape
inputs (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments.
For best performance, given as a CPU tensor.
the boundaries of segments. For best performance, given as a CPU
tensor.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
bias (torch.Tensor, optional): Optional bias term of shape
Expand All @@ -181,7 +186,7 @@ def sampled_add(
) -> Tensor:
r"""Performs a sampled **addition** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] +
Expand Down Expand Up @@ -213,7 +218,7 @@ def sampled_sub(
) -> Tensor:
r"""Performs a sampled **subtraction** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] -
Expand Down Expand Up @@ -245,7 +250,7 @@ def sampled_mul(
) -> Tensor:
r"""Performs a sampled **multiplication** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] *
Expand Down Expand Up @@ -277,7 +282,7 @@ def sampled_div(
) -> Tensor:
r"""Performs a sampled **division** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] /
Expand Down Expand Up @@ -351,7 +356,6 @@ def softmax_csr(
:rtype: :class:`Tensor`
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
>>> softmax(src, ptr)
Expand Down
21 changes: 14 additions & 7 deletions pyg_lib/ops/scatter_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


@triton.jit
def fused_scatter_reduce_kernel(inputs_ptr, index_ptr, out_ptr, num_feats,
num_reductions, numel, REDUCE0, REDUCE1,
REDUCE2, REDUCE3, BLOCK_SIZE: tl.constexpr):
def _fused_scatter_reduce_forward_kernel(inputs_ptr, index_ptr, out_ptr,
num_feats, num_reductions, numel,
REDUCE0, REDUCE1, REDUCE2, REDUCE3,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE

Expand Down Expand Up @@ -82,8 +83,13 @@ def fused_scatter_reduce_kernel(inputs_ptr, index_ptr, out_ptr, num_feats,
tl.atomic_max(out_ptr + out_offsets, inputs, mask=mask)


def fused_scatter_reduce(inputs: Tensor, index: Tensor, dim_size: int,
reduce_list: List[str]) -> Tensor:
def fused_scatter_reduce(
inputs: Tensor,
index: Tensor,
dim_size: int,
reduce_list: List[str],
) -> Tensor:
r"""Fuses multiple scatter operations into a single kernel."""
# TODO (matthias): Add support for `out`.
# TODO (matthias): Add backward functionality.
# TODO (matthias): Add support for inputs.dim() != 2.
Expand Down Expand Up @@ -129,9 +135,10 @@ def fused_scatter_reduce(inputs: Tensor, index: Tensor, dim_size: int,

# TODO (matthias) Do not compute "sum" and "mean" reductions twice.

grid = lambda meta: (triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), )
grid = lambda meta: ( # noqa: E731
triton.cdiv(inputs.numel(), meta['BLOCK_SIZE']), )

fused_scatter_reduce_kernel[grid](
_fused_scatter_reduce_forward_kernel[grid](
inputs,
index,
out,
Expand Down
18 changes: 11 additions & 7 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def neighbor_sample(
:obj:`node_time` as default for seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge-weight (torch.Tensor, optional): If given, will perform biased
edge_weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
Expand Down Expand Up @@ -117,10 +117,8 @@ def hetero_neighbor_sample(
.. note ::
Similar to :meth:`neighbor_sample`, but expects a dictionary of node
types (:obj:`str`) and edge types (:obj:`Tuple[str, str, str]`) for
each non-boolean argument.
Args:
kwargs: Arguments of :meth:`neighbor_sample`.
each non-boolean argument. See :meth:`neighbor_sample` for more
details.
"""
src_node_types = {k[0] for k in rowptr_dict.keys()}
dst_node_types = {k[-1] for k in rowptr_dict.keys()}
Expand Down Expand Up @@ -193,8 +191,14 @@ def subgraph(
return torch.ops.pyg.subgraph(rowptr, col, nodes, return_edge_id)


def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,
p: float = 1.0, q: float = 1.0) -> Tensor:
def random_walk(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
walk_length: int,
p: float = 1.0,
q: float = 1.0,
) -> Tensor:
r"""Samples random walks of length :obj:`walk_length` from all node
indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`, as
described in the `"node2vec: Scalable Feature Learning for Networks"
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def onlyTriton(func: Callable) -> Callable:
def withCUDA(func: Callable) -> Callable:
import pytest

devices = [torch.device('cpu')]
devices = [pytest.param(torch.device('cpu'), id='cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda:0'))
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

return pytest.mark.parametrize('device', devices)(func)

Expand Down
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ classifiers=
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3 :: Only

[aliases]
test=pytest

[tool:pytest]
addopts=--capture=no --ignore=third_party
testpaths=test
addopts=--capture=no --ignore=third_party --color=yes -vv

[flake8]
ignore=E731
exclude=
third_party
build

[isort]
multi_line_output=3
Expand Down
3 changes: 2 additions & 1 deletion test/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def add_kernel(x_ptr, y_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
def add(x: Tensor, y: Tensor) -> Tensor:
out = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and out.is_cuda
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
grid = lambda meta: ( # noqa: E731
triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=1024)
return out

Expand Down

0 comments on commit be36298

Please sign in to comment.