|
15 | 15 | import json |
16 | 16 | import logging |
17 | 17 | import os |
| 18 | +import re |
18 | 19 | import shutil |
19 | 20 | import sys |
20 | 21 | import tarfile |
|
24 | 25 | from pathlib import Path |
25 | 26 | from typing import TYPE_CHECKING, Any |
26 | 27 | from urllib.error import ContentTooShortError, HTTPError, URLError |
27 | | -from urllib.parse import urlparse |
| 28 | +from urllib.parse import unquote, urlparse |
28 | 29 | from urllib.request import urlopen, urlretrieve |
29 | 30 |
|
| 31 | +import requests |
| 32 | + |
30 | 33 | from monai.config.type_definitions import PathLike |
31 | 34 | from monai.utils import look_up_option, min_version, optional_import |
32 | 35 |
|
@@ -298,6 +301,20 @@ def extractall( |
298 | 301 | ) |
299 | 302 |
|
300 | 303 |
|
| 304 | +def get_filename_from_url(data_url: str): |
| 305 | + try: |
| 306 | + response = requests.head(data_url, allow_redirects=True) |
| 307 | + content_disposition = response.headers.get("Content-Disposition") |
| 308 | + if content_disposition: |
| 309 | + filename = re.findall("filename=(.+)", content_disposition) |
| 310 | + return filename[0].strip('"').strip("'") |
| 311 | + else: |
| 312 | + filename = _basename(data_url) |
| 313 | + return filename |
| 314 | + except Exception as e: |
| 315 | + raise Exception(f"Error processing URL: {e}") |
| 316 | + |
| 317 | + |
301 | 318 | def download_and_extract( |
302 | 319 | url: str, |
303 | 320 | filepath: PathLike = "", |
@@ -327,18 +344,21 @@ def download_and_extract( |
327 | 344 | be False. |
328 | 345 | progress: whether to display progress bar. |
329 | 346 | """ |
330 | | - url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes) |
331 | | - filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes) |
332 | | - if filepath not in ["", "."]: |
333 | | - if filepath_ext == "": |
334 | | - new_filepath = filepath + url_filename_ext |
335 | | - logger.warning( |
336 | | - f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" |
337 | | - ) |
338 | | - filepath = new_filepath |
339 | | - if filepath_ext and filepath_ext != url_filename_ext: |
340 | | - logger.warning(f"Expected extension {url_filename_ext}, but get {filepath_ext}, may cause unexpected errors!") |
341 | 347 | with tempfile.TemporaryDirectory() as tmp_dir: |
342 | | - filename = filepath or Path(tmp_dir, _basename(url)).resolve() |
343 | | - download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) |
344 | | - extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) |
| 348 | + if not filepath: |
| 349 | + filename = get_filename_from_url(url) |
| 350 | + full_path = Path(tmp_dir, filename) |
| 351 | + elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]: |
| 352 | + filename = get_filename_from_url(url) |
| 353 | + full_path = Path(os.path.join(filepath, filename)) |
| 354 | + logger.warning(f"No compress file extension provided, downloading as: '{full_path}'") |
| 355 | + else: |
| 356 | + url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes) |
| 357 | + filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes) |
| 358 | + if filepath_ext != url_filename_ext: |
| 359 | + raise ValueError( |
| 360 | + f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}" |
| 361 | + ) |
| 362 | + full_path = Path(filepath) |
| 363 | + download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress) |
| 364 | + extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base) |
0 commit comments