Skip to content

Commit

Permalink
use pretrained weights for torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 7, 2023
1 parent b64adb4 commit a10fe37
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def test_jit_trace(self, name: str, inputs: Tensor):

@pytest.mark.skipif(not hasattr(torch, "compile"), reason="torch.compile() is not available")
def test_compile(self, name: str, inputs: Tensor):
m = getattr(backbones, name)().eval()
m = getattr(backbones, name)(pretrained=True).eval()
m_compiled = torch.compile(m)
torch.testing.assert_close(m(inputs), m_compiled(inputs))

0 comments on commit a10fe37

Please sign in to comment.