Skip to content

Commit 8143ac3

Browse files
committed
Enhance download_and_extract
Signed-off-by: jerome_Hsieh <[email protected]>
1 parent 7a26dcd commit 8143ac3

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

monai/apps/utils.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from urllib.parse import urlparse
2929
from urllib.request import urlopen, urlretrieve
3030

31-
import requests
32-
3331
from monai.config.type_definitions import PathLike
3432
from monai.utils import look_up_option, min_version, optional_import
3533

34+
requests, has_requests = optional_import("requests")
3635
gdown, has_gdown = optional_import("gdown", "4.7.3")
36+
BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup")
3737

3838
if TYPE_CHECKING:
3939
from tqdm import tqdm
@@ -303,14 +303,29 @@ def extractall(
303303

304304
def get_filename_from_url(data_url: str):
305305
try:
306-
response = requests.head(data_url, allow_redirects=True)
307-
content_disposition = response.headers.get("Content-Disposition")
308-
if content_disposition:
309-
filename = re.findall("filename=(.+)", content_disposition)
310-
return filename[0].strip('"').strip("'")
306+
if "drive.google.com" in data_url:
307+
response = requests.head(data_url, allow_redirects=True)
308+
cd = response.headers.get("Content-Disposition") # Normal size file case
309+
if cd:
310+
filename = cd.split('filename="')[1].split('"')[0]
311+
return filename
312+
response = requests.get(data_url)
313+
if "text/html" in response.headers.get("Content-Type", ""): # Big size file case
314+
soup = BeautifulSoup(response.text, "html.parser")
315+
filename_div = soup.find("span", {"class": "uc-name-size"})
316+
if filename_div:
317+
filename = filename_div.find("a").text
318+
return filename
319+
return None
311320
else:
312-
filename = _basename(data_url)
313-
return filename
321+
response = requests.head(data_url, allow_redirects=True)
322+
content_disposition = response.headers.get("Content-Disposition")
323+
if content_disposition:
324+
filename = re.findall("filename=(.+)", content_disposition)
325+
return filename[0].strip('"').strip("'")
326+
else:
327+
filename = _basename(data_url)
328+
return filename
314329
except Exception as e:
315330
raise Exception(f"Error processing URL: {e}")
316331

@@ -344,21 +359,18 @@ def download_and_extract(
344359
be False.
345360
progress: whether to display progress bar.
346361
"""
362+
url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes)
363+
filepath_ext = "".join(Path(_basename(filepath)).suffixes)
364+
if filepath not in ["", "."]:
365+
if filepath_ext == "":
366+
new_filepath = Path(filepath).with_suffix(url_filename_ext)
367+
logger.warning(
368+
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
369+
)
370+
filepath = new_filepath
371+
if filepath_ext and filepath_ext != url_filename_ext:
372+
raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}")
347373
with tempfile.TemporaryDirectory() as tmp_dir:
348-
if not filepath:
349-
filename = get_filename_from_url(url)
350-
full_path = Path(tmp_dir, filename)
351-
elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]:
352-
filename = get_filename_from_url(url)
353-
full_path = Path(os.path.join(filepath, filename))
354-
logger.warning(f"No compress file extension provided, downloading as: '{full_path}'")
355-
else:
356-
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
357-
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
358-
if filepath_ext != url_filename_ext:
359-
raise ValueError(
360-
f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}"
361-
)
362-
full_path = Path(filepath)
363-
download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress)
364-
extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base)
374+
filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()
375+
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
376+
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)

0 commit comments

Comments
 (0)