Skip to content

Commit

Permalink
fix lint and type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
malwarefrank committed Mar 17, 2024
1 parent 4edd63e commit fdc1204
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 26 deletions.
13 changes: 7 additions & 6 deletions src/dnfile/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from pefile import Structure

from . import enums, errors
from .utils import LazyList as _LazyList, read_compressed_int as _read_compressed_int
from .utils import LazyList as _LazyList
from .utils import read_compressed_int as _read_compressed_int

if TYPE_CHECKING:
from . import stream
Expand All @@ -25,16 +26,16 @@


class CompressedInt(int):
raw_size: Optional[int] = None
__data__: Optional[bytes] = None
value: Optional[int] = None
raw_size: int
__data__: bytes
value: int
rva: Optional[int] = None

def to_bytes(self):
return self.__data__

@classmethod
def read(cls, data: bytes, rva: Optional[int] = None) -> "CompressedInt":
def read(cls, data: bytes, rva: Optional[int] = None) -> Optional["CompressedInt"]:
result = _read_compressed_int(data)
if result is None:
return None
Expand Down Expand Up @@ -119,7 +120,7 @@ def get_dword_at_rva(self, rva):
class HeapItem(abc.ABC):
rva: Optional[int] = None
# original data from file
__data__: bytes = None
__data__: bytes
# interpreted value
value: Optional[bytes] = None

Expand Down
28 changes: 15 additions & 13 deletions src/dnfile/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import struct as _struct
import logging
import collections as _collections
from typing import Dict, List, Tuple, Union, Optional
from binascii import hexlify as _hexlify
import collections as _collections

from pefile import MAX_STRING_LENGTH, Structure

Expand All @@ -38,7 +38,7 @@ class HeapItemString(base.HeapItem, _collections.abc.Sequence):
def __init__(self, data, encoding="utf-8"):
super().__init__(data)
self.encoding = encoding
self.value = self.__data__.decode(encoding)
self.value: str = self.__data__.decode(encoding)

def __str__(self) -> str:
return self.value
Expand All @@ -56,13 +56,14 @@ def __len__(self):


class HeapItemBinary(base.HeapItem, _collections.abc.Sequence):
item_size: Optional[base.CompressedInt]
item_size: base.CompressedInt

def __init__(self, data: bytes, rva: Optional[int] = None):
# read compressed int, which has a max size of four bytes
self.item_size = base.CompressedInt.read(data[:4], rva)
if self.item_size is None:
size = base.CompressedInt.read(data[:4], rva)
if size is None:
raise ValueError("invalid compressed int")
self.item_size = size
base.HeapItem.__init__(self, data[:self.item_size.raw_size + self.item_size], rva)

# read data
Expand Down Expand Up @@ -217,6 +218,7 @@ def __init__(self, data: Union[bytes, HeapItemBinary], encoding="utf-16"):
elif isinstance(data, HeapItemBinary):
HeapItemBinary.__init__(self, data.to_bytes())

buf = self.to_bytes()[self.item_size.raw_size:]
if self.item_size % 2 == 1:
# > This final byte holds the value 1 if and only if any UTF16
# > character within the string has any bit set in its top byte,
Expand All @@ -226,7 +228,6 @@ def __init__(self, data: Union[bytes, HeapItemBinary], encoding="utf-16"):
# via ECMA-335 6th edition, II.24.2.4
#
# Trim this trailing flag, which is not part of the string.
buf = self.value
self.flag = buf[-1]
str_buf = buf[:-1]
if self.flag == 0x00:
Expand All @@ -246,10 +247,10 @@ def __init__(self, data: Union[bytes, HeapItemBinary], encoding="utf-16"):
# these strings are probably best interpreted as bytes.
pass
else:
logger.warning("unexpected string flag value: 0x%02x", flag)
logger.warning(f"unexpected string flag value: 0x{self.flag:02x}")
else:
logger.warning("string missing trailing flag")
str_buf = self.value
str_buf = buf

self.value = str_buf.decode(encoding)

Expand All @@ -264,11 +265,12 @@ def __len__(self):


class UserStringHeap(BinaryHeap):
def get(self, index) -> Optional[bytes]:
data = super(UserStringHeap, self).get(index)
if data is None:
def get_bytes(self, index) -> Optional[bytes]:
bin_item = super(UserStringHeap, self).get(index)
if bin_item is None:
return None

data = bin_item.to_bytes()[bin_item.item_size.raw_size:]
flag: int = 0
if len(data) % 2 == 1:
# > This final byte holds the value 1 if and only if any UTF16
Expand Down Expand Up @@ -304,7 +306,7 @@ def get(self, index) -> Optional[bytes]:

return data

def get_us(self, index, encoding="utf-16") -> Optional[UserString]:
def get(self, index, encoding="utf-16") -> Optional[UserString]:
bin_item = super().get(index)
if bin_item is None:
return None
Expand Down Expand Up @@ -365,7 +367,7 @@ def get_str(self, index, as_bytes=False):
parts[0], parts[1], parts[2], part3, part4
)

def get(self, index) -> HeapItemGuid:
def get(self, index) -> Optional[HeapItemGuid]:
if index is None or index < 1:
return None

Expand Down
2 changes: 1 addition & 1 deletion tests/test_invalid_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_duplicate_stream():
dn = dnfile.dnPE(path)

assert b"#US" in dn.net.metadata.streams
assert dn.net.user_strings.get_us(1).value == "BBBBBBBB"
assert dn.net.user_strings.get(1).value == "BBBBBBBB"


def test_unknown_stream():
Expand Down
12 changes: 6 additions & 6 deletions tests/test_invalid_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def test_unpaired_surrogate():
assert dn.net.user_strings is not None

assert b"#US" in dn.net.metadata.streams
assert dn.net.user_strings.get(1) == b"\xD0\xDD"
assert dn.net.user_strings.get_bytes(1) == b"\xD0\xDD"
with pytest.raises(UnicodeDecodeError):
assert dn.net.user_strings.get_us(1)
assert dn.net.user_strings.get(1)


def test_raw_binary():
Expand All @@ -30,10 +30,10 @@ def test_raw_binary():

# short MZ header
assert b"#US" in dn.net.metadata.streams
assert dn.net.user_strings.get(1) == b"\x4D\x5A\x90\x00"
assert dn.net.user_strings.get_bytes(1) == b"\x4D\x5A\x90\x00"

# somehow this is valid utf-16
s = dn.net.user_strings.get_us(1)
s = dn.net.user_strings.get(1)
assert s is not None
assert s.value == b"\x4D\x5A\x90\x00".decode("utf-16")

Expand All @@ -49,8 +49,8 @@ def test_string_decoder():

# "Hello World" ^ 0xFF
assert b"#US" in dn.net.metadata.streams
assert dn.net.user_strings.get(1) == b"\xb7\xff\x9a\xff\x93\xff\x93\xff\x90\xff\xdf\xff\xa8\xff\x90\xff\x8d\xff\x93\xff\x9b\xff"
assert dn.net.user_strings.get_bytes(1) == b"\xb7\xff\x9a\xff\x93\xff\x93\xff\x90\xff\xdf\xff\xa8\xff\x90\xff\x8d\xff\x93\xff\x9b\xff"

# somehow this is valid utf-16
s = dn.net.user_strings.get_us(1)
s = dn.net.user_strings.get(1)
assert s is not None

0 comments on commit fdc1204

Please sign in to comment.