Skip to content

Commit 109a1aa

Browse files
committed
Enhance download_and_extract
Signed-off-by: jerome_Hsieh <[email protected]>
1 parent fcc269e commit 109a1aa

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

monai/apps/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def extractall(
301301
)
302302

303303

304-
def _get_filename_from_url(data_url: str):
304+
def get_filename_from_url(data_url: str):
305305
try:
306306
if "drive.google.com" in data_url:
307307
response = requests.head(data_url, allow_redirects=True)
@@ -359,7 +359,7 @@ def download_and_extract(
359359
be False.
360360
progress: whether to display progress bar.
361361
"""
362-
url_filename_ext = "".join(Path(_get_filename_from_url(url)).suffixes)
362+
url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes)
363363
filepath_ext = "".join(Path(_basename(filepath)).suffixes)
364364
if filepath not in ["", "."]:
365365
if filepath_ext == "":
@@ -371,6 +371,6 @@ def download_and_extract(
371371
if filepath_ext and filepath_ext != url_filename_ext:
372372
raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}")
373373
with tempfile.TemporaryDirectory() as tmp_dir:
374-
filename = filepath or Path(tmp_dir, _get_filename_from_url(url)).resolve()
374+
filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()
375375
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
376376
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)

tests/test_download_and_extract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from parameterized import parameterized
2121

2222
from monai.apps import download_and_extract, download_url, extractall
23-
from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config
23+
from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config
2424

2525

26+
@SkipIfNoModule("requests")
2627
class TestDownloadAndExtract(unittest.TestCase):
2728

2829
@skip_if_quick

0 commit comments

Comments
 (0)