|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +import os |
| 15 | +import shutil |
14 | 16 | import hashlib |
15 | 17 | import json |
16 | 18 | import logging |
@@ -80,7 +82,6 @@ def get_logger( |
80 | 82 | logger = get_logger("monai.apps") |
81 | 83 | __all__.append("logger") |
82 | 84 |
|
83 | | - |
84 | 85 | def _basename(p: PathLike) -> str: |
85 | 86 | """get the last part of the path (removing the trailing slash if it exists)""" |
86 | 87 | sep = os.path.sep + (os.path.altsep or "") + "/ " |
@@ -121,7 +122,32 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non |
121 | 122 | logger.error(f"Download failed from {url} to {filepath}.") |
122 | 123 | raise e |
123 | 124 |
|
124 | | - |
| 125 | +def safe_extract_member(member, extract_to): |
| 126 | + """Securely verify compressed package member paths to prevent path traversal attacks""" |
| 127 | + # Get member path (handle different compression formats) |
| 128 | + if hasattr(member, 'filename'): |
| 129 | + member_path = member.filename # zipfile |
| 130 | + elif hasattr(member, 'name'): |
| 131 | + member_path = member.name # tarfile |
| 132 | + else: |
| 133 | + member_path = str(member) |
| 134 | + |
| 135 | + member_path = os.path.normpath(member_path) |
| 136 | + |
| 137 | + if os.path.isabs(member_path) or '..' in member_path.split(os.sep): |
| 138 | + raise ValueError(f"Unsafe path detected in archive: {member_path}") |
| 139 | + |
| 140 | + full_path = os.path.join(extract_to, member_path) |
| 141 | + full_path = os.path.normpath(full_path) |
| 142 | + |
| 143 | + extract_to_abs = os.path.abspath(extract_to) |
| 144 | + full_path_abs = os.path.abspath(full_path) |
| 145 | + |
| 146 | + if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)): |
| 147 | + raise ValueError(f"Path traversal attack detected: {member_path}") |
| 148 | + |
| 149 | + return full_path |
| 150 | + |
125 | 151 | def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool: |
126 | 152 | """ |
127 | 153 | Verify hash signature of specified file. |
@@ -287,14 +313,28 @@ def extractall( |
287 | 313 | logger.info(f"Writing into directory: {output_dir}.") |
288 | 314 | _file_type = file_type.lower().strip() |
289 | 315 | if filepath.name.endswith("zip") or _file_type == "zip": |
290 | | - zip_file = zipfile.ZipFile(filepath) |
291 | | - zip_file.extractall(output_dir) |
292 | | - zip_file.close() |
| 316 | + with zipfile.ZipFile(filepath, 'r') as zip_file: |
| 317 | + for member in zip_file.infolist(): |
| 318 | + if member.is_dir(): |
| 319 | + continue |
| 320 | + safe_path = safe_extract_member(member, output_dir) |
| 321 | + os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
| 322 | + with zip_file.open(member) as source: |
| 323 | + with open(safe_path, 'wb') as target: |
| 324 | + shutil.copyfileobj(source, target) |
293 | 325 | return |
294 | 326 | if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: |
295 | | - tar_file = tarfile.open(filepath) |
296 | | - tar_file.extractall(output_dir) |
297 | | - tar_file.close() |
| 327 | + with tarfile.open(filepath, 'r') as tar_file: |
| 328 | + for member in tar_file.getmembers(): |
| 329 | + if not member.isfile(): |
| 330 | + continue |
| 331 | + |
| 332 | + safe_path = safe_extract_member(member, output_dir) |
| 333 | + os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
| 334 | + with tar_file.extractfile(member) as source: |
| 335 | + if source: |
| 336 | + with open(safe_path, 'wb') as target: |
| 337 | + shutil.copyfileobj(source, target) |
298 | 338 | return |
299 | 339 | raise NotImplementedError( |
300 | 340 | f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.' |
|
0 commit comments