|
28 | 28 | from urllib.parse import urlparse |
29 | 29 | from urllib.request import urlopen, urlretrieve |
30 | 30 |
|
31 | | -import requests |
32 | | - |
33 | 31 | from monai.config.type_definitions import PathLike |
34 | 32 | from monai.utils import look_up_option, min_version, optional_import |
35 | 33 |
|
| 34 | +requests, has_requests = optional_import("requests") |
36 | 35 | gdown, has_gdown = optional_import("gdown", "4.7.3") |
| 36 | +BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup") |
37 | 37 |
|
38 | 38 | if TYPE_CHECKING: |
39 | 39 | from tqdm import tqdm |
@@ -303,14 +303,29 @@ def extractall( |
303 | 303 |
|
304 | 304 | def get_filename_from_url(data_url: str): |
305 | 305 | try: |
306 | | - response = requests.head(data_url, allow_redirects=True) |
307 | | - content_disposition = response.headers.get("Content-Disposition") |
308 | | - if content_disposition: |
309 | | - filename = re.findall("filename=(.+)", content_disposition) |
310 | | - return filename[0].strip('"').strip("'") |
| 306 | + if "drive.google.com" in data_url: |
| 307 | + response = requests.head(data_url, allow_redirects=True) |
| 308 | + cd = response.headers.get("Content-Disposition") # Normal size file case |
| 309 | + if cd: |
| 310 | + filename = cd.split('filename="')[1].split('"')[0] |
| 311 | + return filename |
| 312 | + response = requests.get(data_url) |
| 313 | + if "text/html" in response.headers.get("Content-Type", ""): # Big size file case |
| 314 | + soup = BeautifulSoup(response.text, "html.parser") |
| 315 | + filename_div = soup.find("span", {"class": "uc-name-size"}) |
| 316 | + if filename_div: |
| 317 | + filename = filename_div.find("a").text |
| 318 | + return filename |
| 319 | + return None |
311 | 320 | else: |
312 | | - filename = _basename(data_url) |
313 | | - return filename |
| 321 | + response = requests.head(data_url, allow_redirects=True) |
| 322 | + content_disposition = response.headers.get("Content-Disposition") |
| 323 | + if content_disposition: |
| 324 | + filename = re.findall("filename=(.+)", content_disposition) |
| 325 | + return filename[0].strip('"').strip("'") |
| 326 | + else: |
| 327 | + filename = _basename(data_url) |
| 328 | + return filename |
314 | 329 | except Exception as e: |
315 | 330 | raise Exception(f"Error processing URL: {e}") |
316 | 331 |
|
@@ -344,21 +359,18 @@ def download_and_extract( |
344 | 359 | be False. |
345 | 360 | progress: whether to display progress bar. |
346 | 361 | """ |
| 362 | + url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes) |
| 363 | + filepath_ext = "".join(Path(_basename(filepath)).suffixes) |
| 364 | + if filepath not in ["", "."]: |
| 365 | + if filepath_ext == "": |
| 366 | + new_filepath = Path(filepath).with_suffix(url_filename_ext) |
| 367 | + logger.warning( |
| 368 | + f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" |
| 369 | + ) |
| 370 | + filepath = new_filepath |
| 371 | + if filepath_ext and filepath_ext != url_filename_ext: |
| 372 | + raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}") |
347 | 373 | with tempfile.TemporaryDirectory() as tmp_dir: |
348 | | - if not filepath: |
349 | | - filename = get_filename_from_url(url) |
350 | | - full_path = Path(tmp_dir, filename) |
351 | | - elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]: |
352 | | - filename = get_filename_from_url(url) |
353 | | - full_path = Path(os.path.join(filepath, filename)) |
354 | | - logger.warning(f"No compress file extension provided, downloading as: '{full_path}'") |
355 | | - else: |
356 | | - url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes) |
357 | | - filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes) |
358 | | - if filepath_ext != url_filename_ext: |
359 | | - raise ValueError( |
360 | | - f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}" |
361 | | - ) |
362 | | - full_path = Path(filepath) |
363 | | - download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress) |
364 | | - extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base) |
| 374 | + filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve() |
| 375 | + download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) |
| 376 | + extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) |
0 commit comments