Skip to content

Commit a188d3a

Browse files
update bundle download api
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 8aef9a9 commit a188d3a

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

monai/bundle/scripts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,15 @@ def download(
600600
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
601601
elif source_ == "monaihosting":
602602
try:
603+
extract_path = os.path.join(bundle_dir_, name_)
604+
huggingface_hub.snapshot_download(repo_id=f"MONAI/{name_}", revision=version_, local_dir=extract_path)
605+
except (huggingface_hub.errors.RevisionNotFoundError, huggingface_hub.errors.RepositoryNotFoundError):
606+
# if bundle or version not found from huggingface, download from ngc monaihosting
603607
_download_from_monaihosting(
604608
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
605609
)
606610
except urllib.error.HTTPError:
607-
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
611+
# if also cannot download from ngc monaihosting, download according to bundle_info
608612
_download_from_bundle_info(
609613
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
610614
)

monai/utils/jupyter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def plot_engine_status(
234234

235235

236236
def _get_loss_from_output(
237-
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
238238
) -> torch.Tensor:
239239
"""Returns a single value from the network output, which is a dict or tensor."""
240240

tests/bundle/test_bundle_download.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464
TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]
6565

66+
TEST_CASE_6_HF = [["models/model.pt", "configs/train.yaml"], "mednist_ddpm", "1.0.1"]
67+
6668
TEST_CASE_7 = [
6769
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
6870
"test_bundle",
@@ -193,6 +195,7 @@ def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _ur
193195

194196
@parameterized.expand([TEST_CASE_6])
195197
@skip_if_quick
198+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
196199
def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):
197200
with skip_if_downloading_fails():
198201
# download a single file from url, also use `args_file`
@@ -239,6 +242,7 @@ def test_list_latest_versions(self):
239242
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])
240243

241244
@skip_if_quick
245+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
242246
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
243247
def test_download_monaihosting(self, mock_get_versions):
244248
"""Test checking MONAI version from a metadata file."""
@@ -333,6 +337,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
333337

334338
@parameterized.expand([TEST_CASE_8])
335339
@skip_if_quick
340+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
336341
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
337342
with skip_if_downloading_fails():
338343
# download bundle, and load weights from the downloaded path

0 commit comments

Comments
 (0)