diff --git a/floss/api_hooks.py b/floss/api_hooks.py index 318b9363d..0c549eb28 100644 --- a/floss/api_hooks.py +++ b/floss/api_hooks.py @@ -33,24 +33,25 @@ class ApiMonitor(viv_utils.emulator_drivers.Monitor): - """ - The ApiMonitor observes emulation and cleans up API function returns. - """ + """The ApiMonitor observes emulation and cleans up API function returns.""" def __init__(self, function_index): self.function_index = function_index super().__init__() def apicall(self, emu, api, argv): + """Log API calls and their arguments.""" pc = emu.getProgramCounter() logger.trace("apicall: 0x%x %s %s", pc, api, argv) def prehook(self, emu, op, startpc): + """Log the start of an instruction.""" # overridden from Monitor # helpful for debugging decoders, but super verbose! logger.trace("prehook: 0x%x %s", startpc, op) def posthook(self, emu, op, endpc): + """Log the end of an instruction.""" # overridden from Monitor if op.mnem == "ret": try: @@ -60,11 +61,14 @@ def posthook(self, emu, op, endpc): # TODO remove stack fixes? works sometimes, but does it add value? def _check_return(self, emu, op): - """ - Ensure that the target of the return is within the allowed set of functions. + """Ensure that the target of the return is within the allowed set of functions. Do nothing, if return address is valid. If return address is invalid: _fix_return modifies program counter and stack pointer if a valid return address is found on the stack or raises an Exception if no valid return address is found. + + Args: + emu: The emulator. + op: The opcode. """ function_start = self.function_index[op.va] return_addresses = self._get_return_vas(emu, function_start) @@ -85,8 +89,14 @@ def _check_return(self, emu, op): logger.trace("Return address 0x%08x is valid, returning", return_address) def _get_return_vas(self, emu, function_start): - """ - Get the list of valid addresses to which a function should return. + """Get the list of valid addresses to which a function should return. + + Args: + emu: The emulator. + function_start: The start address of the function. + + Returns: + A set of valid return addresses. """ return_vas = set([]) callers = emu.vw.getCallers(function_start) @@ -97,10 +107,17 @@ def _get_return_vas(self, emu, function_start): return return_vas def _fix_return(self, emu, return_address, return_addresses): - """ - Find a valid return address from return_addresses on the stack. Adjust the stack accordingly + """Find a valid return address from return_addresses on the stack. Adjust the stack accordingly or raise an Exception if no valid address is found within the search boundaries. Modify program counter and stack pointer, so the emulator does not return to a garbage address. + + Args: + emu: The emulator. + return_address: The return address. + return_addresses: The set of valid return addresses. + + Raises: + Exception: If no valid return address is found. """ fu.dump_stack(emu) NUM_ADDRESSES = 4 @@ -121,14 +138,21 @@ def _fix_return(self, emu, return_address, return_addresses): class DemoHook: + """A demo hook to demonstrate the API of the hook classes.""" + def __call__( - self, emu: viv_utils.emulator_drivers.EmulatorDriver, api: Tuple[str, Any, str, str, List], argv: List + self, + emu: viv_utils.emulator_drivers.EmulatorDriver, + api: Tuple[str, Any, str, str, List], + argv: List, ): # api: (rettype, retname, callconv, funcname, [(argtype, argname), ...)] ... class GetProcessHeapHook: + """Hook calls to GetProcessHeap and return a fake heap handle.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("GetProcessHeap",)): fu.call_return(emu, api, argv, 42) @@ -136,6 +160,8 @@ def __call__(self, emu, api, argv): class GetModuleFileNameHook: + """Hook calls to GetModuleFileName and return the name of the current module.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("GetModuleFileNameA",)): unicode = False @@ -166,9 +192,7 @@ def __call__(self, emu, api, argv): class MemoryAllocationHook: - """ - Hook calls to memory allocation functions: allocate memory and return pointer to this memory. - """ + """Hook calls to memory allocation functions: allocate memory and return pointer to this memory.""" _heap_addr = HEAP_BASE @@ -176,6 +200,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _allocate_mem(self, emu, size): + """Allocate memory and return the address of the allocated memory. The memory is initialized with zeros. + + Args: + emu: The emulator. + size: The size of the memory to allocate. + + Returns: + The address of the allocated memory. + """ va = self._heap_addr # align to 16-byte boundary (64-bit), also works for 32-bit, which is normally 8-bytes size = fu.round_(size, 16) @@ -205,8 +238,7 @@ def __call__(self, emu, api, argv): class CppNewObjectHook(MemoryAllocationHook): - """ - Hook calls to: + """Hook calls to: - C++ new operator Thanks to @BenjaminSoelberg """ @@ -233,6 +265,8 @@ def __call__(self, emu, api, argv): class MemoryFreeHook: + """Hook calls to memory free functions: free memory and return success.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("free", "free_base", "VirtualFree", "HeapFree", "RtlFreeHeap")): # If the function succeeds, the return value is nonzero. @@ -241,6 +275,8 @@ def __call__(self, emu, api, argv): class MemcpyHook: + """Hook calls to memory copy functions: copy memory from source to destination.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("memcpy", "memmove")): dst, src, count = argv @@ -257,6 +293,8 @@ def __call__(self, emu, api, argv): class StrlenHook: + """Hook calls to string length functions: return the length of the string.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("strlen", "lstrlena")): string_va = argv[0] @@ -276,6 +314,8 @@ def __call__(self, emu, api, argv): class StrncmpHook: + """Hook calls to string compare functions: compare two strings.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("strncmp",)): s1va, s2va, num = argv @@ -284,6 +324,15 @@ def __call__(self, emu, api, argv): s2 = fu.readStringAtRva(emu, s2va, maxsize=num) def cmp(a, b): + """Compare two strings. + + Args: + a: The first string. + b: The second string. + + Returns: + int: -1 if a < b, 0 if a == b, 1 if a > b. + """ return (a > b) - (a < b) result = cmp(s1, s2) @@ -292,6 +341,8 @@ def cmp(a, b): class MemchrHook: + """Hook calls to memchr: search for a character in a memory block.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("memchr",)): ptr, value, num = argv @@ -307,6 +358,8 @@ def __call__(self, emu, api, argv): class MemsetHook: + """Hook calls to memset: fill memory with a constant byte.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("memset",)): ptr, value, num = argv @@ -318,6 +371,8 @@ def __call__(self, emu, api, argv): class PrintfHook: + """Hook calls to printf: write formatted data to stdout.""" + # TODO disabled for now as incomplete (need to implement string format) and could result in FP strings as is def __call__(self, emu, api, argv): # TODO vfprintf, vfwprintf, vfprintf_s, vfwprintf_s, vsnprintf, vsnwprintf, etc. @@ -330,6 +385,8 @@ def __call__(self, emu, api, argv): class ExitExceptionHook: + """Hook calls to exit and raise exception.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("ExitProcess", "RaiseException")): raise viv_utils.emulator_drivers.StopEmulation() @@ -340,14 +397,27 @@ def __call__(self, emu, api, argv): class SehPrologEpilogHook: + """Hook calls to SEH prolog and epilog functions and return success.""" + def __call__(self, emu, api, argv): - if fu.contains_funcname(api, ("__EH_prolog", "__EH_prolog3", "__SEH_prolog4", "seh4_prolog", "__SEH_epilog4")): + if fu.contains_funcname( + api, + ( + "__EH_prolog", + "__EH_prolog3", + "__SEH_prolog4", + "seh4_prolog", + "__SEH_epilog4", + ), + ): # nop fu.call_return(emu, api, argv, 0) return True class SecurityCheckCookieHook: + """Hook calls to __security_check_cookie and return success.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("__security_check_cookie", "@__security_check_cookie@4")): # nop @@ -356,6 +426,8 @@ def __call__(self, emu, api, argv): class GetLastErrorHook: + """Hook calls to GetLastError and return success.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("GetLastError",)): # always assuming success @@ -365,6 +437,8 @@ def __call__(self, emu, api, argv): class GetCurrentProcessHook: + """Hook calls to GetCurrentProcess and return a fake process handle.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("GetCurrentProcess",)): fu.call_return(emu, api, argv, CURRENT_PROCESS_ID) @@ -372,6 +446,8 @@ def __call__(self, emu, api, argv): class CriticalSectionHook: + """Hook calls to InitializeCriticalSection and return a fake critical section handle.""" + def __call__(self, emu, api, argv): if fu.contains_funcname(api, ("InitializeCriticalSection",)): (hsection,) = argv @@ -403,14 +479,16 @@ def __call__(self, emu, api, argv): @contextlib.contextmanager def defaultHooks(driver): - """ - Install and remove the default set of hooks to handle common functions. + """Install and remove the default set of hooks to handle common functions. intended usage: with defaultHooks(driver): driver.runFunction() ... + + Args: + driver: The emulator driver. """ try: for hook in DEFAULT_HOOKS: diff --git a/floss/decoding_manager.py b/floss/decoding_manager.py index 02e951c65..1c5841077 100644 --- a/floss/decoding_manager.py +++ b/floss/decoding_manager.py @@ -19,8 +19,14 @@ def is_import(emu, va): - """ - Return True if the given VA is that of an imported function. + """Check if the given address is an import. + + Args: + emu: The emulator. + va: The address to check. + + Returns: + bool: True if the address is an import, False otherwise. """ # TODO: also check location type t = emu.getVivTaint(va) @@ -59,8 +65,7 @@ def is_import(emu, va): @dataclass class Snapshot: - """ - A snapshot of the state of the CPU and memory. + """A snapshot of the state of the CPU and memory. Attributes: memory: a snapshot of the memory contents @@ -74,6 +79,14 @@ class Snapshot: def get_map_size(emu): + """Get the total size of all memory maps in the emulator. + + Args: + emu: The emulator. + + Returns: + int: The total size of all memory maps. + """ size = 0 for mapva, mapsize, mperm, mfname in emu.getMemoryMaps(): mapsize += size @@ -81,12 +94,19 @@ def get_map_size(emu): class MapsTooLargeError(Exception): + """Exception raised when the emulator has mapped too much memory.""" + pass def make_snapshot(emu: Emulator) -> Snapshot: - """ - Create a snapshot of the current CPU and memory. + """Create a snapshot of the current CPU and memory. + + Args: + emu: The emulator. + + Returns: + Snapshot: The snapshot of the emulator state. """ if get_map_size(emu) > MAX_MAPS_SIZE: logger.debug("emulator mapped too much memory: 0x%x", get_map_size(emu)) @@ -96,9 +116,9 @@ def make_snapshot(emu: Emulator) -> Snapshot: @dataclass class Delta: - """ - a pair of snapshots from before and after an operation. - facilitates diffing the state of an emulator. + """a pair of snapshots from before and after an operation. + + Facilitates diffing the state of an emulator. """ pre: Snapshot @@ -106,9 +126,7 @@ class Delta: class DeltaCollectorHook(viv_utils.emulator_drivers.Hook): - """ - hook that collects Deltas at each imported API call. - """ + """hook that collects Deltas at each imported API call.""" def __init__(self, pre_snap: Snapshot): super().__init__() @@ -127,15 +145,21 @@ def __call__(self, emu, api, argv): self.deltas.append(Delta(self._pre_snap, make_snapshot(emu))) except MapsTooLargeError: _, _, _, name, _ = api - logger.debug("despite call to import %s, maps too large, not extracting strings", name) + logger.debug( + "despite call to import %s, maps too large, not extracting strings", + name, + ) pass def emulate_function( - emu: Emulator, function_index, fva: int, return_address: int, max_instruction_count: int + emu: Emulator, + function_index, + fva: int, + return_address: int, + max_instruction_count: int, ) -> List[Delta]: - """ - Emulate a function and collect snapshots at each interesting place. + """Emulate a function and collect snapshots at each interesting place. These interesting places include calls to imported API functions and the final state of the emulator. Emulation continues until the return address is hit, or @@ -148,12 +172,15 @@ def emulate_function( - AllocateHeap - malloc - :type function_index: viv_utils.FunctionIndex - :param fva: The start address of the function to emulate. - :param return_address: The expected return address of the function. - Emulation stops here. - :param max_instruction_count: The max number of instructions to emulate. - This helps avoid unexpected infinite loops. + Args: + emu: The emulator. + function_index: The index of the function to emulate. + fva: The address of the function to emulate. + return_address: The address to stop emulation at. + max_instruction_count: The maximum number of instructions to emulate. + + Returns: + List[Delta]: A list of Deltas representing the emulator state at each interesting place. """ try: pre_snap = make_snapshot(emu) @@ -166,7 +193,10 @@ def emulate_function( try: logger.debug("Emulating function at 0x%08x", fva) driver = viv_utils.emulator_drivers.DebuggerEmulatorDriver( - emu, repmax=256, max_hit=DS_MAX_ADDRESS_REVISITS_EMULATION, max_insn=max_instruction_count + emu, + repmax=256, + max_hit=DS_MAX_ADDRESS_REVISITS_EMULATION, + max_insn=max_instruction_count, ) monitor = api_hooks.ApiMonitor(function_index) driver.add_monitor(monitor) @@ -180,11 +210,20 @@ def emulate_function( if e.reason == "max_insn": logger.debug("Halting as emulation has escaped!") except envi.InvalidInstruction as e: - logger.debug("vivisect encountered an invalid instruction. will continue processing. %s", e) + logger.debug( + "vivisect encountered an invalid instruction. will continue processing. %s", + e, + ) except envi.UnsupportedInstruction as e: - logger.debug("vivisect encountered an unsupported instruction. will continue processing. %s", e) + logger.debug( + "vivisect encountered an unsupported instruction. will continue processing. %s", + e, + ) except envi.BreakpointHit as e: - logger.debug("vivisect encountered an unexpected emulation breakpoint. will continue processing. %s", e) + logger.debug( + "vivisect encountered an unexpected emulation breakpoint. will continue processing. %s", + e, + ) except envi.exc.SegmentationViolation as e: tos_val = floss.utils.get_stack_value(emu, 0) logger.debug("%s: top of stack (return address): 0x%x", e, tos_val) @@ -194,7 +233,10 @@ def emulate_function( pass except Exception: # we cheat here a bit and skip over various errors, check this for improvements and debugging - logger.debug("vivisect encountered an unexpected exception. will continue processing.", exc_info=True) + logger.debug( + "vivisect encountered an unexpected exception. will continue processing.", + exc_info=True, + ) logger.debug("Ended emulation at 0x%08x", emu.getProgramCounter()) deltas = delta_collector.deltas diff --git a/floss/features/extract.py b/floss/features/extract.py index 06fc91068..c58ca6e47 100644 --- a/floss/features/extract.py +++ b/floss/features/extract.py @@ -35,9 +35,15 @@ def extract_insn_nzxor(f, bb, insn): - """ - parse non-zeroing XOR instruction from the given instruction. - ignore expected non-zeroing XORs, e.g. security cookies. + """Analyzes a given instruction within a function's basic block to identify non-zeroing XOR operations that are not associated with security cookie checks. + + Args: + f: The current function being analyzed. + bb: The basic block that contains the instruction. + insn: The specific instruction to analyze. + + Returns: + Nzxor: Yields a Nzxor feature if a relevant XOR instruction is found. """ if insn.opcode != INS_XOR: return @@ -52,8 +58,15 @@ def extract_insn_nzxor(f, bb, insn): def is_security_cookie(f, bb, insn) -> bool: - """ - check if an instruction is related to security cookie checks + """Determines if the given instruction is related to security cookie checks. + + Args: + f: The function object being analyzed. + bb: The basic block that contains the instruction. + insn: The instruction object to analyze. + + Returns: + bool: True if the instruction is related to security cookie checks, False otherwise. """ # security cookie check should use SP or BP oper = insn.opers[1] @@ -80,11 +93,31 @@ def is_security_cookie(f, bb, insn) -> bool: def extract_insn_shift(f, bb, insn): + """Extracts shift or rotate instructions from the given instruction within a basic block. + + Args: + f: The function object being analyzed. + bb: The basic block containing the instruction. + insn: The instruction object to analyze. + + Returns: + Iterator[Shift]: An iterator over Shift features if shift or rotate instructions are found. + """ if insn.opcode in SHIFT_ROTATE_INS: yield Shift(insn) def extract_insn_mov(f, bb, insn): + """Identifies MOV instructions that write to memory in a given basic block. + + Args: + f: The function object being analyzed. + bb: The basic block containing the instruction. + insn: The instruction object to analyze. + + Returns: + Iterator[Mov]: An iterator over Mov features if relevant MOV instructions are found. + """ # identify register dereferenced writes to memory # mov byte [eax], cl # mov dword [edx], eax @@ -115,13 +148,25 @@ def extract_insn_mov(f, bb, insn): def extract_function_calls_to(f): + """Identifies all function calls within the given function. + + Args: + f: The function object being analyzed. + + Returns: + An iterator over CallsTo features, each representing a call made from the given function. + """ yield CallsTo(f.vw, [x[0] for x in f.vw.getXrefsTo(f.va, rtype=vivisect.const.REF_CODE)]) def extract_function_kinda_tight_loop(f): - """ - Yields tight loop features in the provided function - Algorithm by Blaine S. + """Identifies tight and kinda tight loops within the provided function using a specific algorithm. + + Args: + f: The function object to analyze for loop structures. + + Returns: + An iterator over TightLoop or KindaTightLoop features identified within the function. """ try: cfg = viv_utils.CFG(f) @@ -213,6 +258,15 @@ def extract_function_kinda_tight_loop(f): def skip_tightloop(bb: BasicBlock, loop_bb: BasicBlock) -> bool: + """Determines whether a tight loop should be skipped based on the presence of function calls and memory writes. + + Args: + bb: The basic block being analyzed. + loop_bb: The loop basic block to compare against. + + Returns: + True if the loop should be skipped, otherwise False. + """ # ignore tight loops that call other functions if contains_call(bb) or contains_call(loop_bb): return True @@ -225,6 +279,14 @@ def skip_tightloop(bb: BasicBlock, loop_bb: BasicBlock) -> bool: def contains_call(bb): + """Checks if the given basic block contains any call instructions. + + Args: + bb: The basic block to inspect. + + Returns: + True if a call instruction is found, otherwise False. + """ for insn in bb.instructions: if insn.opcode == INS_CALL: return True @@ -232,6 +294,14 @@ def contains_call(bb): def writes_memory(bb): + """Determines if any instruction within the basic block writes to memory. + + Args: + bb: The basic block to check. + + Returns: + bool: True if at least one instruction writes to memory, False otherwise. + """ for insn in bb.instructions: # don't handle len(ops) == 0 for `rep movsb` or other unexpected instructions if len(insn.opers) < 1: @@ -251,6 +321,16 @@ def writes_memory(bb): def abstract_nzxor_tightloop(features): + """ + Abstracts tight loop patterns with non-zeroing XOR operations within the features. + + Args: + features: A list of features extracted from a function. + + Returns: + An iterator over NzxorTightLoop features for each identified pattern. + """ + for tl in filter(lambda f: isinstance(f, TightLoop), features): for nzxor in filter(lambda f: isinstance(f, Nzxor), features): if tl.startva <= nzxor.insn.va <= tl.endva: @@ -258,13 +338,28 @@ def abstract_nzxor_tightloop(features): def abstract_nzxor_loop(features): + """ + Abstracts loop patterns with non-zeroing XOR operations within the features. + + Args: + features: A list of features extracted from a function. + + Returns: + An iterator over NzxorLoop features for each identified pattern. + """ if any(isinstance(f, Nzxor) for f in features) and any(isinstance(f, Loop) for f in features): yield NzxorLoop() def abstract_tightfunction(features): """ - (Kinda) TightLoop and only a few basic blocks + Abstracts functions that are tight or kinda tight and contain a small number of basic blocks. + + Args: + features: A list of features extracted from a function. + + Returns: + An iterator over TightFunction features for functions meeting the criteria. """ if any(filter(lambda f: isinstance(f, (TightLoop, KindaTightLoop)), features)): for block_count in filter(lambda f: isinstance(f, BlockCount), features): @@ -275,7 +370,13 @@ def abstract_tightfunction(features): def extract_function_loop(f): """ - parse if a function has a loop + Identifies loop structures within a function. + + Args: + f: The function object to analyze. + + Returns: + An iterator over Loop features for each loop structure identified within the function. """ edges = [] @@ -314,6 +415,15 @@ def extract_function_loop(f): def extract_function_features(f): + """ + Extracts various features from a function, including function calls, loops, and tight loops. + + Args: + f: The function object to analyze. + + Returns: + An iterator over various features extracted from the function. + """ for func_handler in FUNCTION_HANDLERS: for feature in func_handler(f): yield feature @@ -324,6 +434,15 @@ def extract_function_features(f): def extract_basic_block_features(f: Any, bb: Any) -> Iterator: + """Extracts features from a given basic block within a function. + + Args: + f: The function object containing the basic block. + bb: The basic block to analyze. + + Returns: + An iterator over features extracted from the basic block. + """ for bb_handler in BASIC_BLOCK_HANDLERS: for feature in bb_handler(f, bb): yield feature @@ -337,6 +456,16 @@ def extract_basic_block_features(f: Any, bb: Any) -> Iterator: def extract_insn_features(f, bb, insn): + """Extracts features from a given instruction within a basic block of a function. + + Args: + f: The function object containing the basic block and instruction. + bb: The basic block containing the instruction. + insn: The instruction to analyze. + + Returns: + An iterator over features extracted from the instruction. + """ for insn_handler in INSTRUCTION_HANDLERS: for feature in insn_handler(f, bb, insn): yield feature @@ -350,6 +479,14 @@ def extract_insn_features(f, bb, insn): def abstract_features(features): + """Abstracts higher-level features from a collection of lower-level features. + + Args: + features: A list of features extracted from a function. + + Returns: + An iterator over abstracted features based on the provided features. + """ for abst_handler in ABSTRACTION_HANDLERS: for feature in abst_handler(features): yield feature diff --git a/floss/features/features.py b/floss/features/features.py index 5f6c460ff..00ea44dc4 100644 --- a/floss/features/features.py +++ b/floss/features/features.py @@ -10,7 +10,19 @@ class Feature: + """A base class for defining features in code analysis, encapsulating common properties and methods. + + Attributes: + name (str): Automatically derived from the class name. + value: The specific value of the feature being analyzed. + """ + def __init__(self, value): + """Initializes the Feature instance. + + Args: + value: The value associated with the feature. + """ super(Feature, self).__init__() self.name = self.__class__.__name__ @@ -18,15 +30,30 @@ def __init__(self, value): @property def weight(self) -> float: + """The importance weight of the feature. Must be implemented by subclasses. + + Raises: + NotImplementedError: If the subclass does not implement this property. + """ # feature weight LOW, MEDIUM, ... (less to more important) raise NotImplementedError def score(self) -> float: + """Calculates a score for the feature based on its value. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ # returns a value between 0.0 and 1.0 (less likely to almost certain) # can be negative to exclude functions based on a feature raise NotImplementedError def weighted_score(self): + """Computes the weighted score of the feature by multiplying its weight with its score. + + Returns: + float: The weighted score of the feature. + """ return self.weight * self.score() def __str__(self): @@ -39,12 +66,28 @@ def __repr__(self): class BlockCount(Feature): + """A feature representing the count of blocks in a function, influencing its analysis score. + + Inherits from Feature. + """ + weight = LOW def __init__(self, block_count): + """Initializes the BlockCount feature with the number of blocks. + + Args: + block_count (int): The count of blocks in the function. + """ super(BlockCount, self).__init__(block_count) def score(self): + """Determines the score based on the block count. Specific ranges of block count + influence the score differently. + + Returns: + float: A score indicating the likelihood of a function being a string decoding function. + """ if self.value > 30: # a function with >30 basic blocks is unlikely a string decoding function return 0.1 @@ -57,12 +100,29 @@ def score(self): class InstructionCount(Feature): + """Represents the instruction count of a function, contributing to its analysis score. + + Attributes: + weight (float): Importance of instruction count, predefined as LOW. + """ + weight = LOW def __init__(self, instruction_count): + """Initializes the InstructionCount feature with the number of instructions. + + Args: + instruction_count (int): The total instruction count in the function. + """ super(InstructionCount, self).__init__(instruction_count) def score(self): + """Calculates the score based on the instruction count. More instructions generally imply a higher likelihood + of the function being significant, up to a point. + + Returns: + float: Score based on the number of instructions. + """ if self.value > 10: return 0.8 else: @@ -70,6 +130,12 @@ def score(self): class Arguments(Feature): + """Represents the number of arguments in a function, affecting its evaluation. + + Attributes: + weight (float): Importance of the argument count, predefined as LOW. + """ + weight = LOW def __init__(self, args): @@ -78,6 +144,11 @@ def __init__(self, args): self.args = args def score(self): + """Scores the feature based on the optimal argument count for identification purposes. + + Returns: + float: Score reflecting the appropriateness of the argument count. + """ if 1 <= self.value <= 4: return 1.0 elif 5 <= self.value <= 6: @@ -87,6 +158,12 @@ def score(self): class TightLoop(Feature): + """Identifies a tight loop within a function, indicating high importance. + + Attributes: + weight (float): Importance of this feature, predefined as HIGH. + """ + # basic block (BB) that jumps to itself weight = HIGH @@ -97,15 +174,24 @@ def __init__(self, startva, endva): self.endva = endva def score(self): + """Returns a perfect score, indicating a significant feature of analysis. + + Returns: + float: A static score of 1.0, due to the high relevance of tight loops. + """ return 1.0 class KindaTightLoop(TightLoop): + """Identifies a tight loop within a function, but with an intermediate BB.""" + # BB that jumps to itself via one intermediate BB pass class TightFunction(Feature): + """A feature representing a tight function, indicating high importance.""" + # function that basically just wraps a tight loop weight = SEVERE @@ -113,33 +199,56 @@ def __init__(self): super(TightFunction, self).__init__(True) def score(self): + """Returns a perfect score, indicating a significant feature of analysis.""" # score 0 because we emulate all tight functions anyway return 0.0 class Mnem(Feature): + """Represents a specific mnemonic instruction within a function, influencing its analysis score.""" + def __init__(self, insn): super(Mnem, self).__init__(f"0x{insn.va:x} {insn}") self.insn = insn def score(self): + """Scores the feature based on the mnemonic instruction.""" return 1.0 class Nzxor(Mnem): + """Represents the non-zeroing XOR operation within a function, influencing its analysis score.""" + weight = HIGH class Shift(Mnem): + """Represents the shift operation within a function, influencing its analysis score.""" + weight = HIGH class Mov(Mnem): + """Represents the move operation within a function, influencing its analysis score.""" + weight = MEDIUM class CallsTo(Feature): + """Represents the number of calls to external locations from within a function. + + This feature calculates its score based on the proportion of calls made to the total possible calls identified in the analysis, helping to assess the connectivity and complexity of functions. + + Attributes: + weight (float): The importance weight of this feature, predefined as MEDIUM. + max_calls_to (float): The maximum number of calls to any single location, used to normalize scores. + + Args: + vw: The vivisect workspace instance for analysis. + locations (list): A list of locations (addresses) where calls are made. + """ + weight = MEDIUM max_calls_to = None @@ -153,10 +262,25 @@ def __init__(self, vw, locations): self.locations = locations def score(self): + """Calculates the feature's score as the ratio of observed calls to the maximum number of calls to any location. + + Returns: + float: The normalized score, indicating the frequency of calls relative to the maximum observed. + """ return float(self.value / self.max_calls_to) class Loop(Feature): + """Represents loop structures within a function, evaluated for their impact on the function's behavior. + + Attributes: + weight (float): Assigned importance of loop features, set to MEDIUM. + comp: The components forming the loop within the function. + + Args: + comp: A collection representing the loop's components. + """ + weight = MEDIUM def __init__(self, comp): @@ -165,24 +289,47 @@ def __init__(self, comp): self.comp = comp def score(self): + """ """ return 1.0 class NzxorTightLoop(Feature): + """Identifies tight loops combined with non-zeroing XOR operations, indicating complex obfuscation or encoding routines. + + Attributes: + weight (float): The severity of this feature, set to SEVERE. + """ + weight = SEVERE def __init__(self): super(NzxorTightLoop, self).__init__(True) def score(self): + """Provides a static score for nzxor tight loop features. + + Returns: + float: A static score of 1.0, reflecting the high importance of this feature. + """ return 1.0 class NzxorLoop(Feature): + """Similar to NzxorTightLoop but for more general loop structures combined with non-zeroing XOR operations. + + Attributes: + weight (float): The severity of this feature, also set to SEVERE. + """ + weight = SEVERE def __init__(self): super(NzxorLoop, self).__init__(True) def score(self): + """Gives a static score for nzxor loop features. + + Returns: + float: A static score of 1.0, denoting the critical nature of this feature. + """ return 1.0 diff --git a/floss/function_argument_getter.py b/floss/function_argument_getter.py index a828fb5eb..db1257c2b 100644 --- a/floss/function_argument_getter.py +++ b/floss/function_argument_getter.py @@ -28,6 +28,13 @@ def __init__(self, call_site_va: int): self.function_contexts: List[FunctionContext] = list() def prehook(self, emu, op, pc): + """collect function contexts at call sites + + Args: + emu: The emulator. + op: The operation. + pc: The program counter. + """ logger.trace("%s: %s", hex(pc), op) if pc == self.call_site_va: # strictly calls here, return address should always be next instruction @@ -35,11 +42,18 @@ def prehook(self, emu, op, pc): self.function_contexts.append(FunctionContext(emu.getEmuSnap(), return_address, pc)) def get_contexts(self) -> List[FunctionContext]: + """return the collected function contexts""" return self.function_contexts @contextlib.contextmanager def installed_monitor(driver, monitor): + """install a monitor on an emulator driver for the duration of a context + + Args: + driver: + monitor: + """ try: driver.add_monitor(monitor) yield @@ -48,13 +62,22 @@ def installed_monitor(driver, monitor): def extract_decoding_contexts( - vw: vivisect.VivWorkspace, decoder_fva: int, index: viv_utils.InstructionFunctionIndex + vw: vivisect.VivWorkspace, + decoder_fva: int, + index: viv_utils.InstructionFunctionIndex, ) -> List[FunctionContext]: - """ - Extract the CPU and memory contexts of all calls to the given function. + """Extract the CPU and memory contexts of all calls to the given function. Under the hood, we brute-force emulate all code paths to extract the state of the stack, registers, and global memory at each call to the given address. + + Args: + vw: vivisect.VivWorkspace: + decoder_fva: int: + index: viv_utils.InstructionFunctionIndex: + + Returns: + List[FunctionContext]: """ logger.trace("Getting function context for function at 0x%08x...", decoder_fva) @@ -70,8 +93,16 @@ def extract_decoding_contexts( def get_caller_vas(vw, fva) -> Set[int]: - """ - return all unique VAs where function is called from + """Finds the virtual addresses of functions that call a specified function. + + Analyzes a workspace to identify instructions that call the function at the provided virtual address (`fva`). Handles filtering of non-call instructions and recursive calls. + + Args: + vw: A Vivisect workspace object. + fva: The virtual address of the function being analyzed. + + Returns: + Set[int]: A set of virtual addresses representing the callers of the function. """ caller_vas = set() for caller_va in vw.getCallers(fva): @@ -85,6 +116,17 @@ def get_caller_vas(vw, fva) -> Set[int]: def is_call(vw: vivisect.VivWorkspace, va: int) -> bool: + """Determines if an instruction at a virtual address is a call instruction. + + Attempts to parse an instruction and checks if the instruction flags indicate a call type. + + Args: + vw: A Vivisect workspace object. + va: The virtual address of the instruction. + + Returns: + bool: True if the instruction is a call, False otherwise. + """ try: op = vw.parseOpcode(va) except (envi.UnsupportedInstruction, envi.InvalidInstruction) as e: @@ -99,8 +141,18 @@ def is_call(vw: vivisect.VivWorkspace, va: int) -> bool: def get_contexts_via_monitor(driver, caller_va, decoder_fva: int, index: viv_utils.InstructionFunctionIndex): - """ - run the given function while collecting arguments to a target function + """Collects function call context information via dynamic monitoring. + + This function sets up a monitor to intercept calls to a target function (`decoder_fva`) made from within a caller function (`caller_va`). It achieves this by emulating the caller function and collecting data about the arguments passed to the target function. + + Args: + driver: An object used to control the emulator or analysis environment. + caller_va: The virtual address of the caller function. + decoder_fva: The virtual address of the target function to be monitored. + index: A VivUtils InstructionFunctionIndex (likely maps virtual addresses to function boundaries). + + Returns: + List[FunctionContext]: A list of FunctionContext objects representing intercepted call contexts. """ try: caller_fva = index[caller_va] diff --git a/floss/identify.py b/floss/identify.py index e219d12a4..94ed9b6ed 100644 --- a/floss/identify.py +++ b/floss/identify.py @@ -24,6 +24,21 @@ def get_function_api(f): + """Retrieves API metadata for a function using Vivisect. + + Queries the Vivisect workspace for information about a function's return type, name, calling convention, and arguments. + + Args: + f: The function object (likely within a Vivisect workspace context). + + Returns: + dict: A dictionary containing the extracted API metadata: + * ret_type: The function's return type. + * ret_name: The name of the return value (if any). + * call_conv: The function's calling convention. + * func_name: The function's name. + * arguments: A list of argument descriptions. + """ ret_type, ret_name, call_conv, func_name, args = f.vw.getFunctionApi(int(f)) return { @@ -36,6 +51,19 @@ def get_function_api(f): def get_function_meta(f): + """Retrieves metadata for a function using Vivisect. + + Queries the Vivisect workspace for information about a function's size, block count, and instruction count. + + Args: + f: The function object (likely within a Vivisect workspace context). + + Returns: + dict: A dictionary containing the extracted metadata: + * size: The function's size in bytes. + * block_count: The number of basic blocks in the function. + * instruction_count: The number of instructions in the function. + """ meta = f.vw.getFunctionMetaDict(int(f)) return { @@ -47,6 +75,16 @@ def get_function_meta(f): def get_max_calls_to(vw, skip_thunks=True, skip_libs=True): + """Retrieves the maximum number of calls to a function in a Vivisect workspace. + + Args: + vw: The Vivisect workspace. + skip_thunks: Whether to skip thunk functions. + skip_libs: Whether to skip library functions. + + Returns: + int: The maximum number of calls to a function in the workspace. + """ calls_to = set() for fva in vw.getFunctions(): @@ -62,15 +100,46 @@ def get_max_calls_to(vw, skip_thunks=True, skip_libs=True): def get_function_score_weighted(features): - return round(sum(feature.weighted_score() for feature in features) / sum(feature.weight for feature in features), 3) + """Calculates a weighted score for a function based on its features. + + Args: + features: The features of the function. + + Returns: + float: The weighted score of the function. + """ + return round( + sum(feature.weighted_score() for feature in features) / sum(feature.weight for feature in features), + 3, + ) def get_top_functions(candidate_functions, count=20) -> List[Dict[int, Dict]]: - return sorted(candidate_functions.items(), key=lambda x: operator.getitem(x[1], "score"), reverse=True)[:count] + """Retrieves the top scoring functions from a set of candidate functions. + + Args: + candidate_functions: A dictionary of candidate functions and their scores. + count: The number of top functions to retrieve. + + Returns: + List[Dict[int, Dict]]: A list of the top scoring functions. + """ + return sorted( + candidate_functions.items(), + key=lambda x: operator.getitem(x[1], "score"), + reverse=True, + )[:count] def get_tight_function_fvas(decoding_function_features) -> List[int]: - """return offsets of identified tight functions""" + """Retrieves the function virtual addresses of functions with tight loops. + + Args: + decoding_function_features: A dictionary of decoding function features. + + Returns: + List[int]: A list of function virtual addresses. + """ tight_function_fvas = list() for fva, function_data in decoding_function_features.items(): if any(filter(lambda f: isinstance(f, TightFunction), function_data["features"])): @@ -79,6 +148,15 @@ def get_tight_function_fvas(decoding_function_features) -> List[int]: def append_unique(fvas, fvas_to_append): + """Appends unique function virtual addresses to a list. + + Args: + fvas: The list of function virtual addresses. + fvas_to_append: The list of function virtual addresses to append. + + Returns: + List[int]: The updated list of function virtual addresses. + """ for fva in fvas_to_append: if fva not in fvas: fvas.append(fva) @@ -86,16 +164,41 @@ def append_unique(fvas, fvas_to_append): def get_function_fvas(functions) -> List[int]: + """Retrieves the function virtual addresses from a dictionary of functions. + + Args: + functions: A dictionary of functions. + + Returns: + List[int]: A list of function virtual addresses. + """ return list(map(lambda p: p[0], functions)) def get_functions_with_tightloops(functions): + """Retrieves functions with tight loops from a dictionary of functions. + + Args: + functions: A dictionary of functions. + + Returns: + Dict[int, List]: A dictionary of functions with tight loops. + """ return get_functions_with_features( - functions, (floss.features.features.TightLoop, floss.features.features.KindaTightLoop) + functions, + (floss.features.features.TightLoop, floss.features.features.KindaTightLoop), ) def get_functions_without_tightloops(functions): + """Retrieves functions without tight loops from a dictionary of functions. + + Args: + functions: A dictionary of functions. + + Returns: + Dict[int, List]: A dictionary of functions without tight loops. + """ tloop_functions = get_functions_with_tightloops(functions) no_tloop_funcs = copy.copy(functions) for fva, _ in tloop_functions.items(): @@ -104,6 +207,15 @@ def get_functions_without_tightloops(functions): def get_functions_with_features(functions, features) -> Dict[int, List]: + """Retrieves functions with specified features from a dictionary of functions. + + Args: + functions: A dictionary of functions. + features: The features to search for. + + Returns: + Dict[int, List]: A dictionary of functions with specified features. + """ functions_by_features = dict() for fva, function_data in functions.items(): func_features = list(filter(lambda f: isinstance(f, features), function_data["features"])) @@ -113,6 +225,16 @@ def get_functions_with_features(functions, features) -> Dict[int, List]: def find_decoding_function_features(vw, functions, disable_progress=False) -> Tuple[Dict[int, Dict], Dict[int, str]]: + """Identifies decoding function features from a set of functions. + + Args: + vw: The Vivisect workspace. + functions: The set of functions to analyze. + disable_progress: Whether to disable progress output. + + Returns: + Tuple[Dict[int, Dict], Dict[int, str]]: A tuple containing the decoding function features and library functions. + """ decoding_candidate_functions: DefaultDict[int, Dict] = collections.defaultdict(dict) library_functions: Dict[int, str] = dict() @@ -128,7 +250,10 @@ def find_decoding_function_features(vw, functions, disable_progress=False) -> Tu n_funcs = len(functions) pb = pbar( - functions, desc="finding decoding function features", unit=" functions", postfix="skipped 0 library functions" + functions, + desc="finding decoding function features", + unit=" functions", + postfix="skipped 0 library functions", ) with logging_redirect_tqdm(), redirecting_print_to_tqdm(): for f in pb: @@ -141,7 +266,11 @@ def find_decoding_function_features(vw, functions, disable_progress=False) -> Tu # TODO handle j_j_j__free_base (lib function wrappers), e.g. 0x140035AF0 in d2ca76... # TODO ignore function called to by library functions function_name = viv_utils.get_function_name(vw, function_address) - logger.debug("skipping library function 0x%x (%s)", function_address, function_name) + logger.debug( + "skipping library function 0x%x (%s)", + function_address, + function_name, + ) library_functions[function_address] = function_name n_libs = len(library_functions) percentage = 100 * (n_libs / n_funcs) @@ -178,7 +307,11 @@ def find_decoding_function_features(vw, functions, disable_progress=False) -> Tu function_data["score"] = get_function_score_weighted(function_data["features"]) - logger.debug("analyzed function 0x%x - total score: %.3f", function_address, function_data["score"]) + logger.debug( + "analyzed function 0x%x - total score: %.3f", + function_address, + function_data["score"], + ) for feat in function_data["features"]: logger.trace(" %s", feat) diff --git a/floss/language/go/coverage.py b/floss/language/go/coverage.py index 94b7fa29d..deb78171c 100644 --- a/floss/language/go/coverage.py +++ b/floss/language/go/coverage.py @@ -18,6 +18,7 @@ def main(): + """Parses command-line arguments, sets up logging, and coordinates string extraction.""" parser = argparse.ArgumentParser(description="Get Go strings") parser.add_argument("path", help="file or path to analyze") parser.add_argument( diff --git a/floss/language/go/extract.py b/floss/language/go/extract.py index c1e5bf95f..c687c30b5 100644 --- a/floss/language/go/extract.py +++ b/floss/language/go/extract.py @@ -27,8 +27,21 @@ def find_stack_strings_with_regex( extract_stackstring_pattern, section_data, offset, min_length ) -> Iterable[StaticString]: - """ - Find stack strings using a regex pattern. + """Finds potential stack strings within a binary section using a regular expression. + + This function searches for assembly instruction patterns that are commonly associated with stack string manipulation. It extracts potential strings, handles encoding and construction of StaticString objects. + + Args: + extract_stackstring_pattern: A compiled regular expression pattern used for matching assembly instructions related to stack strings. + section_data: The binary data of the section to search within. + offset: An offset value likely used for address calculations within the section. + min_length: The minimum length for a string to be considered valid. + + Yields: + Iterable[StaticString]: An iterator of StaticString objects representing potential strings found on the stack. + + Note: + The specific logic for identifying stack strings relies on assembly-level instructions and regular expression matching. See inline code comments or related documentation for more in-depth implementation details. """ for m in extract_stackstring_pattern.finditer(section_data): for i in range(1, 8): @@ -54,8 +67,7 @@ def find_stack_strings_with_regex( def find_amd64_stackstrings(section_data, offset, min_length): - """ - Stackstrings in amd64 architecture are found + """Stackstrings in amd64 architecture are found by searching for the following pattern: .text:000000000048FFA9 48 83 FB 0F cmp rbx, 0Fh @@ -69,6 +81,14 @@ def find_amd64_stackstrings(section_data, offset, min_length): .text:000000000048FFCD 75 49 jnz short loc_490018 .text:000000000048FFCF 80 78 0E 6B cmp byte ptr [rax+0Eh], 6Bh ; 'k' .text:000000000048FFD3 75 43 jnz short loc_490018 + + Args: + section_data: The binary data of the section to analyze. + offset: An offset value likely used for address calculations within the section. + min_length: The minimum length for a string to be considered valid. + + Yields: + Iterable[StaticString]: An iterator of StaticString objects representing potential strings found on the stack. """ extract_stackstring_pattern = re.compile( b"\x48\xba(........)|\x48\xb8(........)|\x81\x78\x08(....)|\x81\x79\x08(....)|\x66\x81\x78\x0c(..)|\x66\x81\x79\x0c(..)|\x80\x78\x0e(.)|\x80\x79\x0e(.)" @@ -78,8 +98,7 @@ def find_amd64_stackstrings(section_data, offset, min_length): def find_i386_stackstrings(section_data, offset, min_length): - """ - Stackstrings in i386 architecture are found + """Stackstrings in i386 architecture are found by searching for the following pattern: .text:0048CED3 75 6D jnz short loc_48CF42 @@ -89,6 +108,14 @@ def find_i386_stackstrings(section_data, offset, min_length): .text:0048CEE4 75 5C jnz short loc_48CF42 .text:0048CEE6 80 7D 06 72 cmp byte ptr [ebp+6], 72h ; 'r' .text:0048CEEA 75 56 jnz short loc_48CF42 + + Args: + section_data (bytes): The binary data of the section to search within. + offset (int): The offset within `section_data` to start the search from. + min_length (int): The minimum length of a stackstring to consider. + + Yields: + Iterator: An iterator over found stackstrings matching the search criteria. """ extract_stackstring_pattern = re.compile( b"\x81\xf9(....)|\x81\x38(....)|\x81\x7d\x00(....)|\x81\x3B(....)|\x66\x81\xf9(..)|\x66\x81\x7b\x04(..)|\x66\x81\x78\x04(..)|\x66\x81\x7d\x04(..)|\x80\x7b\x06(.)|\x80\x7d\x06(.)|\x80\xf8(.)|\x80\x78\x06(.)", @@ -99,13 +126,21 @@ def find_i386_stackstrings(section_data, offset, min_length): def get_stackstrings(pe: pefile.PE, min_length: int) -> Iterable[StaticString]: - """ - Find stackstrings in the given PE file. + """Find stackstrings in the given PE file. TODO(mr-tz): algorithms need improvements / rethinking of approach https://github.com/mandiant/flare-floss/issues/828 - """ + Args: + pe (pefile.PE): The PE file object to analyze. + min_length (int): The minimum length of stackstrings to be considered. + + Yields: + Iterable[StaticString]: An iterable of found stackstrings. + + Raises: + ValueError: If the PE file's architecture is neither x86 nor AMD64. + """ for section in pe.sections: if not section.IMAGE_SCN_MEM_EXECUTE: continue @@ -128,8 +163,7 @@ def get_stackstrings(pe: pefile.PE, min_length: int) -> Iterable[StaticString]: def find_longest_monotonically_increasing_run(l: List[int]) -> Tuple[int, int]: - """ - for the given sorted list of values, + """for the given sorted list of values, find the (start, end) indices of the longest run of values such that each value is greater than or equal to the previous value. @@ -137,6 +171,12 @@ def find_longest_monotonically_increasing_run(l: List[int]) -> Tuple[int, int]: [4, 4, 1, 2, 3, 0, 0] -> (2, 4) ^^^^^^^ + + Args: + l (List[int]): The sorted list of integers to analyze. + + Returns: + Tuple[int, int]: The start and end indices of the longest monotonically increasing sequence. """ max_run_length = 0 max_run_end_index = 0 @@ -162,10 +202,16 @@ def find_longest_monotonically_increasing_run(l: List[int]) -> Tuple[int, int]: def read_struct_string(pe: pefile.PE, instance: StructString) -> str: - """ - read the string for the given struct String instance, + """read the string for the given struct String instance, validating that it looks like UTF-8, or raising a ValueError. + + Args: + pe (pefile.PE): A parsed PE file object. + instance (StructString): A struct String instance to read. + + Returns: + str: The extracted string. """ image_base = pe.OPTIONAL_HEADER.ImageBase @@ -195,8 +241,7 @@ def read_struct_string(pe: pefile.PE, instance: StructString) -> str: def find_string_blob_range(pe: pefile.PE, struct_strings: List[StructString]) -> Tuple[VA, VA]: - """ - find the range of the string blob, as loaded in memory. + """find the range of the string blob, as loaded in memory. the current algorithm relies on the fact that the Go compiler stores the strings in length-sorted order, from shortest to longest. @@ -213,6 +258,24 @@ def find_string_blob_range(pe: pefile.PE, struct_strings: List[StructString]) -> there might be hundreds of thousands and takes many minutes. note: this algorithm relies heavily on the strings being stored in length-sorted order. + + Args: + pe: A parsed PE file object (pefile.PE). + struct_strings: A list of potential StructString instances found within the file. + + Returns: + Tuple[VA, VA]: A tuple representing the starting and ending virtual addresses (VA) of the estimated string blob region. + + Assumptions: + * The Go compiler stores strings in the blob in length-sorted order (shortest to longest). + * The string blob is significantly larger than other sequences of strings within the file. + + Algorithm: + + 1. Sorts StructString instances by address. + 2. Finds the longest sequence of strings with monotonically increasing lengths. + 3. Extracts a string from the middle of this sequence for analysis. + 4. Locates the string's section and finds the surrounding null byte sequences (`00 00 00 00`) to approximate the blob boundaries. """ image_base = pe.OPTIONAL_HEADER.ImageBase @@ -251,8 +314,7 @@ def find_string_blob_range(pe: pefile.PE, struct_strings: List[StructString]) -> def get_string_blob_strings(pe: pefile.PE, min_length) -> Iterable[StaticString]: - """ - for the given PE file compiled by Go, + """for the given PE file compiled by Go, find the string blob and then extract strings from it. we rely on code and memory scanning techniques to identify @@ -269,6 +331,18 @@ def get_string_blob_strings(pe: pefile.PE, min_length) -> Iterable[StaticString] https://github.com/golang/go/blob/36ea4f9680f8296f1c7d0cf7dbb1b3a9d572754a/src/builtin/builtin.go#L70-L73 its still the best we can do, though. + + Args: + pe: A parsed PE file object (pefile.PE). + min_length: The minimum length for a string to be considered valid. + + Yields: + Iterable[StaticString]: An iterator of StaticString objects representing potential strings extracted from the string blob. + + Important Notes: + * Relies on assumptions about how the Go compiler stores and organizes strings. + * Handles potential non-UTF-8 sequences within the blob. + * Employs heuristics to refine string extraction and address potential edge cases. """ image_base = pe.OPTIONAL_HEADER.ImageBase @@ -368,7 +442,9 @@ def get_string_blob_strings(pe: pefile.PE, min_length) -> Iterable[StaticString] else: try: string = StaticString.from_utf8( - last_buf[:size], pe.get_offset_from_rva(last_pointer - image_base), min_length + last_buf[:size], + pe.get_offset_from_rva(last_pointer - image_base), + min_length, ) yield string except ValueError: @@ -377,10 +453,17 @@ def get_string_blob_strings(pe: pefile.PE, min_length) -> Iterable[StaticString] def extract_go_strings(sample, min_length) -> List[StaticString]: - """ - extract Go strings from the given PE file - """ + """Extracts potential Go strings from a PE file. + + This function combines techniques to locate strings within a PE file that are likely associated with a Go-compiled binary. It searches for both strings in the string blob and strings located on the stack. + + Args: + sample: The path to the PE file. + min_length: The minimum length for a string to be considered valid. + Returns: + List[StaticString]: A list of extracted StaticString objects. + """ p = pathlib.Path(sample) buf = p.read_bytes() pe = pefile.PE(data=buf, fast_load=True) @@ -393,6 +476,17 @@ def extract_go_strings(sample, min_length) -> List[StaticString]: def get_static_strings_from_blob_range(sample: pathlib.Path, static_strings: List[StaticString]) -> List[StaticString]: + """Filters a list of StaticString objects to include only those within the Go string blob. + + This function assumes the string blob has already been located within the PE file. + + Args: + sample: The path to the PE file. + static_strings: A list of StaticString objects. + + Returns: + List[StaticString]: A filtered list of StaticString objects that fall within the string blob's memory range. + """ pe = pefile.PE(data=pathlib.Path(sample).read_bytes(), fast_load=True) struct_strings = list(sorted(set(get_struct_string_candidates(pe)), key=lambda s: s.address)) @@ -412,6 +506,13 @@ def get_static_strings_from_blob_range(sample: pathlib.Path, static_strings: Lis def main(argv=None): + """Parses command-line arguments and coordinates Go string extraction. + + Sets up logging, parses arguments, extracts strings using `extract_go_strings`, and displays the results. + + Args: + argv: command-line arguments (Default: None) + """ parser = argparse.ArgumentParser(description="Get Go strings") parser.add_argument("path", help="file or path to analyze") parser.add_argument( diff --git a/floss/language/identify.py b/floss/language/identify.py index d337a6166..338d82304 100644 --- a/floss/language/identify.py +++ b/floss/language/identify.py @@ -19,6 +19,8 @@ class Language(Enum): + """Enumerates programming languages that can be identified in binary samples.""" + GO = "go" RUST = "rust" DOTNET = "dotnet" @@ -27,6 +29,16 @@ class Language(Enum): def identify_language_and_version(sample: Path, static_strings: Iterable[StaticString]) -> Tuple[Language, str]: + """Identifies the programming language and version of a given binary sample based on static strings found within. + + Args: + sample (Path): The path to the binary sample to be analyzed. + static_strings (Iterable[StaticString]): An iterable of static strings extracted from the binary sample. + + Returns: + Tuple[Language, str]: A tuple containing the identified programming language and its version. If the language + cannot be identified, returns (Language.UNKNOWN, "unknown"). + """ is_rust, version = get_if_rust_and_version(static_strings) if is_rust: logger.info("Rust binary found with version: %s", version) @@ -53,11 +65,15 @@ def identify_language_and_version(sample: Path, static_strings: Iterable[StaticS def get_if_rust_and_version(static_strings: Iterable[StaticString]) -> Tuple[bool, str]: - """ - Return if the binary given is compiled with Rust compiler and its version - reference: https://github.com/mandiant/flare-floss/issues/766 - """ + """Determines if a binary sample is written in Rust and identifies its version. + Args: + static_strings (Iterable[StaticString]): An iterable of static strings extracted from the binary sample. + + Returns: + Tuple[bool, str]: A tuple where the first element is a boolean indicating whether the sample is identified as Rust, + and the second element is the version of Rust identified. If the version cannot be determined, returns "unknown". + """ # Check if the binary contains the rustc/commit-hash string # matches strings like "rustc/commit-hash[40 characters]/library" e.g. "rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library" @@ -86,14 +102,21 @@ def get_if_rust_and_version(static_strings: Iterable[StaticString]) -> Tuple[boo def get_if_go_and_version(pe: pefile.PE) -> Tuple[bool, str]: - """ - Return if the binary given is compiled with Go compiler and its version - this checks the magic header of the pclntab structure -pcHeader- - the magic values varies through the version - reference: + """Determines if the provided PE file is compiled with Go and identifies the Go version. + + Args: + pe (pefile.PE): The PE file to be analyzed. + + Returns: + Tuple[bool, str]: A tuple containing a boolean indicating if the file is compiled with Go, + and a string representing the version of Go, or 'VERSION_UNKNOWN_OR_NA' if the version cannot be determined. + + This function checks the pclntab structure's magic header -pcHeader- to identify the Go version. + The magic values vary with the version. It first searches the .rdata section, then all available sections for magic headers and common Go functions. + + Reference: https://github.com/0xjiayu/go_parser/blob/865359c297257e00165beb1683ef6a679edc2c7f/pclntbl.py#L46 """ - go_magic = [ b"\xf0\xff\xff\xff\x00\x00", b"\xfb\xff\xff\xff\x00\x00", @@ -165,8 +188,14 @@ def get_if_go_and_version(pe: pefile.PE) -> Tuple[bool, str]: def get_go_version(magic): - """get the version of the go compiler used to compile the binary""" + """Determines the Go compiler version used to compile the binary based on the magic header. + Args: + magic (bytes): The magic header bytes found in the binary. + + Returns: + str: The identified Go version, or VERSION_UNKNOWN_OR_NA if the version cannot be determined. + """ MAGIC_112 = b"\xfb\xff\xff\xff\x00\x00" # Magic Number from version 1.12 MAGIC_116 = b"\xfa\xff\xff\xff\x00\x00" # Magic Number from version 1.16 MAGIC_118 = b"\xf0\xff\xff\xff\x00\x00" # Magic Number from version 1.18 @@ -185,9 +214,14 @@ def get_go_version(magic): def verify_pclntab(section, pclntab_va: int) -> bool: - """ - Parse headers of pclntab to verify it is legit - used in go parser itself https://go.dev/src/debug/gosym/pclntab.go + """Verifies the legitimacy of the pclntab section by parsing its headers. + + Args: + section: The section object from pefile where pclntab is located. + pclntab_va (int): The virtual address of the pclntab header. + + Returns: + bool: True if the pclntab header is valid, False otherwise. """ try: pc_quanum = section.get_data(pclntab_va + 6, 1)[0] @@ -199,10 +233,13 @@ def verify_pclntab(section, pclntab_va: int) -> bool: def is_dotnet_bin(pe: pefile.PE) -> bool: - """ - Check if the binary is .net or not - Checks the IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR entry in the OPTIONAL_HEADER of the file. - If the entry is not found, or if its size is 0, the file is not a .net file. + """Checks whether the binary is a .NET assembly. + + Args: + pe (pefile.PE): The PE file to check. + + Returns: + bool: True if the file is a .NET assembly, False otherwise. """ try: directory_index = pefile.DIRECTORY_ENTRY["IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR"] diff --git a/floss/language/rust/coverage.py b/floss/language/rust/coverage.py index d36992f3a..003cac872 100644 --- a/floss/language/rust/coverage.py +++ b/floss/language/rust/coverage.py @@ -18,6 +18,10 @@ def main(): + """Parses command-line arguments and coordinates Rust string extraction. + + Sets up logging, parses arguments, analyzes a potential PE file, extracts both generic and Rust-specific strings, and displays statistics. + """ parser = argparse.ArgumentParser(description="Get Rust strings") parser.add_argument("path", help="file or path to analyze") parser.add_argument( diff --git a/floss/language/rust/extract.py b/floss/language/rust/extract.py index 4d40c3af9..69502e321 100644 --- a/floss/language/rust/extract.py +++ b/floss/language/rust/extract.py @@ -24,8 +24,22 @@ def fix_b2s_wide_strings( - strings: List[Tuple[str, str, Tuple[int, int], bool]], min_length: int, buffer: bytes + strings: List[Tuple[str, str, Tuple[int, int], bool]], + min_length: int, + buffer: bytes, ) -> List[Tuple[str, str, Tuple[int, int], bool]]: + """Handles potential misidentification of UTF-16 strings during extraction. + + This function attempts to correct cases where wide strings (likely UTF-16 encoded) have been incorrectly parsed as UTF-8 strings. It does this by re-encoding and re-extracting the string. + + Args: + strings: A list of tuples containing extracted strings, their types, offsets, and other metadata. + min_length: The minimum length for a string to be considered valid. + buffer: The raw byte buffer being analyzed. + + Returns: + List[Tuple[str, str, Tuple[int, int], bool]]: A modified list of string tuples, potentially with corrected strings. + """ # TODO(mr-tz): b2s may parse wide strings where there really should be utf-8 strings # handle special cases here until fixed # https://github.com/mandiant/flare-floss/issues/867 @@ -62,6 +76,17 @@ def filter_and_transform_utf8_strings( strings: List[Tuple[str, str, Tuple[int, int], bool]], start_rdata: int, ) -> List[StaticString]: + """Filters extracted strings, transforms UTF-8 strings, and creates StaticString objects. + + This function focuses on UTF-8 encoded strings. It removes newline characters, calculates the correct offsets within the file, and constructs StaticString objects. + + Args: + strings: A list of tuples containing extracted strings, their types, offsets, and other metadata. + start_rdata: The starting offset of the .rdata section within the file. + + Returns: + List[StaticString]: A list of StaticString objects representing the filtered and transformed UTF-8 strings. + """ transformed_strings = [] for string in strings: @@ -80,9 +105,14 @@ def filter_and_transform_utf8_strings( def split_strings(static_strings: List[StaticString], address: int, min_length: int) -> None: - """ - if address is in between start and end of a string in ref data then split the string - this modifies the elements of the static strings list directly + """Splits StaticString objects if an address falls within their string data. + + This function operates directly on the provided `static_strings` list. It checks if a given address lies within an existing StaticString. If so, it splits the string into two, preserving both parts if they meet the minimum length requirement. + + Args: + static_strings: A list of StaticString objects. + address: The address to check against the string boundaries. + min_length: The minimum length for a string to be considered valid. """ for string in static_strings: @@ -92,7 +122,11 @@ def split_strings(static_strings: List[StaticString], address: int, min_length: if len(rust_string) >= min_length: static_strings.append( - StaticString(string=rust_string, offset=string.offset, encoding=StringEncoding.UTF8) + StaticString( + string=rust_string, + offset=string.offset, + encoding=StringEncoding.UTF8, + ) ) if len(rest) >= min_length: static_strings.append(StaticString(string=rest, offset=address, encoding=StringEncoding.UTF8)) @@ -107,8 +141,16 @@ def split_strings(static_strings: List[StaticString], address: int, min_length: def extract_rust_strings(sample: pathlib.Path, min_length: int) -> List[StaticString]: - """ - Extract Rust strings from a sample + """Extracts potential Rust strings from a file. + + This function likely employs heuristics and techniques tailored to identifying strings that are typically present in Rust-compiled binaries. It leverages the `get_string_blob_strings` function, implying a focus on the string blob region. + + Args: + sample: The path to the file to analyze. + min_length: The minimum length for a string to be considered valid. + + Returns: + List[StaticString]: A list of extracted StaticString objects. """ p = pathlib.Path(sample) @@ -122,6 +164,17 @@ def extract_rust_strings(sample: pathlib.Path, min_length: int) -> List[StaticSt def get_static_strings_from_rdata(sample, static_strings) -> List[StaticString]: + """Filters StaticString objects based on the .rdata section of a PE file. + + This function assumes the existence of a pre-populated list of StaticString objects. It filters these strings, keeping only those whose offsets fall within the boundaries of the .rdata section of a PE file. + + Args: + sample: The path to the PE file. + static_strings: A list of StaticString objects. + + Returns: + List[StaticString]: A filtered list of StaticString objects that are located within the .rdata section. + """ pe = pefile.PE(data=pathlib.Path(sample).read_bytes(), fast_load=True) try: @@ -136,6 +189,20 @@ def get_static_strings_from_rdata(sample, static_strings) -> List[StaticString]: def get_string_blob_strings(pe: pefile.PE, min_length: int) -> Iterable[StaticString]: + """Extracts strings from the .rdata section of a PE file, focusing on UTF-8 strings with a minimum length. + + This function handles architecture-specific xrefs to find strings efficiently without reading all candidate strings, which may be numerous. It's tailored for Rust binaries but applicable to other PE files. + + Args: + pe (pefile.PE): The PE file from which to extract strings. + min_length (int): The minimum length of strings to extract. + + Returns: + Iterable[StaticString]: An iterable of `StaticString` objects found within the .rdata section of the given PE file. + + Note: + The function prioritizes performance and accuracy by leveraging specific characteristics of Rust binaries and PE file structure. + """ image_base = pe.OPTIONAL_HEADER.ImageBase try: @@ -195,6 +262,13 @@ def get_string_blob_strings(pe: pefile.PE, min_length: int) -> Iterable[StaticSt def main(argv=None): + """Parses command-line arguments, coordinates Rust string extraction, and displays results. + + Sets up logging, parses arguments, extracts strings using the `extract_rust_strings` function, sorts the results, and prints them to the console. + + Args: + argv: Command-line arguments (Default: None) + """ parser = argparse.ArgumentParser(description="Get Rust strings") parser.add_argument("path", help="file or path to analyze") parser.add_argument( diff --git a/floss/language/utils.py b/floss/language/utils.py index 124b0ecad..55ea423ec 100644 --- a/floss/language/utils.py +++ b/floss/language/utils.py @@ -18,8 +18,7 @@ @dataclass(frozen=True) class StructString: - """ - a struct String instance. + """a struct String instance. ```go @@ -53,7 +52,6 @@ class StructString: We only use pointer and length data https://github.com/rust-lang/rust/blob/3911a63b7777e19dad4043542f908018e70c0bdd/library/alloc/src/string.rs - """ address: VA @@ -76,11 +74,7 @@ def get_image_range(pe: pefile.PE) -> Tuple[VA, VA]: def find_amd64_lea_xrefs(buf: bytes, base_addr: VA) -> Iterable[VA]: - """ - scan the given data found at the given base address - to find all the 64-bit RIP-relative LEA instructions, - extracting the target virtual address. - """ + """scan the given data found at the given base address to find all the 64-bit RIP-relative LEA instructions, extracting the target virtual address.""" rip_relative_insn_length = 7 rip_relative_insn_re = re.compile( # use rb, or else double escape the term "\x0D", or else beware! @@ -115,11 +109,7 @@ def find_amd64_lea_xrefs(buf: bytes, base_addr: VA) -> Iterable[VA]: def find_i386_lea_xrefs(buf: bytes) -> Iterable[VA]: - """ - scan the given data - to find all the 32-bit absolutely addressed LEA instructions, - extracting the target virtual address. - """ + """scan the given data to find all the 32-bit absolutely addressed LEA instructions, extracting the target virtual address.""" absolute_insn_re = re.compile( rb""" ( @@ -143,11 +133,7 @@ def find_i386_lea_xrefs(buf: bytes) -> Iterable[VA]: def find_lea_xrefs(pe: pefile.PE) -> Iterable[VA]: - """ - scan the executable sections of the given PE file - for LEA instructions that reference valid memory addresses, - yielding the virtual addresses. - """ + """scan the executable sections of the given PE file for LEA instructions that reference valid memory addresses, yielding the virtual addresses.""" low, high = get_image_range(pe) for section in pe.sections: @@ -169,11 +155,7 @@ def find_lea_xrefs(pe: pefile.PE) -> Iterable[VA]: def find_i386_push_xrefs(buf: bytes) -> Iterable[VA]: - """ - scan the given data found at the given base address - to find all the 32-bit PUSH instructions, - extracting the target virtual address. - """ + """scan the given data found at the given base address to find all the 32-bit PUSH instructions, extracting the target virtual address.""" push_insn_re = re.compile( rb""" ( @@ -192,11 +174,7 @@ def find_i386_push_xrefs(buf: bytes) -> Iterable[VA]: def find_amd64_push_xrefs(buf: bytes) -> Iterable[VA]: - """ - scan the given data found at the given base address - to find all the 64-bit PUSH instructions, - extracting the target virtual address. - """ + """scan the given data found at the given base address to find all the 64-bit PUSH instructions, extracting the target virtual address.""" push_insn_re = re.compile( rb""" ( @@ -215,11 +193,7 @@ def find_amd64_push_xrefs(buf: bytes) -> Iterable[VA]: def find_push_xrefs(pe: pefile.PE) -> Iterable[VA]: - """ - scan the executable sections of the given PE file - for PUSH instructions that reference valid memory addresses, - yielding the virtual addresses. - """ + """scan the executable sections of the given PE file for PUSH instructions that reference valid memory addresses, yielding the virtual addresses.""" low, high = get_image_range(pe) for section in pe.sections: @@ -241,11 +215,7 @@ def find_push_xrefs(pe: pefile.PE) -> Iterable[VA]: def find_i386_mov_xrefs(buf: bytes) -> Iterable[VA]: - """ - scan the given data found at the given base address - to find all the 32-bit MOV instructions, - extracting the target virtual address. - """ + """scan the given data found at the given base address to find all the 32-bit MOV instructions, extracting the target virtual address.""" mov_insn_re = re.compile( rb""" ( @@ -269,11 +239,7 @@ def find_i386_mov_xrefs(buf: bytes) -> Iterable[VA]: def find_amd64_mov_xrefs(buf: bytes) -> Iterable[VA]: - """ - scan the given data found at the given base address - to find all the 64-bit MOV instructions, - extracting the target virtual address. - """ + """scan the given data found at the given base address to find all the 64-bit MOV instructions, extracting the target virtual address.""" mov_insn_re = re.compile( rb""" ( @@ -298,11 +264,7 @@ def find_amd64_mov_xrefs(buf: bytes) -> Iterable[VA]: def find_mov_xrefs(pe: pefile.PE) -> Iterable[VA]: - """ - scan the executable sections of the given PE file - for MOV instructions that reference valid memory addresses, - yielding the virtual addresses. - """ + """scan the executable sections of the given PE file for MOV instructions that reference valid memory addresses, yielding the virtual addresses.""" low, high = get_image_range(pe) for section in pe.sections: @@ -329,9 +291,7 @@ def get_max_section_size(pe: pefile.PE) -> int: def get_struct_string_candidates_with_pointer_size(pe: pefile.PE, buf: bytes, psize: int) -> Iterable[StructString]: - """ - scan through the given bytes looking for pairs of machine words (address, length) - that might potentially be struct String instances. + """scan through the given bytes looking for pairs of machine words (address, length) that might potentially be struct String instances. we do some initial validation, like checking that the address is valid and the length is reasonable; however, we don't validate the encoded string data. @@ -383,8 +343,7 @@ def get_i386_struct_string_candidates(pe: pefile.PE, buf: bytes) -> Iterable[Str def get_struct_string_candidates(pe: pefile.PE) -> Iterable[StructString]: - """ - find candidate struct String instances in the given PE file. + """find candidate struct String instances in the given PE file. we do some initial validation, like checking that the address is valid and the length is reasonable; however, we don't validate the encoded string data. @@ -474,7 +433,11 @@ def get_struct_string_candidates(pe: pefile.PE) -> Iterable[StructString]: def get_extract_stats( - pe: pefile, all_ss_strings: List[StaticString], lang_strings: List[StaticString], min_len: int, min_blob_len=0 + pe: pefile, + all_ss_strings: List[StaticString], + lang_strings: List[StaticString], + min_len: int, + min_blob_len=0, ) -> float: # min_blob_len: this is the minimum length of a string blob in binary file to be considered for extraction all_strings = list() @@ -548,7 +511,18 @@ def get_extract_stats( lang_str_found.append(lang_str) if replaced_len < min_len: - results.append((secname, s_id, s_range, False, "missing", s, orig_len - replaced_len, lang_str)) + results.append( + ( + secname, + s_id, + s_range, + False, + "missing", + s, + orig_len - replaced_len, + lang_str, + ) + ) break if not found: diff --git a/floss/logging_.py b/floss/logging_.py index 863a228f4..ccf3bfa79 100644 --- a/floss/logging_.py +++ b/floss/logging_.py @@ -9,6 +9,8 @@ class DebugLevel(int, Enum): + """ """ + NONE = 0 DEFAULT = 1 TRACE = 2 @@ -25,6 +27,16 @@ class DebugLevel(int, Enum): def make_format(color): + """Constructs a log message format string with color formatting. + + Inserts a color code, along with placeholders for log level, logger name, and the log message itself. + + Args: + color: The color code to be inserted into the format string. + + Returns: + str: A formatted string suitable for use with Python's logging module. + """ return f"{color}%(levelname)s{RESET}: %(name)s: %(message)s" @@ -41,18 +53,36 @@ def make_format(color): class ColorFormatter(logging.Formatter): - """ - Logging Formatter to add colors and count warning / errors + """Logging Formatter to add colors and count warning / errors via: https://stackoverflow.com/a/56944256/87207 """ def format(self, record): + """ + Format the log record. + + Args: + record: The log record to format. + + Returns: + str: The formatted log message. + + """ return FORMATTERS[record.levelno].format(record) class LoggerWithTrace(logging.getLoggerClass()): # type: ignore + """A custom logger class that includes a TRACE level and color formatting.""" + def trace(self, msg, *args, **kwargs): + """Log a message with severity 'TRACE' on this logger. + + Args: + msg: The message to log. + *args: Additional positional arguments to be passed to the log message. + **kwargs: Additional keyword arguments to be passed to the log message. + """ self.log(TRACE, msg, *args, **kwargs) @@ -60,8 +90,7 @@ def trace(self, msg, *args, **kwargs): def getLogger(name) -> LoggerWithTrace: - """ - a logging constructor that guarantees that the TRACE level is available. + """a logging constructor that guarantees that the TRACE level is available. use this just like `logging.getLogger`. because we patch stdlib logging upon import of this module (side-effect), @@ -69,5 +98,11 @@ def getLogger(name) -> LoggerWithTrace: then we want to provide a way to ensure that callers can access TRACE consistently. if callers use `floss.logging.getLogger()` intead of `logging.getLogger()`, then they'll be guaranteed to have access to TRACE. + + Args: + name: The name of the logger to retrieve. + + Returns: + LoggerWithTrace: The logger object. """ return logging.getLogger(name) # type: ignore diff --git a/floss/main.py b/floss/main.py index 38f0d027a..d6c7863bc 100644 --- a/floss/main.py +++ b/floss/main.py @@ -64,6 +64,8 @@ class StringType(str, Enum): + """Enumerates the types of strings that FLOSS can extract from a binary.""" + STATIC = "static" STACK = "stack" TIGHT = "tight" @@ -71,16 +73,22 @@ class StringType(str, Enum): class WorkspaceLoadError(ValueError): + """Indicates an error occurred while loading a workspace. + + This exception inherits from ValueError, making it suitable for signaling issues encountered during the process of loading or initializing a workspace (e.g., in an analysis tool). + """ + pass class ArgumentValueError(ValueError): + """Indicates an error occurred while parsing command-line arguments.""" + pass class ArgumentParser(argparse.ArgumentParser): - """ - argparse will call sys.exit upon parsing invalid arguments. + """argparse will call sys.exit upon parsing invalid arguments. we don't want that, because we might be parsing args within test cases, run as a module, etc. so, we override the behavior to raise a ArgumentValueError instead. @@ -88,12 +96,25 @@ class ArgumentParser(argparse.ArgumentParser): """ def error(self, message): + """override the default behavior to raise an exception instead of calling sys.exit. + + Args: + message: The error message to display. + """ self.print_usage(sys.stderr) args = {"prog": self.prog, "message": message} raise ArgumentValueError("%(prog)s: error: %(message)s" % args) def make_parser(argv): + """Create the command-line argument parser for FLOSS. + + Args: + argv: The command-line arguments. + + Returns: + ArgumentParser: The command-line argument parser for FLOSS. + """ desc = ( "The FLARE team's open-source tool to extract ALL strings from malware.\n" f" %(prog)s {__version__} - https://github.com/mandiant/flare-floss/\n\n" @@ -192,7 +213,7 @@ def make_parser(argv): "--format", choices=[f[0] for f in formats], default="auto", - help="select sample format, %s" % format_help if show_all_options else argparse.SUPPRESS, + help=("select sample format, %s" % format_help if show_all_options else argparse.SUPPRESS), ) advanced_group.add_argument( "--language", @@ -209,7 +230,7 @@ def make_parser(argv): "-l", "--load", action="store_true", - help="load from existing FLOSS results document" if show_all_options else argparse.SUPPRESS, + help=("load from existing FLOSS results document" if show_all_options else argparse.SUPPRESS), ) advanced_group.add_argument( "--functions", @@ -293,7 +314,10 @@ def make_parser(argv): help="enable debugging output on STDERR, specify multiple times to increase verbosity", ) logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output on STDOUT except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output on STDOUT except fatal errors", ) logging_group.add_argument( "--color", @@ -307,6 +331,12 @@ def make_parser(argv): def set_log_config(debug, quiet): + """Set the logging configuration for FLOSS. + + Args: + debug: The debug level. + quiet: Whether to suppress all status output except fatal errors. + """ if quiet: log_level = logging.WARNING elif debug >= DebugLevel.TRACE: @@ -346,16 +376,14 @@ def set_log_config(debug, quiet): def select_functions(vw, asked_functions: Optional[List[int]]) -> Set[int]: - """ - Given a workspace and an optional list of function addresses, - collect the set of valid functions, - or all valid function addresses. + """Given a workspace and an optional list of function addresses, collect the set of valid functions, or all valid function addresses. - arguments: - asked_functions: the functions a user wants, or None. + Args: + vw: The vivisect workspace. + asked_functions: The list of function addresses to analyze. - raises: - ValueError: if an asked for function does not exist in the workspace. + Returns: + Set[int]: The set of valid function addresses. """ functions = set(vw.getFunctions()) if not asked_functions: @@ -371,16 +399,22 @@ def select_functions(vw, asked_functions: Optional[List[int]]) -> Set[int]: raise ValueError("failed to find functions: %s" % (", ".join(map(hex, sorted(missing_functions))))) logger.debug("selected %d functions", len(asked_functions_)) - logger.trace("selected the following functions: %s", ", ".join(map(hex, sorted(asked_functions_)))) + logger.trace( + "selected the following functions: %s", + ", ".join(map(hex, sorted(asked_functions_))), + ) return asked_functions_ def is_supported_file_type(sample_file_path: Path): - """ - Return if FLOSS supports the input file type, based on header bytes - :param sample_file_path: - :return: True if file type is supported, False otherwise + """Return if FLOSS supports the input file type, based on header bytes + + Args: + sample_file_path: The path to the sample file. + + Returns: + bool: True if the file type is supported, False otherwise. """ with sample_file_path.open("rb") as f: magic = f.read(2) @@ -397,6 +431,17 @@ def load_vw( sigpaths: List[Path], should_save_workspace: bool = False, ) -> VivWorkspace: + """Load a Vivisect workspace from a file. + + Args: + sample_path: The path to the sample file. + format: The format of the sample file. + sigpaths: The list of paths to signature files. + should_save_workspace: Whether to save the workspace. + + Returns: + VivWorkspace: The Vivisect workspace. + """ if format not in ("sc32", "sc64"): if not is_supported_file_type(sample_path): raise WorkspaceLoadError( @@ -434,18 +479,22 @@ def load_vw( def is_running_standalone() -> bool: - """ - are we running from a PyInstaller'd executable? + """are we running from a PyInstaller'd executable? if so, then we'll be able to access `sys._MEIPASS` for the packaged resources. + + Returns: + bool: True if running standalone, False otherwise. """ return hasattr(sys, "frozen") and hasattr(sys, "_MEIPASS") def get_default_root() -> Path: - """ - get the file system path to the default resources directory. + """get the file system path to the default resources directory. under PyInstaller, this comes from _MEIPASS. under source, this is the root directory of the project. + + Returns: + Path: The file system path to the default resources directory. """ if is_running_standalone(): # pylance/mypy don't like `sys._MEIPASS` because this isn't standard. @@ -457,6 +506,14 @@ def get_default_root() -> Path: def get_signatures(sigs_path: Path) -> List[Path]: + """Get the paths to the signature files. + + Args: + sigs_path: The path to the signature files. + + Returns: + List[Path]: The paths to the signature files. + """ if not sigs_path.exists(): raise IOError("signatures path %s does not exist or cannot be accessed" % str(sigs_path)) @@ -485,9 +542,13 @@ def get_signatures(sigs_path: Path) -> List[Path]: def main(argv=None) -> int: - """ - arguments: - argv: the command line arguments + """The main entry point for FLOSS. + + Args: + argv: The command-line arguments. + + Returns: + int: The return code. """ # use rich as default Traceback handler rich.traceback.install(show_locals=True) @@ -558,7 +619,10 @@ def main(argv=None) -> int: return 0 - results = ResultDocument(metadata=Metadata(file_path=str(sample), min_length=args.min_length), analysis=analysis) + results = ResultDocument( + metadata=Metadata(file_path=str(sample), min_length=args.min_length), + analysis=analysis, + ) sample_size = sample.stat().st_size if sample_size > sys.maxsize: @@ -666,7 +730,9 @@ def main(argv=None) -> int: # here currently only focus on strings in string blob range string_blob_strings = floss.language.go.extract.get_static_strings_from_blob_range(sample, static_strings) results.strings.language_strings_missed = floss.language.utils.get_missed_strings( - string_blob_strings, results.strings.language_strings, args.min_length + string_blob_strings, + results.strings.language_strings, + args.min_length, ) elif results.metadata.language == Language.RUST.value: @@ -701,7 +767,13 @@ def main(argv=None) -> int: sigpaths = get_signatures(args.signatures) - should_save_workspace = os.environ.get("FLOSS_SAVE_WORKSPACE") not in ("0", "no", "NO", "n", None) + should_save_workspace = os.environ.get("FLOSS_SAVE_WORKSPACE") not in ( + "0", + "no", + "NO", + "n", + None, + ) try: with halo.Halo( text="analyzing program", @@ -783,10 +855,8 @@ def main(argv=None) -> int: else: logger.debug("identified %d candidate decoding functions", len(fvas_to_emulate)) for fva in fvas_to_emulate: - score = decoding_function_features[fva]["score"] - xrefs_to = decoding_function_features[fva]["xrefs_to"] - results.analysis.functions.decoding_function_scores[fva] = {"score": score, "xrefs_to": xrefs_to} - logger.debug(" - 0x%x: score: %.3f, xrefs to: %d", fva, score, xrefs_to) + results.analysis.functions.decoding_function_scores[fva] = decoding_function_features[fva]["score"] + logger.debug(" - 0x%x: %.3f", fva, decoding_function_features[fva]["score"]) # TODO filter out strings decoded in library function or function only called by library function(s) results.strings.decoded_strings = decode_strings( diff --git a/floss/render/default.py b/floss/render/default.py index eedc46f3f..f68aab82c 100644 --- a/floss/render/default.py +++ b/floss/render/default.py @@ -27,17 +27,41 @@ def heading_style(s: str): + """Adds cyan color formatting to a string (likely for headings). + + Args: + s: The string to be formatted. + + Returns: + str: The formatted string with color markup. + """ colored_string = "[cyan]" + escape(s) + "[/cyan]" return colored_string def string_style(s: str): + """Adds green color formatting to a string (likely for strings). + + Args: + s: The string to be formatted. + + Returns: + str: The formatted string with color markup. + """ colored_string = "[green]" + escape(s) + " [/green]" return colored_string def width(s: str, character_count: int) -> str: - """pad the given string to at least `character_count`""" + """Pads a string with spaces to a specified length. + + Args: + s: The string to be padded. + character_count: The desired length of the string. + + Returns: + str: The padded string. + """ if len(s) < character_count: return s + " " * (character_count - len(s)) else: @@ -45,6 +69,15 @@ def width(s: str, character_count: int) -> str: def render_meta(results: ResultDocument, console, verbose): + """Formats analysis results and metadata for display. + + Prepares metadata extracted from a file and analysis statistics into a structured table-like format. It adjusts the level of detail based on the provided verbosity setting. + + Args: + results: A ResultDocument object containing analysis metadata and results. + console: An object used for output to the terminal (likely a wrapper). + verbose: Verbosity level influencing the amount of detail displayed. + """ rows: List[Tuple[str, str]] = list() lang = f"{results.metadata.language}" if results.metadata.language else "" @@ -57,13 +90,24 @@ def render_meta(results: ResultDocument, console, verbose): language_value = f"{lang}{lang_v}{lang_s}" if verbose == Verbosity.DEFAULT: - rows.append((width("file path", MIN_WIDTH_LEFT_COL), width(results.metadata.file_path, MIN_WIDTH_RIGHT_COL))) + rows.append( + ( + width("file path", MIN_WIDTH_LEFT_COL), + width(results.metadata.file_path, MIN_WIDTH_RIGHT_COL), + ) + ) rows.append(("identified language", language_value)) else: rows.extend( [ - (width("file path", MIN_WIDTH_LEFT_COL), width(results.metadata.file_path, MIN_WIDTH_RIGHT_COL)), - ("start date", results.metadata.runtime.start_date.strftime("%Y-%m-%d %H:%M:%S")), + ( + width("file path", MIN_WIDTH_LEFT_COL), + width(results.metadata.file_path, MIN_WIDTH_RIGHT_COL), + ), + ( + "start date", + results.metadata.runtime.start_date.strftime("%Y-%m-%d %H:%M:%S"), + ), ("runtime", strtime(results.metadata.runtime.total)), ("version", results.metadata.version), ("identified language", language_value), @@ -84,6 +128,16 @@ def render_meta(results: ResultDocument, console, verbose): def render_string_type_rows(results: ResultDocument) -> List[Tuple[str, str]]: + """Formats analysis results for display. + + Prepares analysis statistics into a structured table-like format. + + Args: + results: A ResultDocument object containing analysis metadata and results. + + Returns: + List[Tuple[str, str]]: A list of tuples containing the analysis statistics. + """ len_ss = len(results.strings.static_strings) len_ls = len(results.strings.language_strings) len_chars_ss = sum([len(s.string) for s in results.strings.static_strings]) @@ -107,20 +161,30 @@ def render_string_type_rows(results: ResultDocument) -> List[Tuple[str, str]]: ), ( " stack strings", - str(len(results.strings.stack_strings)) if results.analysis.enable_stack_strings else DISABLED, + (str(len(results.strings.stack_strings)) if results.analysis.enable_stack_strings else DISABLED), ), ( " tight strings", - str(len(results.strings.tight_strings)) if results.analysis.enable_tight_strings else DISABLED, + (str(len(results.strings.tight_strings)) if results.analysis.enable_tight_strings else DISABLED), ), ( " decoded strings", - str(len(results.strings.decoded_strings)) if results.analysis.enable_decoded_strings else DISABLED, + (str(len(results.strings.decoded_strings)) if results.analysis.enable_decoded_strings else DISABLED), ), ] def render_function_analysis_rows(results) -> List[Tuple[str, str]]: + """Formats function analysis results for display. + + Prepares function analysis statistics into a structured table-like format. + + Args: + results: A ResultDocument object containing analysis metadata and results. + + Returns: + List[Tuple[str, str]]: A list of tuples containing the function analysis statistics. + """ if results.metadata.runtime.vivisect == 0: return [("analyzed functions", DISABLED)] @@ -134,7 +198,12 @@ def render_function_analysis_rows(results) -> List[Tuple[str, str]]: if results.analysis.enable_tight_strings: rows.append((" tight strings", str(results.analysis.functions.analyzed_tight_strings))) if results.analysis.enable_decoded_strings: - rows.append((" decoded strings", str(results.analysis.functions.analyzed_decoded_strings))) + rows.append( + ( + " decoded strings", + str(results.analysis.functions.analyzed_decoded_strings), + ) + ) if results.analysis.functions.decoding_function_scores: rows.append( ( @@ -154,13 +223,45 @@ def render_function_analysis_rows(results) -> List[Tuple[str, str]]: def strtime(seconds): + """Converts seconds to a human-readable time format. + + Args: + seconds: The number of seconds to be converted. + + Returns: + str: The human-readable time format. + """ m, s = divmod(seconds, 60) return f"{m:02.0f}:{s:02.0f}" -def render_language_strings(language, language_strings, language_strings_missed, console, verbose, disable_headers): +def render_language_strings( + language, + language_strings, + language_strings_missed, + console, + verbose, + disable_headers, +): + """Displays language-specific strings to the console. + + Sorts the provided strings, optionally displays a heading, and then prints each string to the console. Formatting (e.g., colors) and string sanitation are controlled by verbosity settings. + + Args: + language: The programming language the strings are associated with. + language_strings: A list of extracted strings. + language_strings_missed: Potentially a list of strings that were not fully extracted. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of headers. + """ strings = sorted(language_strings + language_strings_missed, key=lambda s: s.offset) - render_heading(f"FLOSS {language.upper()} STRINGS ({len(strings)})", console, verbose, disable_headers) + render_heading( + f"FLOSS {language.upper()} STRINGS ({len(strings)})", + console, + verbose, + disable_headers, + ) offset_len = len(f"{strings[-1].offset}") for s in strings: if verbose == Verbosity.DEFAULT: @@ -171,6 +272,18 @@ def render_language_strings(language, language_strings, language_strings_missed, def render_static_substrings(strings, encoding, offset_len, console, verbose, disable_headers): + """Displays static strings with their encoding information to the console. + + Optionally displays a heading, and then prints each string with its offset to the console. Formatting of strings is influenced by verbosity settings. + + Args: + strings: A list of static strings. + encoding: The encoding type of the strings. + offset_len: The length of the offset field for formatting. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of headers. + """ if verbose != Verbosity.DEFAULT: encoding = heading_style(encoding) render_sub_heading(f"FLOSS STATIC STRINGS: {encoding}", len(strings), console, disable_headers) @@ -183,6 +296,16 @@ def render_static_substrings(strings, encoding, offset_len, console, verbose, di def render_staticstrings(strings, console, verbose, disable_headers): + """Displays static strings to the console. + + Sorts the provided strings, optionally displays a heading, and then prints each string to the console. Formatting (e.g., colors) and string sanitation are controlled by verbosity settings. + + Args: + strings: A list of extracted strings. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of headers. + """ render_heading(f"FLOSS STATIC STRINGS ({len(strings)})", console, verbose, disable_headers) ascii_strings = list(filter(lambda s: s.encoding == StringEncoding.ASCII, strings)) @@ -202,8 +325,21 @@ def render_staticstrings(strings, console, verbose, disable_headers): def render_stackstrings( - strings: Union[List[StackString], List[TightString]], console, verbose: bool, disable_headers: bool + strings: Union[List[StackString], List[TightString]], + console, + verbose: bool, + disable_headers: bool, ): + """Renders the results of the stack string extraction phase. + + Optionally displays a heading, and then prints each string with its offset to the console. Formatting of strings is influenced by verbosity settings. + + Args: + strings: A list of extracted strings. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of headers. + """ if verbose == Verbosity.DEFAULT: for s in strings: console.print(sanitize(s.string), markup=False) @@ -230,8 +366,15 @@ def render_stackstrings( def render_decoded_strings(decoded_strings: List[DecodedString], console, verbose, disable_headers): - """ - Render results of string decoding phase. + """Renders the results of the string decoding phase. + + Optionally displays a heading, and then prints each string with its offset to the console. Formatting of strings is influenced by verbosity settings. + + Args: + decoded_strings: A list of extracted strings. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of headers. """ if verbose == Verbosity.DEFAULT: for ds in decoded_strings: @@ -242,7 +385,12 @@ def render_decoded_strings(decoded_strings: List[DecodedString], console, verbos strings_by_functions[ds.decoding_routine].append(ds) for fva, data in strings_by_functions.items(): - render_sub_heading(" FUNCTION at " + heading_style(f"0x{fva:x}"), len(data), console, disable_headers) + render_sub_heading( + " FUNCTION at " + heading_style(f"0x{fva:x}"), + len(data), + console, + disable_headers, + ) rows = [] for ds in data: if ds.address_type == AddressType.STACK: @@ -251,11 +399,22 @@ def render_decoded_strings(decoded_strings: List[DecodedString], console, verbos offset_string = escape("[heap]") else: offset_string = hex(ds.address or 0) - rows.append((offset_string, hex(ds.decoded_at), string_style(sanitize(ds.string)))) + rows.append( + ( + offset_string, + hex(ds.decoded_at), + string_style(sanitize(ds.string)), + ) + ) if rows: table = Table( - "Offset", "Called At", "String", show_header=not (disable_headers), box=box.ASCII2, show_edge=False + "Offset", + "Called At", + "String", + show_header=not (disable_headers), + box=box.ASCII2, + show_edge=False, ) for row in rows: table.add_row(row[0], row[1], row[2]) @@ -264,12 +423,20 @@ def render_decoded_strings(decoded_strings: List[DecodedString], console, verbos def render_heading(heading, console, verbose, disable_headers): - """ - example:: + """example:: ───────────────────────── FLOSS TIGHT STRINGS (0) ───────────────────────── + Displays a prominent heading for a section of the report. + + Constructs a single-row table with horizontal borders to visually distinguish a heading. Formatting (e.g., color) is influenced by the verbosity setting. + + Args: + heading: The text of the heading. + console: An object used for output to the terminal. + verbose: Verbosity level influencing formatting. + disable_headers: A flag to suppress the display of the heading entirely. """ if disable_headers: return @@ -286,12 +453,21 @@ def render_heading(heading, console, verbose, disable_headers): def render_sub_heading(heading, n, console, disable_headers): - """ - example:: + """example:: +-----------------------------------+ | FLOSS STATIC STRINGS: ASCII (862) | +-----------------------------------+ + + Displays a subheading with a count for a section of the report. + + Constructs a single-row table with more prominent borders than the primary heading, visually differentiating a subheading. Includes a count associated with the section. + + Args: + heading: The text of the subheading. + n: The count associated with the section. + console: An object used for output to the terminal. + disable_headers: A flag to suppress the display of the subheading entirely. """ if disable_headers: return @@ -302,6 +478,14 @@ def render_sub_heading(heading, n, console, disable_headers): def get_color(color): + """Converts a string color setting to a rich color system. + + Args: + color: A string representing a color setting. + + Returns: + str: A string representing a rich color system. + """ if color == "always": color_system = "256" elif color == "auto": @@ -315,8 +499,24 @@ def get_color(color): def render(results: floss.results.ResultDocument, verbose, disable_headers, color): + """Renders analysis results to a string. + + Args: + results: A ResultDocument object containing analysis metadata and results. + verbose: Verbosity level influencing the amount of detail displayed. + disable_headers: A flag to suppress the display of headers. + color: A string representing a color setting. + + Returns: + str: A string containing the formatted analysis results. + """ sys.__stdout__.reconfigure(encoding="utf-8") - console = Console(file=io.StringIO(), color_system=get_color(color), highlight=False, soft_wrap=True) + console = Console( + file=io.StringIO(), + color_system=get_color(color), + highlight=False, + soft_wrap=True, + ) if not disable_headers: console.print("\n") @@ -347,18 +547,31 @@ def render(results: floss.results.ResultDocument, verbose, disable_headers, colo console.print("\n") if results.analysis.enable_stack_strings: - render_heading(f"FLOSS STACK STRINGS ({len(results.strings.stack_strings)})", console, verbose, disable_headers) + render_heading( + f"FLOSS STACK STRINGS ({len(results.strings.stack_strings)})", + console, + verbose, + disable_headers, + ) render_stackstrings(results.strings.stack_strings, console, verbose, disable_headers) console.print("\n") if results.analysis.enable_tight_strings: - render_heading(f"FLOSS TIGHT STRINGS ({len(results.strings.tight_strings)})", console, verbose, disable_headers) + render_heading( + f"FLOSS TIGHT STRINGS ({len(results.strings.tight_strings)})", + console, + verbose, + disable_headers, + ) render_stackstrings(results.strings.tight_strings, console, verbose, disable_headers) console.print("\n") if results.analysis.enable_decoded_strings: render_heading( - f"FLOSS DECODED STRINGS ({len(results.strings.decoded_strings)})", console, verbose, disable_headers + f"FLOSS DECODED STRINGS ({len(results.strings.decoded_strings)})", + console, + verbose, + disable_headers, ) render_decoded_strings(results.strings.decoded_strings, console, verbose, disable_headers) diff --git a/floss/render/json.py b/floss/render/json.py index d3e87b391..ff9241649 100644 --- a/floss/render/json.py +++ b/floss/render/json.py @@ -8,14 +8,23 @@ class FlossJSONEncoder(json.JSONEncoder): - """ - serializes FLOSS data structures into JSON. - specifically: - - dataclasses into their dict representation - - datetimes to ISO8601 strings + """Custom JSON encoder for serializing FLOSS data structures. + + Handles the following special cases: + + * Dataclasses: Converts dataclass instances into their dictionary representations. + * Datetimes: Encodes datetime objects into ISO 8601 formatted strings (with timezone information). """ def default(self, o): + """Overrides the default JSON encoding behavior to handle dataclasses and datetime objects. + + Args: + o: The object to encode. + + Returns: + The JSON-serializable representation of the object. + """ if dataclasses.is_dataclass(o): return dataclasses.asdict(o) if isinstance(o, datetime.datetime): @@ -24,6 +33,16 @@ def default(self, o): def render(doc: ResultDocument) -> str: + """Serializes a ResultDocument into a JSON string. + + Uses the custom `FlossJSONEncoder` to ensure correct handling of dataclasses and datetime objects within the analysis results. + + Args: + doc: The ResultDocument object containing analysis results. + + Returns: + str: A JSON-formatted string representation of the ResultDocument. + """ return json.dumps( doc, cls=FlossJSONEncoder, diff --git a/floss/render/sanitize.py b/floss/render/sanitize.py index 584614015..5b3611fd1 100644 --- a/floss/render/sanitize.py +++ b/floss/render/sanitize.py @@ -4,8 +4,14 @@ def sanitize(s: str, is_ascii_only=True) -> str: - """ - Return sanitized string for printing to cli. + """Sanitize a string for printing. + + Args: + s: The string to sanitize. + is_ascii_only: Whether to only allow ASCII characters. + + Returns: + The sanitized string. """ s = s.replace("\n", "\\n") s = s.replace("\r", "\\r") diff --git a/floss/results.py b/floss/results.py index bd8b90f1a..493e26670 100644 --- a/floss/results.py +++ b/floss/results.py @@ -28,14 +28,20 @@ class InvalidResultsFile(Exception): + """Indicates that a results file is invalid, corrupt, or in an incompatible format.""" + pass class InvalidLoadConfig(Exception): + """Indicates that the load configuration is invalid.""" + pass class StringEncoding(str, Enum): + """Enumeration of string encodings.""" + ASCII = "ASCII" UTF16LE = "UTF-16LE" UTF8 = "UTF-8" @@ -43,12 +49,11 @@ class StringEncoding(str, Enum): @dataclass(frozen=True) class StackString: - """ - here's what the following members represent: - - + """here's what the following members represent: + + [smaller addresses] - + +---------------+ <- stack_pointer (top of stack) | | \ +---------------+ | offset @@ -62,19 +67,8 @@ class StackString: +---------------+ | | | / +---------------+ <- original_stack_pointer (bottom of stack, probably bp) - + [bigger addresses] - - - Attributes: - function: the address of the function from which the stackstring was extracted - string: the extracted string - encoding: string encoding - program_counter: the program counter at the moment the string was extracted - stack_pointer: the stack counter at the moment the string was extracted - original_stack_pointer: the initial stack counter when the function was entered - offset: the offset into the stack from at which the stack string was found - frame_offset: the offset from the function frame at which the stack string was found """ function: int @@ -88,10 +82,14 @@ class StackString: class TightString(StackString): + """A string that is tightly packed in memory.""" + pass class AddressType(str, Enum): + """Enumeration of address types.""" + STACK = "STACK" GLOBAL = "GLOBAL" HEAP = "HEAP" @@ -99,17 +97,7 @@ class AddressType(str, Enum): @dataclass(frozen=True) class DecodedString: - """ - A decoding string and details about where it was found. - - Attributes: - address: address of the string in memory - address_type: type of the address of the string in memory - string: the decoded string - encoding: the string encoding, like ASCII or unicode - decoded_at: the address at which the decoding routine is called - decoding_routine: the address of the decoding routine - """ + """A decoding string and details about where it was found.""" address: int address_type: AddressType @@ -121,14 +109,7 @@ class DecodedString: @dataclass(frozen=True) class StaticString: - """ - A string extracted from the raw bytes of the input. - - Attributes: - string: the string - offset: the offset into the input where the string is found - encoding: the string encoding, like ASCII or unicode - """ + """A string extracted from the raw bytes of the input.""" string: str offset: int @@ -136,6 +117,16 @@ class StaticString: @classmethod def from_utf8(cls, buf, addr, min_length): + """Create a StaticString from a buffer of bytes. + + Args: + buf: The buffer of bytes. + addr: The address of the buffer. + min_length: The minimum length of the string. + + Returns: + StaticString: The created string. + """ try: decoded_string = buf.decode("utf-8") except UnicodeDecodeError: @@ -151,6 +142,8 @@ def from_utf8(cls, buf, addr, min_length): @dataclass class Runtime: + """The runtime of the analysis.""" + start_date: datetime.datetime = datetime.datetime.now() total: float = 0 vivisect: float = 0 @@ -164,6 +157,8 @@ class Runtime: @dataclass class Functions: + """The functions that were analyzed.""" + discovered: int = 0 library: int = 0 analyzed_stack_strings: int = 0 @@ -174,6 +169,8 @@ class Functions: @dataclass class Analysis: + """The analysis configuration.""" + enable_static_strings: bool = True enable_stack_strings: bool = True enable_tight_strings: bool = True @@ -186,6 +183,8 @@ class Analysis: @dataclass class Metadata: + """Metadata about the analysis.""" + file_path: str version: str = __version__ imagebase: int = 0 @@ -198,6 +197,8 @@ class Metadata: @dataclass class Strings: + """The strings that were found.""" + stack_strings: List[StackString] = field(default_factory=list) tight_strings: List[TightString] = field(default_factory=list) decoded_strings: List[DecodedString] = field(default_factory=list) @@ -208,16 +209,33 @@ class Strings: @dataclass class ResultDocument: + """The result document.""" + metadata: Metadata analysis: Analysis = field(default_factory=Analysis) strings: Strings = field(default_factory=Strings) @classmethod - def parse_file(cls, path: Path) -> "ResultDocument": - return TypeAdapter(cls).validate_json(path.read_text(encoding="utf-8")) + def parse_file(cls, path): + """Parse a result document from a file. + + Args: + path: The path to the file. + + Returns: + ResultDocument: The parsed result document. + """ + # We're ignoring the following mypy error since this field is guaranteed by the Pydantic dataclass. + return cls.__pydantic_model__.parse_file(path) # type: ignore def log_result(decoded_string, verbosity): + """Log a decoded string. + + Args: + decoded_string: The decoded string. + verbosity: The verbosity level. + """ string = sanitize(decoded_string.string) if verbosity < Verbosity.VERBOSE: logger.info("%s", string) @@ -243,6 +261,17 @@ def log_result(decoded_string, verbosity): def load(sample: Path, analysis: Analysis, functions: List[int], min_length: int) -> ResultDocument: + """Load a result document from a file, applying filters as needed. + + Args: + sample: Path: + analysis: Analysis: + functions: List[int]: + min_length: int: + + Returns: + ResultDocument: The loaded result document. + """ logger.debug("loading results document: %s", str(sample)) results = read(sample) results.metadata.file_path = f"{sample}\n{results.metadata.file_path}" @@ -256,6 +285,19 @@ def load(sample: Path, analysis: Analysis, functions: List[int], min_length: int def read(sample: Path) -> ResultDocument: + """Loads a ResultDocument from a file. + + Attempts to read a file as JSON and deserialize it into a ResultDocument object. Handles potential JSON decoding errors, Unicode-related errors, and validation failures. + + Args: + sample: A Path object representing the file to load. + + Returns: + ResultDocument: The deserialized ResultDocument. + + Raises: + InvalidResultsFile: If the file cannot be loaded as a valid ResultDocument (e.g., due to incorrect formatting or validation errors). + """ try: with sample.open("rb") as f: results = json.loads(f.read().decode("utf-8")) @@ -271,6 +313,14 @@ def read(sample: Path) -> ResultDocument: def check_set_string_types(results: ResultDocument, wanted_analysis: Analysis) -> None: + """Ensures consistency in string type analysis settings between loaded results and desired analysis. + + This function checks if specific string analysis types were enabled in a desired analysis configuration (`wanted_analysis`) but are missing from the loaded analysis results (`results`). If found, it issues warnings and updates the `results` object to match the `wanted_analysis` settings. + + Args: + results: A ResultDocument object containing loaded analysis results. + wanted_analysis: An Analysis object representing the desired analysis configuration. + """ for string_type in STRING_TYPE_FIELDS: if getattr(wanted_analysis, string_type) and not getattr(results.analysis, string_type): logger.warning(f"{string_type} not in loaded data, use --only/--no to enable/disable type(s)") @@ -278,6 +328,17 @@ def check_set_string_types(results: ResultDocument, wanted_analysis: Analysis) - def filter_functions(results: ResultDocument, functions: List[int]) -> None: + """Updates a ResultDocument to include analysis data only from specified functions. + + Removes function-related data from the `results` object if the function's address (virtual address) is not present in the provided `functions` list. + + Args: + results: A ResultDocument object containing analysis results. + functions: A list of function virtual addresses to keep in the results. + + Raises: + InvalidLoadConfig: If a specified function address is not found in the loaded results. + """ filtered_scores = dict() for fva in functions: try: @@ -298,6 +359,14 @@ def filter_functions(results: ResultDocument, functions: List[int]) -> None: def filter_string_len(results: ResultDocument, min_length: int) -> None: + """Removes strings shorter than a specified length from the ResultDocument. + + Filters various string collections within the `results` object, keeping only strings that meet the minimum length criterion. + + Args: + results: A ResultDocument object containing analysis results. + min_length: The minimum length a string must have to be retained. + """ results.strings.static_strings = list(filter(lambda s: len(s.string) >= min_length, results.strings.static_strings)) results.strings.stack_strings = list(filter(lambda s: len(s.string) >= min_length, results.strings.stack_strings)) results.strings.tight_strings = list(filter(lambda s: len(s.string) >= min_length, results.strings.tight_strings)) diff --git a/floss/stackstrings.py b/floss/stackstrings.py index afe91b7c1..78eb28ccc 100644 --- a/floss/stackstrings.py +++ b/floss/stackstrings.py @@ -23,8 +23,7 @@ @dataclass(frozen=True) class CallContext: - """ - Context for stackstring extraction. + """Context for stackstring extraction. Attributes: pc: the current program counter @@ -42,10 +41,10 @@ class CallContext: class StackstringContextMonitor(viv_utils.emulator_drivers.Monitor): - """ - Observes emulation and extracts the active stack frame contents: - - at each function call in a function, and - - based on heuristics looking for mov instructions to a hardcoded buffer. + """Observes emulation and extracts the active stack frame contents: + + - at each function call in a function, and + - based on heuristics looking for mov instructions to a hardcoded buffer. """ def __init__(self, init_sp, bb_ends): @@ -71,9 +70,21 @@ def update_contexts(self, emu, va) -> None: # TODO get va here from emu? def get_call_context(self, emu, va, pre_ctx_strings: Optional[Set[str]] = None) -> CallContext: - """ - Returns a context with the bytes on the stack between the base pointer - (specifically, stack pointer at function entry), and stack pointer. + """Collects context information related to a function call. + + Retrieves the stack boundaries, reads the stack memory, and creates a `CallContext` object to encapsulate the extracted information. Optionally integrates pre-existing context strings. + + Args: + self: Likely a reference to an analysis object or a context tracker. + emu: The Vivisect emulator object. + va: The virtual address of the function call. + pre_ctx_strings: An optional set of strings for filtering or refining context generation. + + Returns: + CallContext: An object representing the context of the function call. + + Raises: + ValueError: If the calculated stack size exceeds a maximum threshold (`MAX_STACK_SIZE`). """ stack_top = emu.getStackCounter() stack_bottom = self._init_sp @@ -92,8 +103,12 @@ def posthook(self, emu, op, endpc): self.check_mov_heuristics(emu, op, endpc) def check_mov_heuristics(self, emu, op, endpc): - """ - Extract contexts at end of a basic block (bb) if bb contains enough movs to a harcoded buffer. + """Extract contexts at end of a basic block (bb) if bb contains enough movs to a harcoded buffer. + + Args: + emu: The Vivisect emulator object. + op: The current instruction. + endpc: The virtual address of the end of the basic block. """ # TODO check number of written bytes via writelog? # count movs, shortcut if this basic block has enough writes to trigger context extraction already @@ -107,6 +122,14 @@ def check_mov_heuristics(self, emu, op, endpc): self._mov_count = 0 def is_stack_mov(self, op): + """Check if the given instruction is a move to a stack address. + + Args: + op: The current instruction. + + Returns: + bool: True if the instruction is a move to a stack address, False otherwise. + """ if not op.mnem.startswith("mov"): return False @@ -121,6 +144,16 @@ def is_stack_mov(self, op): def extract_call_contexts(vw, fva, bb_ends): + """Extracts call contexts from a function. + + Args: + vw: The vivisect workspace. + fva: The function virtual address. + bb_ends: The set of virtual addresses that are the last instructions of basic blocks. + + Returns: + List[CallContext]: A list of call contexts. + """ emu = floss.utils.make_emulator(vw) monitor = StackstringContextMonitor(emu.getStackCounter(), bb_ends) driver = viv_utils.emulator_drivers.FullCoverageEmulatorDriver(emu, repmax=256) @@ -134,8 +167,13 @@ def extract_call_contexts(vw, fva, bb_ends): def get_basic_block_ends(vw): - """ - Return the set of VAs that are the last instructions of basic blocks. + """Return the set of VAs that are the last instructions of basic blocks. + + Args: + vw: The vivisect workspace. + + Returns: + Set[int]: A set of virtual addresses. """ index = set([]) for funcva in vw.getFunctions(): @@ -148,16 +186,23 @@ def get_basic_block_ends(vw): def extract_stackstrings( - vw, selected_functions, min_length, verbosity=Verbosity.DEFAULT, disable_progress=False + vw, + selected_functions, + min_length, + verbosity=Verbosity.DEFAULT, + disable_progress=False, ) -> List[StackString]: - """ - Extracts the stackstrings from functions in the given workspace. + """Extracts the stackstrings from functions in the given workspace. + + Args: + vw: The vivisect workspace. + selected_functions: A list of virtual addresses of functions to analyze. + min_length: The minimum length of a string to extract. + verbosity: The verbosity level. + disable_progress: A flag to disable the progress bar. - :param vw: The vivisect workspace from which to extract stackstrings. - :param selected_functions: list of selected functions - :param min_length: minimum string length - :param verbosity: verbosity level - :param disable_progress: do NOT show progress bar + Returns: + List[StackString]: A list of stackstrings. """ logger.info("extracting stackstrings from %d functions", len(selected_functions)) @@ -165,7 +210,10 @@ def extract_stackstrings( bb_ends = get_basic_block_ends(vw) pb = floss.utils.get_progress_bar( - selected_functions, disable_progress, desc="extracting stackstrings", unit=" functions" + selected_functions, + disable_progress, + desc="extracting stackstrings", + unit=" functions", ) with tqdm.contrib.logging.logging_redirect_tqdm(), floss.utils.redirecting_print_to_tqdm(): for fva in pb: @@ -174,7 +222,9 @@ def extract_stackstrings( ctxs = extract_call_contexts(vw, fva, bb_ends) for n, ctx in enumerate(ctxs, 1): logger.trace( - "extracting stackstrings at checkpoint: 0x%x stacksize: 0x%x", ctx.pc, ctx.init_sp - ctx.sp + "extracting stackstrings at checkpoint: 0x%x stacksize: 0x%x", + ctx.pc, + ctx.init_sp - ctx.sp, ) for s in extract_strings(ctx.stack_memory, min_length, seen): frame_offset = (ctx.init_sp - ctx.sp) - s.offset - getPointerSize(vw) diff --git a/floss/string_decoder.py b/floss/string_decoder.py index 6d6534bef..30d82a66c 100644 --- a/floss/string_decoder.py +++ b/floss/string_decoder.py @@ -29,17 +29,16 @@ def memdiff_search(bytes1, bytes2): - """ - Use binary searching to find the offset of the first difference + """Use binary searching to find the offset of the first difference between two strings. - :param bytes1: The original sequence of bytes - :param bytes2: A sequence of bytes to compare with bytes1 - :type bytes1: str - :type bytes2: str - :rtype: int offset of the first location a and b differ, None if strings match - """ + Args: + bytes1: The first sequence of bytes. + bytes2: The second sequence of bytes. + Returns: + int: The offset of the first difference between the two strings. + """ # Prevent infinite recursion on inputs with length of one half = (len(bytes1) // 2) or 1 @@ -57,15 +56,14 @@ def memdiff_search(bytes1, bytes2): def memdiff(bytes1, bytes2): - """ - Find all differences between two input strings. - - :param bytes1: The original sequence of bytes - :param bytes2: The sequence of bytes to compare to - :type bytes1: str - :type bytes2: str - :rtype: list of (offset, length) tuples indicating locations bytes1 and - bytes2 differ + """Find all differences between two input strings. + + Args: + bytes1: The first sequence of bytes. + bytes2: The second sequence of bytes. + + Returns: + list: A list of tuples, where each tuple contains the offset and length of a difference between the two strings. """ # Shortcut matching inputs if bytes1 == bytes2: @@ -105,6 +103,17 @@ def memdiff(bytes1, bytes2): def should_shortcut(fva: int, n: int, n_calls: int, found_strings: int) -> bool: + """Determine if the emulation of a decoding function should be shortcut. + + Args: + fva: The address of the decoding function. + n: The current call number. + n_calls: The total number of calls to the decoding function. + found_strings: The number of strings found so far. + + Returns: + bool: True if the emulation of the decoding function should be shortcut, False otherwise. + """ if n_calls < DS_FUNCTION_CALLS_RARE: # don't shortcut return False @@ -116,7 +125,10 @@ def should_shortcut(fva: int, n: int, n_calls: int, found_strings: int) -> bool: if n >= shortcut_threshold and found_strings <= DS_FUNCTION_MIN_DECODED_STRINGS: logger.debug( - "only %d results after emulating %d contexts, shortcutting emulation of 0x%x", found_strings, n, fva + "only %d results after emulating %d contexts, shortcutting emulation of 0x%x", + found_strings, + n, + fva, ) return True return False @@ -130,16 +142,18 @@ def decode_strings( verbosity: int = Verbosity.DEFAULT, disable_progress: bool = False, ) -> List[DecodedString]: - """ - FLOSS string decoding algorithm - - arguments: - vw: the workspace - functions: addresses of the candidate decoding routines - min_length: minimum string length - max_insn_count: max number of instructions to emulate per function - verbosity: verbosity level - disable_progress: no progress bar + """FLOSS string decoding algorithm + + Args: + vw: The vivisect workspace in which the function is defined. + functions: A list of virtual addresses of functions to emulate. + min_length: The minimum length of string to consider. + max_insn_count: The maximum number of instructions to emulate per function. + verbosity: The verbosity level. + disable_progress: Whether to disable progress bars. + + Returns: + list: A list of DecodedString objects representing the decoded strings. """ logger.info("decoding strings") @@ -177,8 +191,7 @@ def decode_strings( def emulate_decoding_routine(vw, function_index, function: int, context, max_instruction_count: int) -> List[Delta]: - """ - Emulate a function with a given context and extract the CPU and + """Emulate a function with a given context and extract the CPU and memory contexts at interesting points during emulation. These "interesting points" include calls to other functions and the final state. @@ -188,15 +201,15 @@ def emulate_decoding_routine(vw, function_index, function: int, context, max_ins This prevents unexpected infinite loops. This number is taken from emulating the decoding of "Hello world" using RC4. + Args: + vw: The vivisect workspace. + function_index: The index of the function to emulate. + function: The address of the function to emulate. + context: The context of the function call. + max_instruction_count: The maximum number of instructions to emulate. - :param vw: The vivisect workspace in which the function is defined. - :type function_index: viv_utils.FunctionIndex - :param function: The address of the function to emulate. - :type context: funtion_argument_getter.FunctionContext - :param context: The initial state of the CPU and memory - prior to the function being called. - :param max_instruction_count: The maximum number of instructions to emulate per function. - :rtype: Sequence[decoding_manager.Delta] + Returns: + List[Delta]: A list of Deltas representing the emulator state at each interesting place. """ emu = floss.utils.make_emulator(vw) emu.setEmuSnap(context.emu_snap) @@ -214,6 +227,8 @@ def emulate_decoding_routine(vw, function_index, function: int, context, max_ins @dataclass class DeltaBytes: + """ """ + address: int address_type: AddressType bytes: bytes @@ -222,14 +237,15 @@ class DeltaBytes: def extract_delta_bytes(delta: Delta, decoded_at_va: int, source_fva: int = 0x0) -> List[DeltaBytes]: - """ - Extract the sequence of byte sequences that differ from before - and after snapshots. + """Extract the sequence of byte sequences that differ from before and after snapshots. + + Args: + delta: The delta object. + decoded_at_va: The address at which the decoding occurred. + source_fva: The address of the source function. - :param delta: The before and after snapshots of memory to diff. - :param decoded_at_va: The virtual address of a specific call to - the decoding function candidate that resulted in a memory diff - :param source_fva: function VA of the decoding routine candidate + Returns: + List[DeltaBytes]: A list of DeltaBytes objects representing the byte sequences that differ from before and after snapshots. """ delta_bytes = [] @@ -255,7 +271,13 @@ def extract_delta_bytes(delta: Delta, decoded_at_va: int, source_fva: int = 0x0) location_type = AddressType.HEAP if not is_all_zeros(bytes_after): delta_bytes.append( - DeltaBytes(section_after_start, location_type, bytes_after, decoded_at_va, source_fva) + DeltaBytes( + section_after_start, + location_type, + bytes_after, + decoded_at_va, + source_fva, + ) ) continue diff --git a/floss/strings.py b/floss/strings.py index a11836397..c2f5330ec 100644 --- a/floss/strings.py +++ b/floss/strings.py @@ -16,6 +16,17 @@ def buf_filled_with(buf, character): + """Determines if a buffer is entirely filled with a specified character. + + Checks the buffer in chunks and compares them against a reference chunk created from the provided character. + + Args: + buf: The buffer to be analyzed. + character: The character to check for. + + Returns: + bool: True if the buffer is filled with the given character, False otherwise. + """ dupe_chunk = character * SLICE_SIZE for offset in range(0, len(buf), SLICE_SIZE): new_chunk = buf[offset : offset + SLICE_SIZE] @@ -25,18 +36,24 @@ def buf_filled_with(buf, character): def extract_ascii_unicode_strings(buf, n=MIN_LENGTH) -> Iterable[StaticString]: + """Extract ASCII and Unicode strings from the given binary data. + + Args: + buf: A bytestring. + n: The minimum length of strings to extract. (Default value = MIN_LENGTH) + """ yield from chain(extract_ascii_strings(buf, n), extract_unicode_strings(buf, n)) def extract_ascii_strings(buf, n=MIN_LENGTH) -> Iterable[StaticString]: - """ - Extract ASCII strings from the given binary data. + """Extract ASCII strings from the given binary data. + + Args: + buf: A bytestring. + n: The minimum length of strings to extract. (Default value = MIN_LENGTH) - :param buf: A bytestring. - :type buf: str - :param n: The minimum length of strings to extract. - :type n: int - :rtype: Sequence[StaticString] + Returns: + Iterable[StaticString]: An iterable of StaticString objects representing the extracted strings. """ if not buf: @@ -52,18 +69,22 @@ def extract_ascii_strings(buf, n=MIN_LENGTH) -> Iterable[StaticString]: reg = rb"([%s]{%d,})" % (ASCII_BYTE, n) r = re.compile(reg) for match in r.finditer(buf): - yield StaticString(string=match.group().decode("ascii"), offset=match.start(), encoding=StringEncoding.ASCII) + yield StaticString( + string=match.group().decode("ascii"), + offset=match.start(), + encoding=StringEncoding.ASCII, + ) def extract_unicode_strings(buf, n=MIN_LENGTH) -> Iterable[StaticString]: - """ - Extract naive UTF-16 strings from the given binary data. + """Extract naive UTF-16 strings from the given binary data. + + Args: + buf: A bytestring. + n: The minimum length of strings to extract. (Default value = MIN_LENGTH) - :param buf: A bytestring. - :type buf: str - :param n: The minimum length of strings to extract. - :type n: int - :rtype: Sequence[StaticString] + Returns: + Iterable[StaticString]: An iterable of StaticString objects representing the extracted strings. """ if not buf: @@ -80,13 +101,16 @@ def extract_unicode_strings(buf, n=MIN_LENGTH) -> Iterable[StaticString]: for match in r.finditer(buf): try: yield StaticString( - string=match.group().decode("utf-16"), offset=match.start(), encoding=StringEncoding.UTF16LE + string=match.group().decode("utf-16"), + offset=match.start(), + encoding=StringEncoding.UTF16LE, ) except UnicodeDecodeError: pass def main(): + """Main function for standalone usage.""" import sys with open(sys.argv[1], "rb") as f: diff --git a/floss/tightstrings.py b/floss/tightstrings.py index 89260ae0c..c3ec17652 100644 --- a/floss/tightstrings.py +++ b/floss/tightstrings.py @@ -20,6 +20,11 @@ class TightstringContextMonitor(StackstringContextMonitor): + """Observes emulation and extracts the active stack frame contents: + - at each function call in a function, and + - based on heuristics looking for mov instructions to a hardcoded buffer. + """ + def __init__(self, sp, min_length): super().__init__(sp, []) self.min_length = min_length @@ -44,6 +49,17 @@ def get_context(self, emu, va, pre_ctx_strings: Optional[Set[str]]) -> Iterator[ def extract_tightstring_contexts(vw, fva, min_length, tloops) -> Iterator[CallContext]: + """Extracts tightstring contexts from a function containing tight loops. + + Args: + vw: The vivisect workspace + fva: The function address + min_length: The minimum string length + tloops: The tight loops in the function + + Returns: + Iterator[CallContext]: An iterator of CallContext objects representing the extracted tightstring contexts. + """ emu = floss.utils.make_emulator(vw) monitor = TightstringContextMonitor(emu.getStackCounter(), min_length) driver_single_path = viv_utils.emulator_drivers.SinglePathEmulatorDriver(emu, repmax=256) @@ -66,33 +82,54 @@ def extract_tightstring_contexts(vw, fva, min_length, tloops) -> Iterator[CallCo # emulate tight loop driver.run_to_va(t.endva) except viv_utils.emulator_drivers.BreakpointHit as e: - logger.debug("hit breakpoint at 0x%x (reason: %s) in function 0x%x", e.va, e.reason, fva) + logger.debug( + "hit breakpoint at 0x%x (reason: %s) in function 0x%x", + e.va, + e.reason, + fva, + ) except Exception as e: - logger.debug("error emulating tight loop starting at 0x%x in function 0x%x: %s", t.startva, fva, e) + logger.debug( + "error emulating tight loop starting at 0x%x in function 0x%x: %s", + t.startva, + fva, + e, + ) yield from monitor.get_context(emu, t.startva, pre_ctx_strings) def extract_tightstrings( - vw, tightloop_functions, min_length, verbosity=Verbosity.DEFAULT, disable_progress=False + vw, + tightloop_functions, + min_length, + verbosity=Verbosity.DEFAULT, + disable_progress=False, ) -> List[TightString]: - """ - Extracts tightstrings from functions that contain tight loops. + """Extracts tightstrings from functions that contain tight loops. + Tightstrings are a special form of stackstrings. Their bytes are loaded on the stack and then modified in a tight loop. To extract tightstrings we use a mix between the string decoding and stackstring algorithms. To reduce computation time we only run this on previously identified functions that contain tight loops. - :param vw: The vivisect workspace - :param tightloop_functions: functions containing tight loops - :param min_length: minimum string length - :param verbosity: verbosity level - :param disable_progress: do NOT show progress bar + Args: + vw: The vivisect workspace + tightloop_functions: A dictionary of functions containing tight loops + min_length: The minimum string length + verbosity: The verbosity level + disable_progress: A flag to disable the progress bar + + Returns: + List[TightString]: A list of TightString objects representing the extracted tightstrings. """ logger.info("extracting tightstrings from %d functions...", len(tightloop_functions)) tight_strings = list() pb = floss.utils.get_progress_bar( - tightloop_functions.items(), disable_progress, desc="extracting tightstrings", unit=" functions" + tightloop_functions.items(), + disable_progress, + desc="extracting tightstrings", + unit=" functions", ) with tqdm.contrib.logging.logging_redirect_tqdm(), floss.utils.redirecting_print_to_tqdm(): for fva, tloops in pb: @@ -104,7 +141,9 @@ def extract_tightstrings( ctxs = extract_tightstring_contexts(vw, fva, min_length, tloops) for n, ctx in enumerate(ctxs, 1): logger.trace( - "extracting tightstring at checkpoint: 0x%x stacksize: 0x%x", ctx.pc, ctx.init_sp - ctx.sp + "extracting tightstring at checkpoint: 0x%x stacksize: 0x%x", + ctx.pc, + ctx.init_sp - ctx.sp, ) logger.trace("pre_ctx strings: %s", ctx.pre_ctx_strings) for s in extract_strings(ctx.stack_memory, min_length, exclude=ctx.pre_ctx_strings): diff --git a/floss/utils.py b/floss/utils.py index e74e5a177..3a20538cb 100644 --- a/floss/utils.py +++ b/floss/utils.py @@ -35,6 +35,8 @@ class ExtendAction(argparse.Action): + """ """ + # stores a list, and extends each argument value to the list # Since Python 3.8 argparse supports this # TODO: remove this code when only supporting Python 3.8+ @@ -101,6 +103,11 @@ def __call__(self, parser, namespace, values, option_string=None): def set_vivisect_log_level(level) -> None: + """Set the log level for vivisect and related modules. + + Args: + level: The log level to set. + """ logging.getLogger("vivisect").setLevel(level) logging.getLogger("vivisect.base").setLevel(level) logging.getLogger("vivisect.impemu").setLevel(level) @@ -110,8 +117,13 @@ def set_vivisect_log_level(level) -> None: def make_emulator(vw) -> Emulator: - """ - create an emulator using consistent settings. + """create an emulator using consistent settings. + + Args: + vw: The vivisect workspace. + + Returns: + Emulator: The emulator instance. """ emu = vw.getEmulator(logwrite=True, taintbyte=b"\xFE") remove_stack_memory(emu) @@ -124,6 +136,11 @@ def make_emulator(vw) -> Emulator: def remove_stack_memory(emu: Emulator): + """Remove the stack memory from the emulator. + + Args: + emu: The emulator instance. + """ # TODO this is a hack while vivisect's initStackMemory() has a bug memory_snap = emu.getMemorySnap() for i in range((len(memory_snap) - 1), -1, -1): @@ -137,9 +154,13 @@ def remove_stack_memory(emu: Emulator): def dump_stack(emu): - """ - Convenience debugging routine for showing - state current state of the stack. + """Convenience debugging routine for showing state current state of the stack. + + Args: + emu: The emulator instance. + + Returns: + str: The stack state. """ esp = emu.getStackCounter() stack_str = "" @@ -148,16 +169,38 @@ def dump_stack(emu): sp = "<= SP" else: sp = "%02x" % (-i) - stack_str = "%s\n0x%08x - 0x%08x %s" % (stack_str, (esp - i), floss.utils.get_stack_value(emu, -i), sp) + stack_str = "%s\n0x%08x - 0x%08x %s" % ( + stack_str, + (esp - i), + floss.utils.get_stack_value(emu, -i), + sp, + ) logger.trace(stack_str) return stack_str def get_stack_value(emu, offset): + """Get the value from the stack at the given offset. + + Args: + emu: The emulator instance. + offset: The offset from the stack pointer. + + Returns: + int: The value from the stack. + """ return emu.readMemoryFormat(emu.getStackCounter() + offset, " Iterable[StaticString]: + """Extracts potential strings from a buffer and applies filtering. + + Initial filtering includes length checks, common false-positive patterns, and optional exclusion based on a provided set. Extracted strings are then stripped or sanitized before yielding. + + Args: + buffer: The byte buffer to analyze. + min_length: The minimum length for a string to be considered valid. + exclude: An optional set of strings to exclude from the results. + + Yields: + Iterable[StaticString]: An iterator of StaticString objects representing the filtered and sanitized strings. + """ if len(buffer) < min_length: return @@ -327,10 +417,13 @@ def extract_strings(buffer: bytes, min_length: int, exclude: Optional[Set[str]] def strip_string(s) -> str: - """ - Return string stripped from false positive (FP) pre- or suffixes. - :param s: input string - :return: string stripped from FP pre- or suffixes + """Return string stripped from false positive (FP) pre- or suffixes. + + Args: + s: The string to strip. + + Returns: + str: The stripped string. """ for reg in ( FP_FILTER_PREFIX_1, @@ -352,8 +445,7 @@ def strip_string(s) -> str: @contextlib.contextmanager def redirecting_print_to_tqdm(): - """ - tqdm (progress bar) expects to have fairly tight control over console output. + """tqdm (progress bar) expects to have fairly tight control over console output. so calls to `print()` will break the progress bar and make things look bad. so, this context manager temporarily replaces the `print` implementation with one that is compatible with tqdm. @@ -362,6 +454,10 @@ def redirecting_print_to_tqdm(): old_print = print def new_print(*args, **kwargs): + """Provides a flexible printing function, prioritizing tqdm progress bars. + + Attempts to print using `tqdm.tqdm.write` for integration with progress bars. If that fails, it falls back to the standard built-in `print` function. + """ # If tqdm.tqdm.write raises error, use builtin print try: tqdm.tqdm.write(*args, **kwargs) @@ -378,6 +474,7 @@ def new_print(*args, **kwargs): @contextlib.contextmanager def timing(msg): + """A context manager for timing a block of code.""" t0 = time.time() yield t1 = time.time() @@ -385,14 +482,41 @@ def timing(msg): def get_runtime_diff(time0): + """Get the runtime difference from the given time. + + Args: + time0: The start time. + + Returns: + float: The runtime difference. + """ return round(time.time() - time0, 4) def is_all_zeros(buffer: bytes): + """Determines if a buffer is entirely filled with null bytes. + + Args: + buffer: The buffer to analyze. + + Returns: + bool: True if the buffer is entirely filled with null bytes, False otherwise. + """ return all([b == 0 for b in buffer]) def get_progress_bar(functions, disable_progress, desc="", unit=""): + """Get a progress bar for the given functions. + + Args: + functions: The functions to process. + disable_progress: Whether to disable the progress bar. + desc: The description for the progress bar. + unit: The unit for the progress bar. + + Returns: + tqdm.tqdm: The progress bar. + """ pbar = tqdm.tqdm if disable_progress: # do not use tqdm to avoid unnecessary side effects when caller intends @@ -402,12 +526,27 @@ def get_progress_bar(functions, disable_progress, desc="", unit=""): def is_thunk_function(vw, function_address): + """Determines if a function is a thunk. + + Args: + vw: The vivisect workspace. + function_address: The address of the function. + + Returns: + bool: True if the function is a thunk, False otherwise. + """ return vw.getFunctionMetaDict(function_address).get("Thunk", False) def round_(i: int, size: int) -> int: - """ - Round `i` to the nearest greater-or-equal-to multiple of `size`. + """Round `i` to the nearest greater-or-equal-to multiple of `size`. + + Args: + i: The integer to round. + size: The size of the multiple. + + Returns: + int: The rounded integer. """ if i % size == 0: return i @@ -415,13 +554,18 @@ def round_(i: int, size: int) -> int: def readStringAtRva(emu, rva, maxsize=None, charsize=1): - """ - Borrowed from vivisect/PE/__init__.py - :param emu: emulator - :param rva: virtual address of string - :param maxsize: maxsize of string - :param charsize: size of character (2 for wide string) - :return: the read string + """Reads a null-terminated string from an emulator at a specified RVA. + + Reads bytes sequentially from the emulator until a null terminator is encountered or the maximum size is reached. + + Args: + emu: The Vivisect emulator object. + rva: The starting RVA (Relative Virtual Address) of the string. + maxsize: An optional maximum number of bytes to read. + charsize: The width of a single character (e.g., 1 for ASCII, 2 for UTF-16). + + Returns: + bytes: The extracted string. """ ret = bytearray() # avoid infinite loop @@ -439,9 +583,15 @@ def readStringAtRva(emu, rva, maxsize=None, charsize=1): def contains_funcname(api, function_names: Tuple[str, ...]): - """ - Returns True if the function name from the call API is part of any of the `function_names` + """Returns True if the function name from the call API is part of any of the `function_names` This ignores casing and underscore prefixes like `_malloc` or `__malloc` + + Args: + api: The call API. + function_names: The function names to check. + + Returns: + bool: True if the function name from the call API is part of any of the `function_names`, False otherwise. """ funcname = get_call_funcname(api) if not funcname or funcname in ("UnknownApi", "?"): @@ -451,20 +601,43 @@ def contains_funcname(api, function_names: Tuple[str, ...]): def call_return(emu, api, argv, value): + """Call the return function for the given emulator, API, arguments, and value. + + Args: + emu: The Vivisect emulator object. + api: The call API. + argv: The arguments. + value: The value to return. + + Returns: + None + """ call_conv = get_call_conv(api) cconv = emu.getCallingConvention(call_conv) cconv.execCallReturn(emu, value, len(argv)) def get_call_conv(api): + """Get the calling convention for the given API.""" return api[2] def get_call_funcname(api): + """Get the function name for the given API.""" return api[3] def is_string_type_enabled(type_, disabled_types, enabled_types): + """Determine if a string type is enabled. + + Args: + type_: The string type. + disabled_types: The disabled types. + enabled_types: The enabled types. + + Returns: + bool: True if the string type is enabled, False otherwise. + """ if disabled_types: return type_ not in disabled_types elif enabled_types: @@ -474,6 +647,17 @@ def is_string_type_enabled(type_, disabled_types, enabled_types): def get_max_size(size: int, max_: int, api: Optional[Tuple] = None, argv: Optional[Tuple] = None) -> int: + """Get the maximum size for the given size. + + Args: + size: The size. + max_: The maximum size. + api: The call API. + argv: The arguments. + + Returns: + int: The maximum size. + """ if size > max_: post = "" if api: @@ -486,6 +670,17 @@ def get_max_size(size: int, max_: int, api: Optional[Tuple] = None, argv: Option def get_referenced_strings(vw: vivisect.VivWorkspace, fva: int) -> Set[str]: + """Collects potential string references from instructions within a function. + + Analyzes instructions within the specified function, seeking operands that might be addresses referencing strings within the workspace. Leverages Vivisect functionality to attempt string extraction. + + Args: + vw: A Vivisect workspace object. + fva: The function virtual address (FVA) to analyze. + + Returns: + Set[str]: A set of potential strings extracted from instruction operands. + """ # modified from capa f: viv_utils.Function = viv_utils.Function(vw, fva) strings: Set[str] = set() @@ -517,11 +712,17 @@ def get_referenced_strings(vw: vivisect.VivWorkspace, fva: int) -> Set[str]: def derefs(vw, p): - """ - recursively follow the given pointer, yielding the valid memory addresses along the way. + """recursively follow the given pointer, yielding the valid memory addresses along the way. useful when you may have a pointer to string, or pointer to pointer to string, etc. this is a "do what i mean" type of helper function. + + Args: + vw: A Vivisect workspace object. + p: The initial pointer address. + + Yields: + Valid memory addresses encountered during dereferencing. """ depth = 0 while True: @@ -549,6 +750,20 @@ def derefs(vw, p): def read_string(vw, offset: int) -> str: + """Attempts to read a string from a Vivisect workspace at the specified offset. + + Handles potential encoding types (UTF-8, UTF-16), segmentation violations, and works around possible Vivisect quirks when detecting string boundaries. + + Args: + vw: A Vivisect workspace object. + offset: The memory offset where the string is suspected to begin. + + Returns: + str: The extracted string. + + Raises: + ValueError: If a valid string cannot be extracted at the given offset. + """ try: alen = vw.detectString(offset) except envi.exc.SegmentationViolation: @@ -580,6 +795,16 @@ def read_string(vw, offset: int) -> str: def read_memory(vw, va: int, size: int) -> bytes: + """Read memory from a Vivisect workspace at the specified virtual address. + + Args: + vw: A Vivisect workspace object. + va: The virtual address to read from. + size: The number of bytes to read. + + Returns: + bytes: The extracted memory. + """ # as documented in #176, vivisect will not readMemory() when the section is not marked readable. # # but here, we don't care about permissions. @@ -596,8 +821,14 @@ def read_memory(vw, va: int, size: int) -> bytes: def get_static_strings(sample: Path, min_length: int) -> list: - """ - Returns list of static strings from the file which are above the minimum length + """Returns list of static strings from the file which are above the minimum length + + Args: + sample: The file to analyze. + min_length: The minimum length of strings to extract. + + Returns: + list: A list of extracted static strings. """ if sample.stat().st_size == 0: diff --git a/scripts/extract_rust_hashes.py b/scripts/extract_rust_hashes.py index c50ba520b..52c9b9f57 100644 --- a/scripts/extract_rust_hashes.py +++ b/scripts/extract_rust_hashes.py @@ -34,7 +34,8 @@ r = requests.get("https://github.com/rust-lang/rust/releases?page={}".format(page_number)) soup = BeautifulSoup(r.text, "html.parser") tables = soup.find_all( - "div", class_="col-md-2 d-flex flex-md-column flex-row flex-wrap pr-md-6 mb-2 mb-md-0 flex-items-start pt-md-4" + "div", + class_="col-md-2 d-flex flex-md-column flex-row flex-wrap pr-md-6 mb-2 mb-md-0 flex-items-start pt-md-4", ) # if there are no more tables, means we have reached the end of the page, break diff --git a/scripts/idaplugin.py b/scripts/idaplugin.py index 6c7bad60c..3c3a16225 100644 --- a/scripts/idaplugin.py +++ b/scripts/idaplugin.py @@ -118,7 +118,10 @@ def apply_decoded_strings(decoded_strings: List[DecodedString]) -> None: def apply_stack_strings( - stack_strings: List[StackString], tight_strings: List[TightString], lvar_cmt: bool = True, cmt: bool = True + stack_strings: List[StackString], + tight_strings: List[TightString], + lvar_cmt: bool = True, + cmt: bool = True, ) -> None: """ lvar_cmt: apply stack variable comment @@ -130,7 +133,10 @@ def apply_stack_strings( continue logger.info( - "decoded stack/tight string in function 0x%x (pc: 0x%x): %s", s.function, s.program_counter, s.string + "decoded stack/tight string in function 0x%x (pc: 0x%x): %s", + s.function, + s.program_counter, + s.string, ) if lvar_cmt: try: @@ -178,7 +184,11 @@ def main(argv=None): logger.info("extracting stackstrings...") selected_functions = floss.identify.get_functions_without_tightloops(decoding_function_features) stack_strings = floss.stackstrings.extract_stackstrings( - vw, selected_functions, MIN_LENGTH, verbosity=floss.render.Verbosity.VERBOSE, disable_progress=True + vw, + selected_functions, + MIN_LENGTH, + verbosity=floss.render.Verbosity.VERBOSE, + disable_progress=True, ) logger.info("decoded %d stack strings", len(stack_strings)) diff --git a/scripts/render-binja-import-script.py b/scripts/render-binja-import-script.py index 24ccd83ac..b24a283b3 100644 --- a/scripts/render-binja-import-script.py +++ b/scripts/render-binja-import-script.py @@ -132,7 +132,10 @@ def main(): logging_group.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR") logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output except fatal errors", ) args = parser.parse_args() diff --git a/scripts/render-ghidra-import-script.py b/scripts/render-ghidra-import-script.py index 49cc18959..77c9fc7d5 100644 --- a/scripts/render-ghidra-import-script.py +++ b/scripts/render-ghidra-import-script.py @@ -120,7 +120,10 @@ def main(): logging_group.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR") logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output except fatal errors", ) args = parser.parse_args() diff --git a/scripts/render-ida-import-script.py b/scripts/render-ida-import-script.py index f2267204e..51b260def 100644 --- a/scripts/render-ida-import-script.py +++ b/scripts/render-ida-import-script.py @@ -126,7 +126,10 @@ def main(): logging_group.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR") logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output except fatal errors", ) args = parser.parse_args() diff --git a/scripts/render-r2-import-script.py b/scripts/render-r2-import-script.py index 635d84757..703b2e5f1 100644 --- a/scripts/render-r2-import-script.py +++ b/scripts/render-r2-import-script.py @@ -75,7 +75,10 @@ def main(): logging_group.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR") logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output except fatal errors", ) args = parser.parse_args() diff --git a/scripts/render-x64dbg-database.py b/scripts/render-x64dbg-database.py index 6dbbd9c7b..99665b614 100644 --- a/scripts/render-x64dbg-database.py +++ b/scripts/render-x64dbg-database.py @@ -86,7 +86,10 @@ def main(): logging_group.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR") logging_group.add_argument( - "-q", "--quiet", action="store_true", help="disable all status output except fatal errors" + "-q", + "--quiet", + action="store_true", + help="disable all status output except fatal errors", ) args = parser.parse_args() diff --git a/tests/conftest.py b/tests/conftest.py index 9d4022e5d..ccdf9e3b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,7 +76,12 @@ def collect(self): filepath = test_dir / filename if filepath.exists(): yield FLOSSTest.from_parent( - self, path=str(filepath), platform=platform, arch=arch, filename=filename, spec=spec + self, + path=str(filepath), + platform=platform, + arch=arch, + filename=filename, + spec=spec, ) diff --git a/tests/test_language_extract_go.py b/tests/test_language_extract_go.py index 8978fa8d8..79dca8975 100644 --- a/tests/test_language_extract_go.py +++ b/tests/test_language_extract_go.py @@ -145,10 +145,20 @@ def test_strings_with_newline_char_0A(request, string, offset, encoding, go_stri [ # .idata:000000000062232A word_62232A dw 0 ; DATA XREF: .idata:0000000000622480↓o # .idata:000000000062232C db 'AddVectoredExceptionHandler',0 mov [rax+8], rcx - pytest.param("AddVectoredExceptionHandler", 0x1C5B2C, StringEncoding.ASCII, "go_strings64"), + pytest.param( + "AddVectoredExceptionHandler", + 0x1C5B2C, + StringEncoding.ASCII, + "go_strings64", + ), # .idata:005E531E word_5E531E dw 0 ; DATA XREF: .idata:005E53D4↓o # .idata:005E5320 db 'AddVectoredExceptionHandler',0 mov [eax+8], ecx - pytest.param("AddVectoredExceptionHandler", 0x1B5120, StringEncoding.ASCII, "go_strings32"), + pytest.param( + "AddVectoredExceptionHandler", + 0x1B5120, + StringEncoding.ASCII, + "go_strings32", + ), ], ) def test_import_data(request, string, offset, encoding, go_strings): diff --git a/tests/test_language_id.py b/tests/test_language_id.py index c0382aa14..b4d4993ff 100644 --- a/tests/test_language_id.py +++ b/tests/test_language_id.py @@ -18,8 +18,16 @@ ), ("data/language/rust/rust-hello/bin/rust-hello.exe", Language.RUST, "1.69.0"), ("data/test-decode-to-stack.exe", Language.UNKNOWN, VERSION_UNKNOWN_OR_NA), - ("data/language/dotnet/dotnet-hello/bin/dotnet-hello.exe", Language.DOTNET, VERSION_UNKNOWN_OR_NA), - ("data/src/shellcode-stackstrings/bin/shellcode-stackstrings.bin", Language.UNKNOWN, VERSION_UNKNOWN_OR_NA), + ( + "data/language/dotnet/dotnet-hello/bin/dotnet-hello.exe", + Language.DOTNET, + VERSION_UNKNOWN_OR_NA, + ), + ( + "data/src/shellcode-stackstrings/bin/shellcode-stackstrings.bin", + Language.UNKNOWN, + VERSION_UNKNOWN_OR_NA, + ), ], ) def test_language_detection(binary_file, expected_result, expected_version):