Skip to content

Commit bad1907

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2c72dbc commit bad1907

File tree

1 file changed

+44
-46
lines changed

1 file changed

+44
-46
lines changed

monai/apps/utils.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111

1212
from __future__ import annotations
1313

14-
import os
14+
import os
1515
import shutil
1616
import hashlib
1717
import json
1818
import logging
19-
import os
2019
import re
21-
import shutil
2220
import sys
2321
import tarfile
2422
import tempfile
@@ -122,32 +120,32 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non
122120
logger.error(f"Download failed from {url} to {filepath}.")
123121
raise e
124122

125-
def safe_extract_member(member, extract_to):
126-
"""Securely verify compressed package member paths to prevent path traversal attacks"""
127-
# Get member path (handle different compression formats)
128-
if hasattr(member, 'filename'):
129-
member_path = member.filename # zipfile
130-
elif hasattr(member, 'name'):
131-
member_path = member.name # tarfile
132-
else:
133-
member_path = str(member)
134-
135-
member_path = os.path.normpath(member_path)
136-
137-
if os.path.isabs(member_path) or '..' in member_path.split(os.sep):
138-
raise ValueError(f"Unsafe path detected in archive: {member_path}")
139-
140-
full_path = os.path.join(extract_to, member_path)
141-
full_path = os.path.normpath(full_path)
142-
143-
extract_to_abs = os.path.abspath(extract_to)
144-
full_path_abs = os.path.abspath(full_path)
145-
146-
if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)):
147-
raise ValueError(f"Path traversal attack detected: {member_path}")
148-
123+
def safe_extract_member(member, extract_to):
124+
"""Securely verify compressed package member paths to prevent path traversal attacks"""
125+
# Get member path (handle different compression formats)
126+
if hasattr(member, 'filename'):
127+
member_path = member.filename # zipfile
128+
elif hasattr(member, 'name'):
129+
member_path = member.name # tarfile
130+
else:
131+
member_path = str(member)
132+
133+
member_path = os.path.normpath(member_path)
134+
135+
if os.path.isabs(member_path) or '..' in member_path.split(os.sep):
136+
raise ValueError(f"Unsafe path detected in archive: {member_path}")
137+
138+
full_path = os.path.join(extract_to, member_path)
139+
full_path = os.path.normpath(full_path)
140+
141+
extract_to_abs = os.path.abspath(extract_to)
142+
full_path_abs = os.path.abspath(full_path)
143+
144+
if not (full_path_abs == extract_to_abs or full_path_abs.startswith(extract_to_abs + os.sep)):
145+
raise ValueError(f"Path traversal attack detected: {member_path}")
146+
149147
return full_path
150-
148+
151149
def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool:
152150
"""
153151
Verify hash signature of specified file.
@@ -313,27 +311,27 @@ def extractall(
313311
logger.info(f"Writing into directory: {output_dir}.")
314312
_file_type = file_type.lower().strip()
315313
if filepath.name.endswith("zip") or _file_type == "zip":
316-
with zipfile.ZipFile(filepath, 'r') as zip_file:
317-
for member in zip_file.infolist():
318-
if member.is_dir():
319-
continue
320-
safe_path = safe_extract_member(member, output_dir)
321-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
322-
with zip_file.open(member) as source:
323-
with open(safe_path, 'wb') as target:
314+
with zipfile.ZipFile(filepath, 'r') as zip_file:
315+
for member in zip_file.infolist():
316+
if member.is_dir():
317+
continue
318+
safe_path = safe_extract_member(member, output_dir)
319+
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
320+
with zip_file.open(member) as source:
321+
with open(safe_path, 'wb') as target:
324322
shutil.copyfileobj(source, target)
325323
return
326324
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
327-
with tarfile.open(filepath, 'r') as tar_file:
328-
for member in tar_file.getmembers():
329-
if not member.isfile():
330-
continue
331-
332-
safe_path = safe_extract_member(member, output_dir)
333-
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
334-
with tar_file.extractfile(member) as source:
335-
if source:
336-
with open(safe_path, 'wb') as target:
325+
with tarfile.open(filepath, 'r') as tar_file:
326+
for member in tar_file.getmembers():
327+
if not member.isfile():
328+
continue
329+
330+
safe_path = safe_extract_member(member, output_dir)
331+
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
332+
with tar_file.extractfile(member) as source:
333+
if source:
334+
with open(safe_path, 'wb') as target:
337335
shutil.copyfileobj(source, target)
338336
return
339337
raise NotImplementedError(

0 commit comments

Comments
 (0)