From 773f74b7d93e73be9e0f73410fc28c1cd38f394a Mon Sep 17 00:00:00 2001 From: malwarefrank <42877127+malwarefrank@users.noreply.github.com> Date: Sat, 23 Mar 2024 03:54:34 +0000 Subject: [PATCH] make HeapItem classes more consistent, remove sequence compatibility, and add more comment blocks --- src/dnfile/base.py | 48 +++++++- src/dnfile/stream.py | 207 ++++++++++++---------------------- tests/test_invalid_strings.py | 19 ++-- tests/test_parse.py | 19 ++-- 4 files changed, 135 insertions(+), 158 deletions(-) diff --git a/src/dnfile/base.py b/src/dnfile/base.py index 3198163..bb0672e 100644 --- a/src/dnfile/base.py +++ b/src/dnfile/base.py @@ -10,7 +10,7 @@ import logging import functools as _functools import itertools as _itertools -from typing import TYPE_CHECKING, Dict, List, Type, Tuple, Union, Generic, TypeVar, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Type, Tuple, Union, Generic, TypeVar, Optional, Sequence from pefile import Structure @@ -118,34 +118,70 @@ def get_dword_at_rva(self, rva): class HeapItem(abc.ABC): + """ + HeapItem is a base class for items retrieved from any of the + heap streams, for example #Strings, #US, #GUID, and #Blob. + + It can be used to access the raw underlying data, the RVA + from which it was retrieved, an optional interpreted value, + and the bytes representation of the value. + + Each heap stream .get() call returns a subclass with these + and optionally additional members. + """ + rva: Optional[int] = None # original data from file __data__: bytes # interpreted value - value: Optional[bytes] = None + value: Any = None def __init__(self, data: bytes, rva: Optional[int] = None): self.rva = rva self.__data__ = data - def to_bytes(self): + def value_bytes(self): + """ + Return the raw bytes underlying the interpreted value. + + For the base HeapItem, this is the same as the raw_data. + """ return self.__data__ @property def raw_size(self): + """ + Number of bytes read from the stream, including any header, + value, and footer. + """ return len(self.__data__) + @property + def raw_data(self): + """ + The bytes read from the stream, including any header, + value, and footer + """ + return self.__data__ + def __eq__(self, other): + """ + Two HeapItems are equal if their raw data is the same or their + interpreted values are the same and not Noney. + + A HeapItem is equal to a bytes object if the HeapItem's value as bytes + is equal to the bytes object. + """ if isinstance(other, HeapItem): - return self.to_bytes() == other.to_bytes() or (self.value is not None and self.value == other.value) + return self.raw_data == other.raw_data or (self.value is not None and self.value == other.value) elif isinstance(other, bytes): - return self.to_bytes() == other + return self.value_bytes() == other return False class ClrHeap(ClrStream): @abc.abstractmethod - def get(self, index): + def get(self, index: int): raise NotImplementedError() diff --git a/src/dnfile/stream.py b/src/dnfile/stream.py index 19b36d9..38e6d2b 100644 --- a/src/dnfile/stream.py +++ b/src/dnfile/stream.py @@ -13,7 +13,6 @@ import struct as _struct import logging -import collections as _collections from typing import Dict, List, Tuple, Union, Optional from binascii import hexlify as _hexlify @@ -32,33 +31,44 @@ class GenericStream(base.ClrStream): pass -class HeapItemString(base.HeapItem, _collections.abc.Sequence): +class HeapItemString(base.HeapItem): + """ + A HeapItemString is a HeapItem with an encoding. The .value member + is the decoded string or None if there was a UnicodeDecodeError. + + A HeapItemString can be compared directly to a str. + """ encoding: Optional[str] - def __init__(self, data, encoding="utf-8"): - super().__init__(data) + def __init__(self, data: bytes, rva: Optional[int] = None, encoding="utf-8"): + super().__init__(data, rva=rva) self.encoding = encoding - self.value: str = self.__data__.decode(encoding) + try: + self.value: Optional[str] = self.__data__.decode(encoding) + except UnicodeDecodeError as e: + self.value = None def __str__(self) -> str: - return self.value + return self.value or "" def __eq__(self, other): if isinstance(other, str): - return str(self) == other + return self.value == other return super().__eq__(other) - def __getitem__(self, i): - return self.value[i] - - def __len__(self): - return len(self.value) +class HeapItemBinary(base.HeapItem): + """ + A HeapItemBinary is a HeapItem with an item_size. The .item_size + is the parsed compressed integer at the RVA in the binary heap. + The .value is the bytes following the compressed integer. -class HeapItemBinary(base.HeapItem, _collections.abc.Sequence): + A HeapItemBinary can be compared directly to a bytes object. + """ item_size: base.CompressedInt def __init__(self, data: bytes, rva: Optional[int] = None): + self.rva = rva # read compressed int, which has a max size of four bytes size = base.CompressedInt.read(data[:4], rva) if size is None: @@ -70,6 +80,9 @@ def __init__(self, data: bytes, rva: Optional[int] = None): offset = self.item_size.raw_size self.value = self.__data__[offset:offset + self.item_size] + def value_bytes(self): + return self.__data__[self.item_size.raw_size:] + def __eq__(self, other): if isinstance(other, base.HeapItem): return base.HeapItem.__eq__(self, other) @@ -77,12 +90,6 @@ def __eq__(self, other): return self.value == other return False - def __getitem__(self, i): - return self.value[i] - - def __len__(self): - return len(self.value) - class StringsHeap(base.ClrHeap): offset_size = 0 @@ -93,21 +100,15 @@ def get_str(self, index, max_length=MAX_STRING_LENGTH, encoding="utf-8", as_byte Returns None on error, or string, or bytes if as_bytes is True. """ - if not self.__data__ or index is None or not max_length: - return None - - if index >= len(self.__data__): - raise IndexError("index out of range") + item = self.get(index, max_length, encoding) - offset = index - end = self.__data__.find(b"\x00", offset) - if end - offset > max_length: + if item is None: return None - data = self.__data__[offset:end] + if as_bytes: - return data - s = data.decode(encoding) - return s + return item.value_bytes() + + return item.value def get(self, index, max_length=MAX_STRING_LENGTH, encoding="utf-8") -> Optional[HeapItemString]: """ @@ -125,54 +126,35 @@ def get(self, index, max_length=MAX_STRING_LENGTH, encoding="utf-8") -> Optional if end - offset > max_length: return None - item = HeapItemString(self.__data__[offset:end], encoding) - item.rva = self.rva + offset + item = HeapItemString(self.__data__[offset:end], rva=self.rva + offset, encoding=encoding) return item class BinaryHeap(base.ClrHeap): - def get_with_size(self, index) -> Optional[Tuple[bytes, int]]: - if self.__data__ is None: - logger.warning("stream has no data") - return None - - if index >= len(self.__data__): - logger.warning("stream is too small: wanted: 0x%x found: 0x%x", index, len(self.__data__)) + def get_with_size(self, index: int) -> Optional[Tuple[bytes, int]]: + try: + item = self.get(index) + except IndexError: return None - offset = index - - # read compressed int length - buf = self.__data__[offset:offset + 4] - ret = read_compressed_int(buf) - if ret is None: - # invalid compressed int length, such as invalid leading flags. - logger.warning("stream entry has invalid compressed int") + if item is None: return None - data_length, length_size = ret + return item.value_bytes(), item.raw_size - # read data - offset = offset + length_size - data = self.__data__[offset:offset + data_length] - - return data, length_size + data_length - - def get_bytes(self, index) -> Optional[bytes]: + def get_bytes(self, index: int) -> Optional[bytes]: try: - ret = self.get_with_size(index) + item = self.get(index) except IndexError: return None - if ret is None: + if item is None: return None - data, _ = ret + return item.value_bytes() - return data - - def get(self, index) -> Optional[HeapItemBinary]: + def get(self, index: int) -> Optional[HeapItemBinary]: if self.__data__ is None: logger.warning("stream has no data") return None @@ -184,14 +166,12 @@ def get(self, index) -> Optional[HeapItemBinary]: offset = index try: - item = HeapItemBinary(self.__data__[index:]) + item = HeapItemBinary(self.__data__[index:], rva=self.rva + offset) except ValueError as e: # possible invalid compressed int length, such as invalid leading flags. - logger.warning(f"stream entry error - {e} @ RVA=0x{hex(self.rva)}") + logger.warning(f"stream entry error - {e} @ RVA=0x{hex(self.rva + offset)}") return None - item.rva = self.rva + offset - return item @@ -199,7 +179,7 @@ class BlobHeap(BinaryHeap): pass -class UserString(HeapItemBinary, HeapItemString, _collections.abc.Sequence): +class UserString(HeapItemBinary, HeapItemString): """ The #US or UserStrings stream should contain UTF-16 strings. Each entry in the stream includes a byte indicating whether @@ -211,14 +191,14 @@ class UserString(HeapItemBinary, HeapItemString, _collections.abc.Sequence): flag: Optional[int] = None - def __init__(self, data: Union[bytes, HeapItemBinary], encoding="utf-16"): + def __init__(self, data: Union[bytes, HeapItemBinary], rva: Optional[int] = None, encoding="utf-16"): self.encoding = encoding if isinstance(data, bytes): - HeapItemBinary.__init__(self, data) + HeapItemBinary.__init__(self, data, rva=rva) elif isinstance(data, HeapItemBinary): - HeapItemBinary.__init__(self, data.to_bytes()) + HeapItemBinary.__init__(self, data.raw_data, rva=rva or data.rva) - buf = self.to_bytes()[self.item_size.raw_size:] + buf = self.__data__[self.item_size.raw_size:] if self.item_size % 2 == 1: # > This final byte holds the value 1 if and only if any UTF16 # > character within the string has any bit set in its top byte, @@ -252,59 +232,29 @@ def __init__(self, data: Union[bytes, HeapItemBinary], encoding="utf-16"): logger.warning("string missing trailing flag") str_buf = buf - self.value = str_buf.decode(encoding) + try: + self.value = str_buf.decode(encoding) + except UnicodeDecodeError as e: + logger.warning(f"UserString decode error (rva:0x{self.rva:08x}): {e}") + self.value = None + + def value_bytes(self): + if self.flag is None: + return self.__data__[self.item_size.raw_size:] + return self.__data__[self.item_size.raw_size:-1] def __eq__(self, other): return HeapItemString.__eq__(self, other) - def __getitem__(self, i): - return self.value[i] - - def __len__(self): - return len(self.value) - class UserStringHeap(BinaryHeap): def get_bytes(self, index) -> Optional[bytes]: - bin_item = super(UserStringHeap, self).get(index) - if bin_item is None: - return None + item = self.get(index) - data = bin_item.to_bytes()[bin_item.item_size.raw_size:] - flag: int = 0 - if len(data) % 2 == 1: - # > This final byte holds the value 1 if and only if any UTF16 - # > character within the string has any bit set in its top byte, - # > or its low byte is any of the following: - # > 0x01ā€“0x08, 0x0Eā€“0x1F, 0x27, 0x2D, 0x7F. - # - # via ECMA-335 6th edition, II.24.2.4 - # - # Trim this trailing flag, which is not part of the string. - flag = data[-1] - if flag == 0x00: - # > Otherwise, it holds 0. - # - # via ECMA-335 6th edition, II.24.2.4 - # - # *should* be a normal UTF-16 string, but still not - # make sense. - pass - elif flag == 0x01: - # > The 1 signifies Unicode characters that require handling - # > beyond that normally provided for 8-bit encoding sets. - # - # via ECMA-335 6th edition, II.24.2.4 - # - # these strings are probably best interpreted as bytes. - pass - else: - logger.warning("unexpected string flag value: 0x%02x", flag) - data = data[:-1] - else: - logger.warning("string missing trailing flag") + if item is None: + return None - return data + return item.value_bytes() def get(self, index, encoding="utf-16") -> Optional[UserString]: bin_item = super().get(index) @@ -312,8 +262,6 @@ def get(self, index, encoding="utf-16") -> Optional[UserString]: return None us_item = UserString(bin_item, encoding=encoding) - us_item.rva = bin_item.rva - us_item.item_size = bin_item.item_size return us_item @@ -344,28 +292,15 @@ class GuidHeap(base.ClrHeap): offset_size = 0 def get_str(self, index, as_bytes=False): - if index is None or index < 1: - return None + item = self.get(index) - size = 128 // 8 # number of bytes in a guid - # offset into the GUID stream - offset = (index - 1) * size - - if offset + size > len(self.__data__): - raise IndexError("index out of range") + if item is None: + return None - data = self.__data__[offset:offset + size] if as_bytes: - return data - # convert to string - parts = _struct.unpack_from(" Optional[HeapItemGuid]: if index is None or index < 1: diff --git a/tests/test_invalid_strings.py b/tests/test_invalid_strings.py index dcfb462..ee326d7 100644 --- a/tests/test_invalid_strings.py +++ b/tests/test_invalid_strings.py @@ -14,9 +14,11 @@ def test_unpaired_surrogate(): assert dn.net.user_strings is not None assert b"#US" in dn.net.metadata.streams - assert dn.net.user_strings.get_bytes(1) == b"\xD0\xDD" - with pytest.raises(UnicodeDecodeError): - assert dn.net.user_strings.get(1) + item = dn.net.user_strings.get(1) + assert item is not None + assert item.flag == 0x01 + assert item.value_bytes() == b"\xD0\xDD" + assert item.value is None def test_raw_binary(): @@ -30,12 +32,11 @@ def test_raw_binary(): # short MZ header assert b"#US" in dn.net.metadata.streams - assert dn.net.user_strings.get_bytes(1) == b"\x4D\x5A\x90\x00" - - # somehow this is valid utf-16 s = dn.net.user_strings.get(1) assert s is not None + # somehow this is valid utf-16 assert s.value == b"\x4D\x5A\x90\x00".decode("utf-16") + assert s.value_bytes() == b"\x4D\x5A\x90\x00" def test_string_decoder(): @@ -49,7 +50,11 @@ def test_string_decoder(): # "Hello World" ^ 0xFF assert b"#US" in dn.net.metadata.streams - assert dn.net.user_strings.get_bytes(1) == b"\xb7\xff\x9a\xff\x93\xff\x93\xff\x90\xff\xdf\xff\xa8\xff\x90\xff\x8d\xff\x93\xff\x9b\xff" + item = dn.net.user_strings.get(1) + assert item is not None + assert item.raw_data == b"\x17\xb7\xff\x9a\xff\x93\xff\x93\xff\x90\xff\xdf\xff\xa8\xff\x90\xff\x8d\xff\x93\xff\x9b\xff\x01" + assert item.flag == 0x01 + assert item.value_bytes() == b"\xb7\xff\x9a\xff\x93\xff\x93\xff\x90\xff\xdf\xff\xa8\xff\x90\xff\x8d\xff\x93\xff\x9b\xff" # somehow this is valid utf-16 s = dn.net.user_strings.get(1) diff --git a/tests/test_parse.py b/tests/test_parse.py index 0e4955b..318e8a8 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -297,7 +297,8 @@ def test_heap_items(): # HeapItem buf = b"1234567890" item = dnfile.base.HeapItem(b"1234567890") - assert item.to_bytes() == b"1234567890" + assert item.raw_data == b"1234567890" + assert item.value_bytes() == buf assert item == b"1234567890" assert item.raw_size == len(b"1234567890") assert item.rva is None @@ -310,36 +311,36 @@ def test_heap_items(): buf = b"1234567890" buf_decoded = buf.decode("utf-8") str_item = dnfile.stream.HeapItemString(buf) - assert str_item.to_bytes() == buf + assert str_item.raw_data == buf assert str_item.encoding == "utf-8" assert str_item.value == buf_decoded + assert str_item.value_bytes() == buf assert str_item == buf_decoded - assert len(str_item) == len(buf_decoded) - assert str_item[5] == buf_decoded[5] # HeapItemBinary buf_with_compressed_int_len = b"\x0a1234567890" buf = b"1234567890" bin_item = dnfile.stream.HeapItemBinary(buf_with_compressed_int_len) - assert bin_item.to_bytes() == buf_with_compressed_int_len + assert bin_item.raw_data == buf_with_compressed_int_len assert bin_item.raw_size == len(buf_with_compressed_int_len) assert bin_item.item_size == len(buf) + assert bin_item.value_bytes() == buf assert bin_item == buf # UserString buf_with_flag = b"\x151\x002\x003\x004\x005\x006\x007\x008\x009\x000\x00\x00" us_str = "1234567890" us_item = dnfile.stream.UserString(buf_with_flag) - assert us_item.to_bytes() == buf_with_flag + assert us_item.raw_data == buf_with_flag + assert us_item.value_bytes() == buf_with_flag[1:-1] assert us_item.value == us_str assert us_item == us_str - assert len(us_item) == len(us_str) - assert us_item[5] == us_str[5] # GUID guid_bytes = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" guid_item = dnfile.stream.HeapItemGuid(guid_bytes) assert guid_item is not None - assert guid_item.to_bytes() == guid_bytes + assert guid_item.raw_data == guid_bytes + assert guid_item.value_bytes() == guid_bytes assert guid_item.value == guid_bytes assert str(guid_item) == "03020100-0504-0706-0809-0a0b0c0d0e0f"