Skip to content

Commit

Permalink
fix Github actions
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#5274

Add a new test to cover PyTorch 2.0 (we'll probably drop the support for 1.x).

Differential Revision: D56911192
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed May 3, 2024
1 parent 6163074 commit b2a2eb7
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
20 changes: 12 additions & 8 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
name: CI
on: [push, pull_request]
on:
push:
pull_request:
schedule:
- cron: "0 0 * * *" # @daily

# Run linter with github actions for quick feedbacks.
# Run macos tests with github actions. Linux (CPU & GPU) tests currently runs on CircleCI
Expand Down Expand Up @@ -37,14 +41,12 @@ jobs:
strategy:
fail-fast: false
matrix:
torch: ["1.8", "1.9", "1.10"]
torch: ["1.10", "2.2.2"]
include:
- torch: "1.8"
torchvision: 0.9
- torch: "1.9"
torchvision: "0.10"
- torch: "1.10"
torchvision: "0.11.1"
torchvision: "0.11.2"
- torch: "2.2.2"
torchvision: "0.17.2"
env:
# point datasets to ~/.torch so it's cached by CI
DETECTRON2_DATASETS: ~/.torch/datasets
Expand All @@ -66,11 +68,13 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install ninja opencv-python-headless onnx pytest-xdist
python -m pip install wheel ninja opencv-python-headless onnx pytest-xdist
python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
# install from github to get latest; install iopath first since fvcore depends on it
python -m pip install -U 'git+https://github.com/facebookresearch/iopath'
python -m pip install -U 'git+https://github.com/facebookresearch/fvcore'
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
python collect_env.py
- name: Build and install
run: |
Expand Down
4 changes: 3 additions & 1 deletion detectron2/layers/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torchvision.ops import boxes as box_ops
from torchvision.ops import nms # noqa . for compatibility

from detectron2.layers.wrappers import disable_torch_compiler


def batched_nms(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
Expand All @@ -22,7 +24,7 @@ def batched_nms(

# Note: this function (nms_rotated) might be moved into
# torchvision/ops/boxes.py in the future
@torch.compiler.disable
@disable_torch_compiler
def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
"""
Performs non-maximum suppression (NMS) on the rotated boxes according
Expand Down
4 changes: 3 additions & 1 deletion detectron2/layers/roi_align_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from detectron2.layers.wrappers import disable_torch_compiler


class _ROIAlignRotated(Function):
@staticmethod
@torch.compiler.disable
@disable_torch_compiler
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(roi)
ctx.output_size = _pair(output_size)
Expand Down
17 changes: 16 additions & 1 deletion detectron2/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
is implemented
"""

import functools
import warnings
from typing import List, Optional
import torch
Expand Down Expand Up @@ -39,14 +40,28 @@ def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> tor


def check_if_dynamo_compiling():
if TORCH_VERSION >= (1, 14):
if TORCH_VERSION >= (2, 1):
from torch._dynamo import is_compiling

return is_compiling()
else:
return False


def disable_torch_compiler(func):
if TORCH_VERSION >= (2, 1):
# Use the torch.compiler.disable decorator if supported
@torch.compiler.disable
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper
else:
# Return the function unchanged if torch.compiler.disable is not supported
return func


def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
Expand Down
6 changes: 6 additions & 0 deletions tests/test_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import io
import unittest
import warnings
import onnx
import torch
from packaging import version
from torch.hub import _check_module_exists

from detectron2 import model_zoo
Expand Down Expand Up @@ -95,6 +97,10 @@ def inference_func(model, image):
inference_func,
)

@unittest.skipIf(
version.Version(onnx.version.version) == version.Version("1.16.0"),
"This test fails on ONNX Runtime 1.16",
)
def testKeypointHead(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit b2a2eb7

Please sign in to comment.