diff --git a/tests/conftest.py b/tests/conftest.py index 7d622c3977..4d5a43c810 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from pathlib import Path from typing import Any import pytest @@ -16,3 +17,8 @@ def load(*args: Any, progress: bool = False, **kwargs: Any) -> Any: @pytest.fixture def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) + + +@pytest.fixture(autouse=True) +def torch_hub(tmp_path: Path) -> None: + torch.hub.set_dir(tmp_path)