diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 7952539b0b..26b51716d4 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -1,5 +1,9 @@ name: CI -on: [push, pull_request] +on: + push: + pull_request: + schedule: + - cron: "0 0 * * *" # @daily # Run linter with github actions for quick feedbacks. # Run macos tests with github actions. Linux (CPU & GPU) tests currently runs on CircleCI @@ -37,14 +41,14 @@ jobs: strategy: fail-fast: false matrix: - torch: ["1.8", "1.9", "1.10"] + torch: ["1.9", "1.10", "2.2.2"] include: - - torch: "1.8" - torchvision: 0.9 - torch: "1.9" torchvision: "0.10" - torch: "1.10" torchvision: "0.11.1" + - torch: "2.2.2" + torchvision: "0.17.2" env: # point datasets to ~/.torch so it's cached by CI DETECTRON2_DATASETS: ~/.torch/datasets @@ -66,11 +70,13 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - python -m pip install ninja opencv-python-headless onnx pytest-xdist + python -m pip install wheel ninja opencv-python-headless onnx pytest-xdist python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html # install from github to get latest; install iopath first since fvcore depends on it python -m pip install -U 'git+https://github.com/facebookresearch/iopath' python -m pip install -U 'git+https://github.com/facebookresearch/fvcore' + wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py + python collect_env.py - name: Build and install run: | diff --git a/detectron2/layers/nms.py b/detectron2/layers/nms.py index 65afb746bf..37ba18b2af 100644 --- a/detectron2/layers/nms.py +++ b/detectron2/layers/nms.py @@ -5,6 +5,8 @@ from torchvision.ops import boxes as box_ops from torchvision.ops import nms # noqa . for compatibility +from detectron2.layers.wrappers import disable_torch_compiler + def batched_nms( boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float @@ -22,7 +24,7 @@ def batched_nms( # Note: this function (nms_rotated) might be moved into # torchvision/ops/boxes.py in the future -@torch.compiler.disable +@disable_torch_compiler def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float): """ Performs non-maximum suppression (NMS) on the rotated boxes according diff --git a/detectron2/layers/roi_align_rotated.py b/detectron2/layers/roi_align_rotated.py index 7e25a310db..12dd00118c 100644 --- a/detectron2/layers/roi_align_rotated.py +++ b/detectron2/layers/roi_align_rotated.py @@ -5,10 +5,12 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair +from detectron2.layers.wrappers import disable_torch_compiler + class _ROIAlignRotated(Function): @staticmethod - @torch.compiler.disable + @disable_torch_compiler def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): ctx.save_for_backward(roi) ctx.output_size = _pair(output_size) diff --git a/detectron2/layers/wrappers.py b/detectron2/layers/wrappers.py index fb3cb38b9d..668c6bbd63 100644 --- a/detectron2/layers/wrappers.py +++ b/detectron2/layers/wrappers.py @@ -8,6 +8,7 @@ is implemented """ +import functools import warnings from typing import List, Optional import torch @@ -39,7 +40,7 @@ def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> tor def check_if_dynamo_compiling(): - if TORCH_VERSION >= (1, 14): + if TORCH_VERSION >= (2, 1): from torch._dynamo import is_compiling return is_compiling() @@ -47,6 +48,19 @@ def check_if_dynamo_compiling(): return False +def disable_torch_compiler(func): + if TORCH_VERSION >= (2, 1): + # Use the torch.compiler.disable decorator if supported + @torch.compiler.disable + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + else: + # Return the function unchanged if torch.compiler.disable is not supported + return func + + def cat(tensors: List[torch.Tensor], dim: int = 0): """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list diff --git a/tests/test_export_onnx.py b/tests/test_export_onnx.py index aa15e1a406..e5c3d6e17e 100644 --- a/tests/test_export_onnx.py +++ b/tests/test_export_onnx.py @@ -3,7 +3,9 @@ import io import unittest import warnings +import onnx import torch +from packaging import version from torch.hub import _check_module_exists from detectron2 import model_zoo @@ -95,6 +97,10 @@ def inference_func(model, image): inference_func, ) + @unittest.skipIf( + version.Version(onnx.version.version) == version.Version("1.16.0"), + "This test fails on ONNX Runtime 1.16", + ) def testKeypointHead(self): class M(torch.nn.Module): def __init__(self):