Skip to content

Commit 03e1e13

Browse files
committed
enhance download_and_extract
1 parent 13b96ae commit 03e1e13

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

monai/apps/utils.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,47 @@ def download_and_extract(
327327
be False.
328328
progress: whether to display progress bar.
329329
"""
330-
with tempfile.TemporaryDirectory() as tmp_dir:
331-
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
332-
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
333-
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
330+
def download_and_extract(
331+
url: str,
332+
filepath: PathLike = "",
333+
output_dir: PathLike = ".",
334+
hash_val: str | None = None,
335+
hash_type: str = "md5",
336+
file_type: str = "",
337+
has_base: bool = True,
338+
progress: bool = True,
339+
) -> None:
340+
"""
341+
Download file from URL and extract it to the output directory.
342+
343+
Args:
344+
url: source URL link to download file.
345+
filepath: the file path of the downloaded compressed file.
346+
use this option to keep the directly downloaded compressed file, to avoid further repeated downloads.
347+
output_dir: target directory to save extracted files.
348+
default is the current directory.
349+
hash_val: expected hash value to validate the downloaded file.
350+
if None, skip hash validation.
351+
hash_type: 'md5' or 'sha1', defaults to 'md5'.
352+
file_type: string of file type for decompressing. Leave it empty to infer the type from url's base file name.
353+
has_base: whether the extracted files have a base folder. This flag is used when checking if the existing
354+
folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped
355+
to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should
356+
be False.
357+
progress: whether to display progress bar.
358+
"""
359+
urlFilenameExtension = ''.join(Path(".", _basename(url)).resolve().suffixes)
360+
if filepath:
361+
FilepathExtenstion = ''.join(Path(".", _basename(filepath)).resolve().suffixes)
362+
if urlFilenameExtension != FilepathExtenstion:
363+
raise NotImplementedError(
364+
f'The file types do not match: url={urlFilenameExtension}, but filepath={FilepathExtenstion}'
365+
)
366+
else:
367+
with tempfile.TemporaryDirectory() as tmp_dir:
368+
if filepath:
369+
filename = filepath
370+
else:
371+
filename = Path(tmp_dir, _basename(url)).resolve()
372+
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
373+
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)

0 commit comments

Comments
 (0)