Skip to content

Commit d3c711b

Browse files
authored
Update utils.py
Changed to previous content, the fix will be filed in a new PR Signed-off-by: h3rrr <[email protected]>
1 parent a8ed1df commit d3c711b

File tree

1 file changed

+9
-47
lines changed

1 file changed

+9
-47
lines changed

monai/apps/utils.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
from __future__ import annotations
1313

14-
import os
15-
import shutil
1614
import hashlib
1715
import json
1816
import logging
17+
import os
1918
import re
19+
import shutil
2020
import sys
2121
import tarfile
2222
import tempfile
@@ -80,6 +80,7 @@ def get_logger(
8080
logger = get_logger("monai.apps")
8181
__all__.append("logger")
8282

83+
8384
def _basename(p: PathLike) -> str:
8485
"""get the last part of the path (removing the trailing slash if it exists)"""
8586
sep = os.path.sep + (os.path.altsep or "") + "/ "
@@ -120,31 +121,6 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non
120121
logger.error(f"Download failed from {url} to {filepath}.")
121122
raise e
122123

123-
def safe_extract_member(member, extract_to):
124-
"""Securely verify compressed package member paths to prevent path traversal attacks"""
125-
# Get member path (handle different compression formats)
126-
if hasattr(member, 'filename'):
127-
member_path = member.filename # zipfile
128-
elif hasattr(member, 'name'):
129-
member_path = member.name # tarfile
130-
else:
131-
member_path = str(member)
132-
133-
member_path = os.path.normpath(member_path)
134-
135-
if os.path.isabs(member_path) or '..' in member_path.split(os.sep):
136-
raise ValueError(f"Unsafe path detected in archive: {member_path}")
137-
138-
full_path = os.path.join(extract_to, member_path)
139-
full_path = os.path.normpath(full_path)
140-
141-
extract_to_abs = os.path.abspath(extract_to)
142-
full_path_abs = os.path.abspath(full_path)
143-
144-
if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)):
145-
raise ValueError(f"Path traversal attack detected: {member_path}")
146-
147-
return full_path
148124

149125
def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool:
150126
"""
@@ -311,28 +287,14 @@ def extractall(
311287
logger.info(f"Writing into directory: {output_dir}.")
312288
_file_type = file_type.lower().strip()
313289
if filepath.name.endswith("zip") or _file_type == "zip":
314-
with zipfile.ZipFile(filepath, 'r') as zip_file:
315-
for member in zip_file.infolist():
316-
if member.is_dir():
317-
continue
318-
safe_path = safe_extract_member(member, output_dir)
319-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
320-
with zip_file.open(member) as source:
321-
with open(safe_path, 'wb') as target:
322-
shutil.copyfileobj(source, target)
290+
zip_file = zipfile.ZipFile(filepath)
291+
zip_file.extractall(output_dir)
292+
zip_file.close()
323293
return
324294
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
325-
with tarfile.open(filepath, 'r') as tar_file:
326-
for member in tar_file.getmembers():
327-
if not member.isfile():
328-
continue
329-
330-
safe_path = safe_extract_member(member, output_dir)
331-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
332-
with tar_file.extractfile(member) as source:
333-
if source:
334-
with open(safe_path, 'wb') as target:
335-
shutil.copyfileobj(source, target)
295+
tar_file = tarfile.open(filepath)
296+
tar_file.extractall(output_dir)
297+
tar_file.close()
336298
return
337299
raise NotImplementedError(
338300
f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.'

0 commit comments

Comments
 (0)