Skip to content

Commit 1d92253

Browse files
fix mypy error
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 4a4a738 commit 1d92253

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

monai/bundle/scripts.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from textwrap import dedent
2727
from typing import Any, Callable
2828

29-
import requests
3029
import torch
3130
from torch.cuda import is_available
3231

@@ -60,7 +59,7 @@
6059
validate, _ = optional_import("jsonschema", name="validate")
6160
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
6261
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
63-
requests_get, has_requests = optional_import("requests", name="get")
62+
requests, has_requests = optional_import("requests")
6463
onnx, _ = optional_import("onnx")
6564
huggingface_hub, _ = optional_import("huggingface_hub")
6665

@@ -234,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
234233
if not has_requests:
235234
raise ValueError("requests package is required, please install it.")
236235
headers = {} if headers is None else headers
237-
response = requests_get(request_url, headers=headers)
236+
response = requests.get(request_url, headers=headers)
238237
response.raise_for_status()
239238
model_info = json.loads(response.text)
240239

@@ -278,7 +277,7 @@ def _download_from_ngc_private(
278277
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
279278
if has_requests:
280279
headers = {} if headers is None else headers
281-
response = requests_get(request_url, headers=headers)
280+
response = requests.get(request_url, headers=headers)
282281
response.raise_for_status()
283282
else:
284283
raise ValueError("NGC API requires requests package. Please install it.")
@@ -301,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
301300
url = "https://authn.nvidia.com/token?service=ngc"
302301
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
303302
if has_requests:
304-
response = requests_get(url, headers=headers)
303+
response = requests.get(url, headers=headers)
305304
if not response.ok:
306305
# retry 3 times, if failed, raise an error.
307306
if retry < 3:
@@ -315,12 +314,15 @@ def _get_ngc_token(api_key, retry=0):
315314

316315
def _get_latest_bundle_version_monaihosting(name):
317316
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
318-
requests_get, has_requests = optional_import("requests", name="get")
319317
if has_requests:
320-
resp = requests_get(full_url)
321-
resp.raise_for_status()
322-
model_info = json.loads(resp.text)
323-
return model_info["model"]["latestVersionIdStr"]
318+
resp = requests.get(full_url)
319+
try:
320+
resp.raise_for_status()
321+
model_info = json.loads(resp.text)
322+
return model_info["model"]["latestVersionIdStr"]
323+
except requests.exceptions.HTTPError:
324+
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325+
return get_bundle_versions(name)["latest_version"]
324326

325327
raise ValueError("NGC API requires requests package. Please install it.")
326328

@@ -400,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
400402
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
401403
if headers:
402404
version_header.update(headers)
403-
resp = requests_get(version_endpoint, headers=version_header)
405+
resp = requests.get(version_endpoint, headers=version_header)
404406
resp.raise_for_status()
405407
model_info = json.loads(resp.text)
406408
latest_versions = _list_latest_versions(model_info)
407409

408410
for version in latest_versions:
409411
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
410-
resp = requests_get(file_endpoint, headers=headers)
412+
resp = requests.get(file_endpoint, headers=headers)
411413
metadata = json.loads(resp.text)
412414
resp.raise_for_status()
413415
# if the package version is not available or the model is compatible with the package version
@@ -428,11 +430,7 @@ def _get_latest_bundle_version(
428430
name = _add_ngc_prefix(name)
429431
return _get_latest_bundle_version_ngc(name)
430432
elif source == "monaihosting":
431-
try:
432-
return _get_latest_bundle_version_monaihosting(name)
433-
except requests.exceptions.HTTPError:
434-
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
435-
return get_bundle_versions(name)["latest_version"]
433+
return _get_latest_bundle_version_monaihosting(name)
436434
elif source == "ngc_private":
437435
headers = kwargs.pop("headers", {})
438436
name = _add_ngc_prefix(name)
@@ -817,9 +815,9 @@ def _get_all_bundles_info(
817815

818816
if auth_token is not None:
819817
headers = {"Authorization": f"Bearer {auth_token}"}
820-
resp = requests_get(request_url, headers=headers)
818+
resp = requests.get(request_url, headers=headers)
821819
else:
822-
resp = requests_get(request_url)
820+
resp = requests.get(request_url)
823821
resp.raise_for_status()
824822
else:
825823
raise ValueError("requests package is required, please install it.")

0 commit comments

Comments
 (0)