Skip to content

Commit dc49f8e

Browse files
authored
Update utils.py
rollback Signed-off-by: h3rrr <[email protected]>
1 parent c9d19a1 commit dc49f8e

File tree

1 file changed

+9
-54
lines changed

1 file changed

+9
-54
lines changed

monai/apps/utils.py

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
from __future__ import annotations
1313

14-
import os
15-
import shutil
1614
import hashlib
1715
import json
1816
import logging
17+
import os
1918
import re
19+
import shutil
2020
import sys
2121
import tarfile
2222
import tempfile
@@ -80,6 +80,7 @@ def get_logger(
8080
logger = get_logger("monai.apps")
8181
__all__.append("logger")
8282

83+
8384
def _basename(p: PathLike) -> str:
8485
"""get the last part of the path (removing the trailing slash if it exists)"""
8586
sep = os.path.sep + (os.path.altsep or "") + "/ "
@@ -120,36 +121,6 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non
120121
logger.error(f"Download failed from {url} to {filepath}.")
121122
raise e
122123

123-
def safe_extract_member(member, extract_to):
124-
"""Securely verify compressed package member paths to prevent path traversal attacks"""
125-
# Get member path (handle different compression formats)
126-
if hasattr(member, 'filename'):
127-
member_path = member.filename # zipfile
128-
elif hasattr(member, 'name'):
129-
member_path = member.name # tarfile
130-
else:
131-
member_path = str(member)
132-
133-
if hasattr(member, 'issym') and member.issym():
134-
raise ValueError(f"Symbolic link detected in archive: {member_path}")
135-
if hasattr(member, 'islnk') and member.islnk():
136-
raise ValueError(f"Hard link detected in archive: {member_path}")
137-
138-
member_path = os.path.normpath(member_path)
139-
140-
if os.path.isabs(member_path) or '..' in member_path.split(os.sep):
141-
raise ValueError(f"Unsafe path detected in archive: {member_path}")
142-
143-
full_path = os.path.join(extract_to, member_path)
144-
full_path = os.path.normpath(full_path)
145-
146-
extract_to_abs = os.path.abspath(extract_to)
147-
full_path_abs = os.path.abspath(full_path)
148-
149-
if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)):
150-
raise ValueError(f"Path traversal attack detected: {member_path}")
151-
152-
return full_path
153124

154125
def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool:
155126
"""
@@ -316,30 +287,14 @@ def extractall(
316287
logger.info(f"Writing into directory: {output_dir}.")
317288
_file_type = file_type.lower().strip()
318289
if filepath.name.endswith("zip") or _file_type == "zip":
319-
with zipfile.ZipFile(filepath, 'r') as zip_file:
320-
for member in zip_file.infolist():
321-
safe_path = safe_extract_member(member, output_dir)
322-
if member.is_dir():
323-
continue
324-
325-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
326-
with zip_file.open(member) as source:
327-
with open(safe_path, 'wb') as target:
328-
shutil.copyfileobj(source, target)
290+
zip_file = zipfile.ZipFile(filepath)
291+
zip_file.extractall(output_dir)
292+
zip_file.close()
329293
return
330294
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
331-
with tarfile.open(filepath, 'r') as tar_file:
332-
for member in tar_file.getmembers():
333-
safe_path = safe_extract_member(member, output_dir)
334-
if not member.isfile():
335-
continue
336-
337-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
338-
source = tar_file.extractfile(member)
339-
if source is not None:
340-
with source:
341-
with open(safe_path, 'wb') as target:
342-
shutil.copyfileobj(source, target)
295+
tar_file = tarfile.open(filepath)
296+
tar_file.extractall(output_dir)
297+
tar_file.close()
343298
return
344299
raise NotImplementedError(
345300
f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.'

0 commit comments

Comments
 (0)