Skip to content

Commit

Permalink
remove torchscript and add torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 7, 2023
1 parent d12d1b5 commit 8642b2c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_model_creation(self, name: str):
assert isinstance(m, backbones.BaseBackbone)

def test_pretrained_weights(self, name: str):
m = getattr(backbones, name)(pretrained=True)
getattr(backbones, name)(pretrained=True)

def test_attributes(self, name: str):
m = getattr(backbones, name)()
Expand Down Expand Up @@ -62,10 +62,12 @@ def test_get_feature_maps(self, name: str, inputs: Tensor):
assert len(out.shape) == 4
assert out.shape[1] == out_c

def test_jit_script(self, name: str):
m = getattr(backbones, name)()
torch.jit.script(m)

def test_jit_trace(self, name: str, inputs: Tensor):
m = getattr(backbones, name)()
torch.jit.script(m, inputs)
torch.jit.trace(m, inputs)

@pytest.mark.skipif(not hasattr(torch, "compile"), reason="torch.compile() is not available")
def test_compile(self, name: str, inputs: Tensor):
m = getattr(backbones, name)().eval()
m_compiled = torch.compile(m)
torch.testing.assert_close(m(inputs), m_compiled(inputs))

0 comments on commit 8642b2c

Please sign in to comment.