|
15 | 15 | import json |
16 | 16 | import os |
17 | 17 | import re |
| 18 | +import urllib |
18 | 19 | import warnings |
19 | 20 | import zipfile |
20 | 21 | from collections.abc import Mapping, Sequence |
|
25 | 26 | from textwrap import dedent |
26 | 27 | from typing import Any, Callable |
27 | 28 |
|
| 29 | +import requests |
28 | 30 | import torch |
29 | 31 | from torch.cuda import is_available |
30 | 32 |
|
@@ -206,6 +208,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str |
206 | 208 | extractall(filepath=filepath, output_dir=download_path, has_base=True) |
207 | 209 |
|
208 | 210 |
|
| 211 | +def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None: |
| 212 | + bundle_info = get_bundle_info(name=filename, version=version) |
| 213 | + if not bundle_info: |
| 214 | + raise ValueError(f"Bundle info not found for {filename} v{version}.") |
| 215 | + url = bundle_info["source"] |
| 216 | + filepath = download_path / f"{filename}_v{version}.zip" |
| 217 | + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) |
| 218 | + extractall(filepath=filepath, output_dir=download_path, has_base=True) |
| 219 | + |
| 220 | + |
209 | 221 | def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: |
210 | 222 | if name.startswith(prefix): |
211 | 223 | return name |
@@ -307,10 +319,10 @@ def _get_latest_bundle_version_monaihosting(name): |
307 | 319 | if has_requests: |
308 | 320 | resp = requests_get(full_url) |
309 | 321 | resp.raise_for_status() |
310 | | - else: |
311 | | - raise ValueError("NGC API requires requests package. Please install it.") |
312 | | - model_info = json.loads(resp.text) |
313 | | - return model_info["model"]["latestVersionIdStr"] |
| 322 | + model_info = json.loads(resp.text) |
| 323 | + return model_info["model"]["latestVersionIdStr"] |
| 324 | + |
| 325 | + raise ValueError("NGC API requires requests package. Please install it.") |
314 | 326 |
|
315 | 327 |
|
316 | 328 | def _examine_monai_version(monai_version: str) -> tuple[bool, str]: |
@@ -416,7 +428,11 @@ def _get_latest_bundle_version( |
416 | 428 | name = _add_ngc_prefix(name) |
417 | 429 | return _get_latest_bundle_version_ngc(name) |
418 | 430 | elif source == "monaihosting": |
419 | | - return _get_latest_bundle_version_monaihosting(name) |
| 431 | + try: |
| 432 | + return _get_latest_bundle_version_monaihosting(name) |
| 433 | + except requests.exceptions.HTTPError: |
| 434 | + # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json |
| 435 | + return get_bundle_versions(name)["latest_version"] |
420 | 436 | elif source == "ngc_private": |
421 | 437 | headers = kwargs.pop("headers", {}) |
422 | 438 | name = _add_ngc_prefix(name) |
@@ -585,7 +601,16 @@ def download( |
585 | 601 | name_ver = "_v".join([name_, version_]) if version_ is not None else name_ |
586 | 602 | _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) |
587 | 603 | elif source_ == "monaihosting": |
588 | | - _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) |
| 604 | + try: |
| 605 | + _download_from_monaihosting( |
| 606 | + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ |
| 607 | + ) |
| 608 | + except urllib.error.HTTPError: |
| 609 | + # for monaihosting bundles, if cannot download from default host, download according to bundle_info |
| 610 | + _download_from_bundle_info( |
| 611 | + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ |
| 612 | + ) |
| 613 | + |
589 | 614 | elif source_ == "ngc": |
590 | 615 | _download_from_ngc( |
591 | 616 | download_path=bundle_dir_, |
|
0 commit comments