Skip to content

Commit e70e59c

Browse files
committed
Enhance download_and_extract
Signed-off-by: jerome_Hsieh <[email protected]>
1 parent a9a0171 commit e70e59c

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

monai/apps/utils.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
import re
1819
import shutil
1920
import sys
2021
import tarfile
@@ -24,9 +25,11 @@
2425
from pathlib import Path
2526
from typing import TYPE_CHECKING, Any
2627
from urllib.error import ContentTooShortError, HTTPError, URLError
27-
from urllib.parse import urlparse
28+
from urllib.parse import unquote, urlparse
2829
from urllib.request import urlopen, urlretrieve
2930

31+
import requests
32+
3033
from monai.config.type_definitions import PathLike
3134
from monai.utils import look_up_option, min_version, optional_import
3235

@@ -298,6 +301,20 @@ def extractall(
298301
)
299302

300303

304+
def get_filename_from_url(data_url: str):
305+
try:
306+
response = requests.head(data_url, allow_redirects=True)
307+
content_disposition = response.headers.get("Content-Disposition")
308+
if content_disposition:
309+
filename = re.findall("filename=(.+)", content_disposition)
310+
return filename[0].strip('"').strip("'")
311+
else:
312+
filename = _basename(data_url)
313+
return filename
314+
except Exception as e:
315+
raise Exception(f"Error processing URL: {e}")
316+
317+
301318
def download_and_extract(
302319
url: str,
303320
filepath: PathLike = "",
@@ -327,18 +344,21 @@ def download_and_extract(
327344
be False.
328345
progress: whether to display progress bar.
329346
"""
330-
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
331-
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
332-
if filepath not in ["", "."]:
333-
if filepath_ext == "":
334-
new_filepath = filepath + url_filename_ext
335-
logger.warning(
336-
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
337-
)
338-
filepath = new_filepath
339-
if filepath_ext and filepath_ext != url_filename_ext:
340-
logger.warning(f"Expected extension {url_filename_ext}, but get {filepath_ext}, may cause unexpected errors!")
341347
with tempfile.TemporaryDirectory() as tmp_dir:
342-
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
343-
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
344-
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
348+
if not filepath:
349+
filename = get_filename_from_url(url)
350+
full_path = Path(tmp_dir, filename)
351+
elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]:
352+
filename = get_filename_from_url(url)
353+
full_path = Path(os.path.join(filepath, filename))
354+
logger.warning(f"No compress file extension provided, downloading as: '{full_path}'")
355+
else:
356+
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
357+
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
358+
if filepath_ext != url_filename_ext:
359+
raise ValueError(
360+
f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}"
361+
)
362+
full_path = Path(filepath)
363+
download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress)
364+
extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base)

0 commit comments

Comments
 (0)