Skip to content

Commit 51fc9fb

Browse files
update monai hosting download
Signed-off-by: Yiheng Wang <[email protected]>
1 parent a790590 commit 51fc9fb

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

monai/bundle/scripts.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import os
1717
import re
18+
import urllib
1819
import warnings
1920
import zipfile
2021
from collections.abc import Mapping, Sequence
@@ -25,6 +26,7 @@
2526
from textwrap import dedent
2627
from typing import Any, Callable
2728

29+
import requests
2830
import torch
2931
from torch.cuda import is_available
3032

@@ -206,6 +208,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
206208
extractall(filepath=filepath, output_dir=download_path, has_base=True)
207209

208210

211+
def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
212+
bundle_info = get_bundle_info(name=filename, version=version)
213+
if not bundle_info:
214+
raise ValueError(f"Bundle info not found for {filename} v{version}.")
215+
url = bundle_info["source"]
216+
filepath = download_path / f"{filename}_v{version}.zip"
217+
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
218+
extractall(filepath=filepath, output_dir=download_path, has_base=True)
219+
220+
209221
def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
210222
if name.startswith(prefix):
211223
return name
@@ -307,10 +319,10 @@ def _get_latest_bundle_version_monaihosting(name):
307319
if has_requests:
308320
resp = requests_get(full_url)
309321
resp.raise_for_status()
310-
else:
311-
raise ValueError("NGC API requires requests package. Please install it.")
312-
model_info = json.loads(resp.text)
313-
return model_info["model"]["latestVersionIdStr"]
322+
model_info = json.loads(resp.text)
323+
return model_info["model"]["latestVersionIdStr"]
324+
325+
raise ValueError("NGC API requires requests package. Please install it.")
314326

315327

316328
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
@@ -416,7 +428,11 @@ def _get_latest_bundle_version(
416428
name = _add_ngc_prefix(name)
417429
return _get_latest_bundle_version_ngc(name)
418430
elif source == "monaihosting":
419-
return _get_latest_bundle_version_monaihosting(name)
431+
try:
432+
return _get_latest_bundle_version_monaihosting(name)
433+
except requests.exceptions.HTTPError:
434+
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
435+
return get_bundle_versions(name)["latest_version"]
420436
elif source == "ngc_private":
421437
headers = kwargs.pop("headers", {})
422438
name = _add_ngc_prefix(name)
@@ -585,7 +601,16 @@ def download(
585601
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
586602
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
587603
elif source_ == "monaihosting":
588-
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
604+
try:
605+
_download_from_monaihosting(
606+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
607+
)
608+
except urllib.error.HTTPError:
609+
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
610+
_download_from_bundle_info(
611+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
612+
)
613+
589614
elif source_ == "ngc":
590615
_download_from_ngc(
591616
download_path=bundle_dir_,

0 commit comments

Comments
 (0)