|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | | -import os |
| 14 | +import os |
15 | 15 | import shutil |
16 | 16 | import hashlib |
17 | 17 | import json |
18 | 18 | import logging |
19 | | -import os |
20 | 19 | import re |
21 | | -import shutil |
22 | 20 | import sys |
23 | 21 | import tarfile |
24 | 22 | import tempfile |
@@ -122,32 +120,32 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non |
122 | 120 | logger.error(f"Download failed from {url} to {filepath}.") |
123 | 121 | raise e |
124 | 122 |
|
125 | | -def safe_extract_member(member, extract_to): |
126 | | - """Securely verify compressed package member paths to prevent path traversal attacks""" |
127 | | - # Get member path (handle different compression formats) |
128 | | - if hasattr(member, 'filename'): |
129 | | - member_path = member.filename # zipfile |
130 | | - elif hasattr(member, 'name'): |
131 | | - member_path = member.name # tarfile |
132 | | - else: |
133 | | - member_path = str(member) |
134 | | - |
135 | | - member_path = os.path.normpath(member_path) |
136 | | - |
137 | | - if os.path.isabs(member_path) or '..' in member_path.split(os.sep): |
138 | | - raise ValueError(f"Unsafe path detected in archive: {member_path}") |
139 | | - |
140 | | - full_path = os.path.join(extract_to, member_path) |
141 | | - full_path = os.path.normpath(full_path) |
142 | | - |
143 | | - extract_to_abs = os.path.abspath(extract_to) |
144 | | - full_path_abs = os.path.abspath(full_path) |
145 | | - |
146 | | - if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)): |
147 | | - raise ValueError(f"Path traversal attack detected: {member_path}") |
148 | | - |
| 123 | +def safe_extract_member(member, extract_to): |
| 124 | + """Securely verify compressed package member paths to prevent path traversal attacks""" |
| 125 | + # Get member path (handle different compression formats) |
| 126 | + if hasattr(member, 'filename'): |
| 127 | + member_path = member.filename # zipfile |
| 128 | + elif hasattr(member, 'name'): |
| 129 | + member_path = member.name # tarfile |
| 130 | + else: |
| 131 | + member_path = str(member) |
| 132 | + |
| 133 | + member_path = os.path.normpath(member_path) |
| 134 | + |
| 135 | + if os.path.isabs(member_path) or '..' in member_path.split(os.sep): |
| 136 | + raise ValueError(f"Unsafe path detected in archive: {member_path}") |
| 137 | + |
| 138 | + full_path = os.path.join(extract_to, member_path) |
| 139 | + full_path = os.path.normpath(full_path) |
| 140 | + |
| 141 | + extract_to_abs = os.path.abspath(extract_to) |
| 142 | + full_path_abs = os.path.abspath(full_path) |
| 143 | + |
| 144 | + if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)): |
| 145 | + raise ValueError(f"Path traversal attack detected: {member_path}") |
| 146 | + |
149 | 147 | return full_path |
150 | | - |
| 148 | + |
151 | 149 | def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool: |
152 | 150 | """ |
153 | 151 | Verify hash signature of specified file. |
@@ -313,27 +311,27 @@ def extractall( |
313 | 311 | logger.info(f"Writing into directory: {output_dir}.") |
314 | 312 | _file_type = file_type.lower().strip() |
315 | 313 | if filepath.name.endswith("zip") or _file_type == "zip": |
316 | | - with zipfile.ZipFile(filepath, 'r') as zip_file: |
317 | | - for member in zip_file.infolist(): |
318 | | - if member.is_dir(): |
319 | | - continue |
320 | | - safe_path = safe_extract_member(member, output_dir) |
321 | | - os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
322 | | - with zip_file.open(member) as source: |
323 | | - with open(safe_path, 'wb') as target: |
| 314 | + with zipfile.ZipFile(filepath, 'r') as zip_file: |
| 315 | + for member in zip_file.infolist(): |
| 316 | + if member.is_dir(): |
| 317 | + continue |
| 318 | + safe_path = safe_extract_member(member, output_dir) |
| 319 | + os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
| 320 | + with zip_file.open(member) as source: |
| 321 | + with open(safe_path, 'wb') as target: |
324 | 322 | shutil.copyfileobj(source, target) |
325 | 323 | return |
326 | 324 | if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: |
327 | | - with tarfile.open(filepath, 'r') as tar_file: |
328 | | - for member in tar_file.getmembers(): |
329 | | - if not member.isfile(): |
330 | | - continue |
331 | | - |
332 | | - safe_path = safe_extract_member(member, output_dir) |
333 | | - os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
334 | | - with tar_file.extractfile(member) as source: |
335 | | - if source: |
336 | | - with open(safe_path, 'wb') as target: |
| 325 | + with tarfile.open(filepath, 'r') as tar_file: |
| 326 | + for member in tar_file.getmembers(): |
| 327 | + if not member.isfile(): |
| 328 | + continue |
| 329 | + |
| 330 | + safe_path = safe_extract_member(member, output_dir) |
| 331 | + os.makedirs(os.path.dirname(safe_path), exist_ok=True) |
| 332 | + with tar_file.extractfile(member) as source: |
| 333 | + if source: |
| 334 | + with open(safe_path, 'wb') as target: |
337 | 335 | shutil.copyfileobj(source, target) |
338 | 336 | return |
339 | 337 | raise NotImplementedError( |
|
0 commit comments