diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7954ff92..65b66948 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,4 +47,4 @@ jobs: - name: Test with pytest run: | - cd test && pytest --ignore=other_tests + cd test/amd64 && pytest --ignore=other_tests diff --git a/README.md b/README.md index 452d85e4..66f29519 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ ![logo](https://github.com/libdebug/libdebug/blob/dev/media/libdebug_header.png?raw=true) -# libdebug +# libdebug [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.13151549.svg)](https://doi.org/10.5281/zenodo.13151549) + libdebug is an open source Python library to automate the debugging of a binary executable. With libdebug you have full control of the flow of your debugged executable. With it you can: @@ -128,5 +129,6 @@ If you intend to use libdebug in your work, please cite this repository using th publisher = {libdebug.org}, author = {Digregorio, Gabriele and Bertolini, Roberto Alessandro and Panebianco, Francesco and Polino, Mario}, year = {2024}, + doi = {10.5281/zenodo.13151549}, } ``` diff --git a/docs/source/basic_features.rst b/docs/source/basic_features.rst index b2435edf..7a9ab062 100644 --- a/docs/source/basic_features.rst +++ b/docs/source/basic_features.rst @@ -36,6 +36,8 @@ After creating the debugger object, you can start the execution of the program u The `run()` command returns a `PipeManager` object, which you can use to interact with the program's standard input, output, and error. To read more about the PipeManager interface, please refer to the PipeManager documentation :class:`libdebug.utils.pipe_manager.PipeManager`. Please note that breakpoints are not kept between different runs of the program. If you want to set a breakpoint again, you should do so after the program has restarted. +Any process will be automatically killed when the debugging script exits. If you want to prevent this behavior, you can set the `kill_on_exit` parameter to False when creating the debugger object, or set the companion attribute `kill_on_exit` to False at runtime. + The command queue ----------------- Control flow commands, register access and memory access are all done through the command queue. This is a FIFO queue of commands that are executed in order. @@ -83,8 +85,11 @@ Register Access =============== .. _register-access-paragraph: -libdebug offers a simple register access interface for supported architectures. The registers are accessed through the regs attribute of the debugger object. The field includes both general purpose and special registers, as well as the flags register. Effectively, any register that can be accessed by an assembly instruction, can also be accessed through the regs attribute. The debugger specifically exposes properties of the main thread, including the registers. See :doc:`multithreading` to learn how to access registers and other properties from different threads. +libdebug offers a simple register access interface for supported architectures. The registers are accessed through the `regs`` attribute of the debugger object. The field includes both general purpose and special registers, as well as the flags register. Effectively, any register that can be accessed by an assembly instruction, can also be accessed through the regs attribute. The debugger specifically exposes properties of the main thread, including the registers. See :doc:`multithreading` to learn how to access registers and other properties from different threads. +Floating point and vector registers are available as well. The syntax is identical to the one used for integer registers. +For amd64, the list of available AVX registers is determined during installation by checking the CPU capabilities, thus special registers, such as `zmm0` to `zmm31`, are available only on CPUs that support the specific ISA extension. +If you believe that your target CPU supports AVX registers, but they are not available during debugging, please file an issue on the GitHub repository and include your precise hardware details, so that we can investigate and resolve the issue. Memory Access ==================================== @@ -158,6 +163,23 @@ If you specify a full or a substring of a file name, libdebug will search for th You can also use the wildcard string "binary" to use the base address of the binary as the base address for the relative addressing. The same behavior is applied if you pass a string corresponding to the binary name. +Faster Memory Access +------------------- + +By default, libdebug uses the kernel's ptrace interface to access memory. This is guaranteed to work, but it might be slow during large memory transfers. +To speed up memory access, we provide a secondary system that relies on /proc/$pid/mem for read and write operations. You can enable this feature by setting `fast_memory` to True when instancing the debugger. +The final behavior is identical, but the speed is significantly improved. + +Additionally, you can mix the two memory access methods by changing the `fast_memory` attribute of the debugger at runtime: + +.. code-block:: python + + d.fast_memory = True + + # ... + + d.fast_memory = False + Control Flow Commands ==================================== @@ -219,6 +241,15 @@ The available heuristics are: The default heuristic when none is specified is "backtrace". +Next +^^^^ + +The `next` command is similar to the `step` command, but when a ``call`` instruction is found, it will continue until the end of the function being called or until the process stops for other reasons. The syntax is as follows: + +.. code-block:: python + + d.next() + Detach and GDB Migration ==================================== @@ -246,6 +277,9 @@ An alternative to running the program from the beginning and to resume libdebug d.attach(pid) +Do note that libdebug automatically kills any running process when the debugging script exits, even if the debugger has detached from it. +If you want to prevent this behavior, you can set the `kill_on_exit` parameter to False when creating the debugger object, or set the companion attribute `kill_on_exit` to False at runtime. + Graceful Termination ==================== @@ -291,4 +325,4 @@ You can also access registers after the process has died. This is useful for *po Supported Architectures ======================= -libdebug currently only supports Linux under the x86_64 (AMD64) architecture. Support for other architectures is planned for future releases. Stay tuned. \ No newline at end of file +libdebug currently only supports Linux under the x86_64 (AMD64) and AArch64 (ARM64) architectures. Support for other architectures is planned for future releases. Stay tuned. diff --git a/docs/source/breakpoints.rst b/docs/source/breakpoints.rst index 926dde11..88833437 100644 --- a/docs/source/breakpoints.rst +++ b/docs/source/breakpoints.rst @@ -14,7 +14,7 @@ libdebug provides a simple API to set breakpoints in your debugged program. The .. code-block:: python - from libdebug import Debugger + from libdebug import debugger d = debugger("./test_program") @@ -102,20 +102,27 @@ Features of watchpoints are shared with breakpoints, so you can set callbacks, c callback=...) -> Breakpoint: Again, the position can be specified both as a relative address or as a symbol. -The condition parameter specifies the type of access that triggers the watchpoint. The following values are supported: +The condition parameter specifies the type of access that triggers the watchpoint. The following values are supported in all architectures: - ``"w"``: write access - ``"rw"``: read/write access - ``"x"``: execute access +AArch64 additionally supports: + +- ``"r"``: read access + By default, the watchpoint is triggered only on write access. -The length parameter specifies the size of the word being watched. The following values are supported: +The length parameter specifies the size of the word being watched. +In x86_64 (amd64) the following values are supported: - ``1``: byte - ``2``: word - ``4``: dword - ``8``: qword +AArch64 supports any length from 1 to 8 bytes. + By default, the watchpoint is set to watch a byte. diff --git a/docs/source/conf.py b/docs/source/conf.py index adc39e57..85d0b263 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,7 +14,7 @@ project = 'libdebug' copyright = '2024, Gabriele Digregorio, Roberto Alessandro Bertolini, Francesco Panebianco' author = 'JinBlack, Io_no, MrIndeciso, Frank01001' -release = '0.5.4' +release = '0.6.0' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/index.rst b/docs/source/index.rst index 38b221b8..60aca123 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,6 +9,9 @@ ---- +.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.13151549.svg + :target: https://doi.org/10.5281/zenodo.13151549 + Quick Start ==================================== Welcome to libdebug! This powerful Python library can be used to debug your binary executables programmatically, providing a robust, user-friendly interface. @@ -25,7 +28,7 @@ e.g, for version 0.5.0, go to https://docs.libdebug.org/archive/0.5.0 Supported Architectures ----------------------- -libdebug currently supports Linux under the x86_64 architecture. +libdebug currently supports Linux under the x86_64 and AArch64 architectures. Other operating systems and architectures are not supported at this time. @@ -106,7 +109,7 @@ Now that you have libdebug installed, you can start using it in your scripts. He The above script will run the binary `test` in the working directory and stop at the function corresponding to the symbol "function". It will then print the value of the RAX register and kill the process. -Conflicts with other python packages +Conflicts with other Python packages ------------------------------------ The current version of libdebug is incompatible with https://github.com/Gallopsled/pwntools. @@ -120,6 +123,22 @@ Examples of some known issues include: - Attaching libdebug to a process that was started with pwntools with ``shell=True`` will cause the process to attach to the shell process instead. This behavior is described in https://github.com/libdebug/libdebug/issues/57. +Cite Us +------- +Need to cite libdebug in your research? Use the following BibTeX entry: + +.. code-block:: bibtex + + @software{libdebug_2024, + title = {libdebug: {Build} {Your} {Own} {Debugger}}, + copyright = {MIT Licence}, + url = {https://libdebug.org}, + publisher = {libdebug.org}, + author = {Digregorio, Gabriele and Bertolini, Roberto Alessandro and Panebianco, Francesco and Polino, Mario}, + year = {2024}, + doi = {10.5281/zenodo.13151549}, + } + .. toctree:: :maxdepth: 2 :caption: Contents: diff --git a/docs/source/multithreading.rst b/docs/source/multithreading.rst index 483963dc..b1287335 100644 --- a/docs/source/multithreading.rst +++ b/docs/source/multithreading.rst @@ -33,6 +33,7 @@ The following is a list of behaviors to keep in mind when using control flow fun - `cont` will continue all threads. - `step` and `step_until` will step the selected thread. +- `next` will step on the selected thread or, if a call function is found, continue on all threads until the end of the called function or another stopping event. - `finish` will have different behavior depending on the selected heuristic. - `backtrace` will continue on all threads but will stop at any breakpoint that any of the threads hit. - `step-mode` will step exclusively on the thread that has been specified. diff --git a/libdebug/architectures/aarch64/__init__.py b/libdebug/architectures/aarch64/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libdebug/architectures/aarch64/aarch64_breakpoint_validator.py b/libdebug/architectures/aarch64/aarch64_breakpoint_validator.py new file mode 100644 index 00000000..b7c5f4b8 --- /dev/null +++ b/libdebug/architectures/aarch64/aarch64_breakpoint_validator.py @@ -0,0 +1,24 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libdebug.data.breakpoint import Breakpoint + + +def validate_breakpoint_aarch64(bp: Breakpoint) -> None: + """Validate a hardware breakpoint for the AARCH64 architecture.""" + if bp.condition not in ["r", "w", "rw", "x"]: + raise ValueError("Invalid condition for watchpoints. Supported conditions are 'r', 'w', 'rw', 'x'.") + + if not (1 <= bp.length <= 8): + raise ValueError("Invalid length for watchpoints. Supported lengths are between 1 and 8.") + + if bp.condition != "x" and bp.address & 0x7: + raise ValueError("Watchpoint address must be aligned to 8 bytes on aarch64. This is a kernel limitation.") diff --git a/libdebug/architectures/aarch64/aarch64_call_utilities.py b/libdebug/architectures/aarch64/aarch64_call_utilities.py new file mode 100644 index 00000000..cb4f9959 --- /dev/null +++ b/libdebug/architectures/aarch64/aarch64_call_utilities.py @@ -0,0 +1,38 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from libdebug.architectures.call_utilities_manager import CallUtilitiesManager + + +class Aarch64CallUtilities(CallUtilitiesManager): + """Class that provides call utilities for the AArch64 architecture.""" + + def is_call(self: Aarch64CallUtilities, opcode_window: bytes) -> bool: + """Check if the current instruction is a call instruction.""" + # Check for BL instruction + if (opcode_window[3] & 0xFC) == 0x94: + return True + + # Check for BLR instruction + if opcode_window[3] == 0xD6 and (opcode_window[2] & 0x3F) == 0x3F: + return True + + return False + + def compute_call_skip(self: Aarch64CallUtilities, opcode_window: bytes) -> int: + """Compute the instruction size of the current call instruction.""" + # Check for BL instruction + if self.is_call(opcode_window): + return 4 + + return 0 + + def get_call_and_skip_amount(self: Aarch64CallUtilities, opcode_window: bytes) -> tuple[bool, int]: + """Check if the current instruction is a call instruction and compute the instruction size.""" + skip = self.compute_call_skip(opcode_window) + return skip != 0, skip diff --git a/libdebug/architectures/aarch64/aarch64_ptrace_register_holder.py b/libdebug/architectures/aarch64/aarch64_ptrace_register_holder.py new file mode 100644 index 00000000..6a52ecbf --- /dev/null +++ b/libdebug/architectures/aarch64/aarch64_ptrace_register_holder.py @@ -0,0 +1,219 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from libdebug.architectures.aarch64.aarch64_registers import Aarch64Registers +from libdebug.ptrace.ptrace_register_holder import PtraceRegisterHolder + +if TYPE_CHECKING: + from libdebug.state.thread_context import ThreadContext + +AARCH64_GP_REGS = ["x", "w"] + + +def _get_property_64(name: str) -> property: + def getter(self: Aarch64Registers) -> int: + self._internal_debugger._ensure_process_stopped() + return getattr(self.register_file, name) + + def setter(self: Aarch64Registers, value: int) -> None: + self._internal_debugger._ensure_process_stopped() + setattr(self.register_file, name, value) + + return property(getter, setter, None, name) + + +def _get_property_32(name: str) -> property: + def getter(self: Aarch64Registers) -> int: + self._internal_debugger._ensure_process_stopped() + return getattr(self.register_file, name) & 0xFFFFFFFF + + # https://developer.arm.com/documentation/102374/0101/Registers-in-AArch64---general-purpose-registers + # When a W register is written the top 32 bits of the 64-bit register are zeroed. + def setter(self: Aarch64Registers, value: int) -> None: + self._internal_debugger._ensure_process_stopped() + return setattr(self.register_file, name, value & 0xFFFFFFFF) + + return property(getter, setter, None, name) + + +def _get_property_zr(name: str) -> property: + def getter(_: Aarch64Registers) -> int: + return 0 + + def setter(_: Aarch64Registers, __: int) -> None: + pass + + return property(getter, setter, None, name) + + +def _get_property_fp_8(name: str, index: int) -> property: + def getter(self: Aarch64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.vregs[index].data, sys.byteorder) & 0xFF + + def setter(self: Aarch64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(1, sys.byteorder) + self._fp_register_file.vregs[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_16(name: str, index: int) -> property: + def getter(self: Aarch64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.vregs[index].data, sys.byteorder) & 0xFFFF + + def setter(self: Aarch64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(2, sys.byteorder) + self._fp_register_file.vregs[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_32(name: str, index: int) -> property: + def getter(self: Aarch64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.vregs[index].data, sys.byteorder) & 0xFFFFFFFF + + def setter(self: Aarch64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(4, sys.byteorder) + self._fp_register_file.vregs[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_64(name: str, index: int) -> property: + def getter(self: Aarch64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.vregs[index].data, sys.byteorder) & 0xFFFFFFFFFFFFFFFF + + def setter(self: Aarch64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(8, sys.byteorder) + self._fp_register_file.vregs[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_128(name: str, index: int) -> property: + def getter(self: Aarch64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.vregs[index].data, sys.byteorder) + + def setter(self: Aarch64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(16, sys.byteorder) + self._fp_register_file.vregs[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_syscall_num() -> property: + def getter(self: Aarch64Registers) -> int: + self._internal_debugger._ensure_process_stopped() + return self.register_file.x8 + + def setter(self: Aarch64Registers, value: int) -> None: + self._internal_debugger._ensure_process_stopped() + self.register_file.x8 = value + self.register_file.override_syscall_number = True + + return property(getter, setter, None, "syscall_number") + + +@dataclass +class Aarch64PtraceRegisterHolder(PtraceRegisterHolder): + """A class that provides views and setters for the register of an aarch64 process.""" + + def provide_regs_class(self: Aarch64PtraceRegisterHolder) -> type: + """Provide a class to hold the register accessors.""" + return Aarch64Registers + + def apply_on_regs(self: Aarch64PtraceRegisterHolder, target: Aarch64Registers, target_class: type) -> None: + """Apply the register accessors to the Aarch64Registers class.""" + target.register_file = self.register_file + target._fp_register_file = self.fp_register_file + + if hasattr(target_class, "w0"): + return + + for i in range(31): + name_64 = f"x{i}" + name_32 = f"w{i}" + + setattr(target_class, name_64, _get_property_64(name_64)) + setattr(target_class, name_32, _get_property_32(name_64)) + + # setup the floating point registers + for i in range(32): + name_v = f"v{i}" + name_128 = f"q{i}" + name_64 = f"d{i}" + name_32 = f"s{i}" + name_16 = f"h{i}" + name_8 = f"b{i}" + setattr(target_class, name_v, _get_property_fp_128(name_v, i)) + setattr(target_class, name_128, _get_property_fp_128(name_128, i)) + setattr(target_class, name_64, _get_property_fp_64(name_64, i)) + setattr(target_class, name_32, _get_property_fp_32(name_32, i)) + setattr(target_class, name_16, _get_property_fp_16(name_16, i)) + setattr(target_class, name_8, _get_property_fp_8(name_8, i)) + + # setup special aarch64 registers + target_class.pc = _get_property_64("pc") + target_class.sp = _get_property_64("sp") + target_class.lr = _get_property_64("x30") + target_class.fp = _get_property_64("x29") + target_class.xzr = _get_property_zr("xzr") + target_class.wzr = _get_property_zr("wzr") + + def apply_on_thread(self: Aarch64PtraceRegisterHolder, target: ThreadContext, target_class: type) -> None: + """Apply the register accessors to the thread class.""" + target.register_file = self.register_file + + # If the accessors are already defined, we don't need to redefine them + if hasattr(target_class, "instruction_pointer"): + return + + # setup generic "instruction_pointer" property + target_class.instruction_pointer = _get_property_64("pc") + + # setup generic syscall properties + target_class.syscall_return = _get_property_64("x0") + target_class.syscall_arg0 = _get_property_64("x0") + target_class.syscall_arg1 = _get_property_64("x1") + target_class.syscall_arg2 = _get_property_64("x2") + target_class.syscall_arg3 = _get_property_64("x3") + target_class.syscall_arg4 = _get_property_64("x4") + target_class.syscall_arg5 = _get_property_64("x5") + + # syscall number handling is special on aarch64, as the original number is stored in x8 + # but writing to x8 isn't enough to change the actual called syscall + target_class.syscall_number = _get_property_syscall_num() diff --git a/libdebug/architectures/aarch64/aarch64_registers.py b/libdebug/architectures/aarch64/aarch64_registers.py new file mode 100644 index 00000000..d2609de1 --- /dev/null +++ b/libdebug/architectures/aarch64/aarch64_registers.py @@ -0,0 +1,22 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from dataclasses import dataclass + +from libdebug.data.registers import Registers +from libdebug.debugger.internal_debugger_instance_manager import get_global_internal_debugger + + +@dataclass +class Aarch64Registers(Registers): + """This class holds the state of the architectural-dependent registers of a process.""" + + def __init__(self: Aarch64Registers, thread_id: int) -> None: + """Initializes the Registers object.""" + self._internal_debugger = get_global_internal_debugger() + self._thread_id = thread_id diff --git a/libdebug/architectures/aarch64/aarch64_stack_unwinder.py b/libdebug/architectures/aarch64/aarch64_stack_unwinder.py new file mode 100644 index 00000000..7a9b5065 --- /dev/null +++ b/libdebug/architectures/aarch64/aarch64_stack_unwinder.py @@ -0,0 +1,84 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2023-2024 Roberto Alessandro Bertolini, Francesco Panebianco, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from libdebug.architectures.stack_unwinding_manager import StackUnwindingManager +from libdebug.liblog import liblog + +if TYPE_CHECKING: + from libdebug.data.memory_map import MemoryMap + from libdebug.state.thread_context import ThreadContext + + +class Aarch64StackUnwinder(StackUnwindingManager): + """Class that provides stack unwinding for the AArch64 architecture.""" + + def unwind(self: Aarch64StackUnwinder, target: ThreadContext) -> list: + """Unwind the stack of a process. + + Args: + target (ThreadContext): The target ThreadContext. + + Returns: + list: A list of return addresses. + """ + assert hasattr(target.regs, "pc") + + frame_pointer = target.regs.x29 + + vmaps = target._internal_debugger.debugging_interface.maps() + initial_link_register = None + + try: + initial_link_register = self.get_return_address(target, vmaps) + except ValueError: + liblog.warning( + "Failed to get the return address. Check stack frame registers (e.g., base pointer). The stack trace may be incomplete.", + ) + + stack_trace = [target.regs.pc, initial_link_register] if initial_link_register else [target.regs.pc] + + # Follow the frame chain + while frame_pointer: + try: + link_register = int.from_bytes(target.memory[frame_pointer + 8, 8, "absolute"], byteorder="little") + frame_pointer = int.from_bytes(target.memory[frame_pointer, 8, "absolute"], byteorder="little") + + if not any(vmap.start <= link_register < vmap.end for vmap in vmaps): + break + + # Leaf functions don't set the previous stack frame pointer + # But they set the link register to the return address + # Non-leaf functions set both + if initial_link_register and link_register == initial_link_register: + initial_link_register = None + continue + + stack_trace.append(link_register) + except (OSError, ValueError): + break + + return stack_trace + + def get_return_address(self: Aarch64StackUnwinder, target: ThreadContext, vmaps: list[MemoryMap]) -> int: + """Get the return address of the current function. + + Args: + target (ThreadContext): The target ThreadContext. + vmaps (list[MemoryMap]): The memory maps of the process. + + Returns: + int: The return address. + """ + return_address = target.regs.x30 + + if not any(vmap.start <= return_address < vmap.end for vmap in vmaps): + raise ValueError("Return address not in any valid memory map") + + return return_address diff --git a/libdebug/architectures/amd64/amd64_breakpoint_validator.py b/libdebug/architectures/amd64/amd64_breakpoint_validator.py new file mode 100644 index 00000000..cf54a53d --- /dev/null +++ b/libdebug/architectures/amd64/amd64_breakpoint_validator.py @@ -0,0 +1,21 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from libdebug.data.breakpoint import Breakpoint + + +def validate_breakpoint_amd64(bp: Breakpoint) -> None: + """Validate a hardware breakpoint for the AMD64 architecture.""" + if bp.condition not in ["w", "rw", "x"]: + raise ValueError("Invalid condition for watchpoints. Supported conditions are 'w', 'rw', 'x'.") + + if bp.length not in [1, 2, 4, 8]: + raise ValueError("Invalid length for watchpoints. Supported lengths are 1, 2, 4, 8.") diff --git a/libdebug/architectures/amd64/amd64_call_utilities.py b/libdebug/architectures/amd64/amd64_call_utilities.py new file mode 100644 index 00000000..b89d1940 --- /dev/null +++ b/libdebug/architectures/amd64/amd64_call_utilities.py @@ -0,0 +1,65 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from libdebug.architectures.call_utilities_manager import CallUtilitiesManager + + +class Amd64CallUtilities(CallUtilitiesManager): + """Class that provides call utilities for the x86_64 architecture.""" + + def is_call(self, opcode_window: bytes) -> bool: + """Check if the current instruction is a call instruction.""" + # Check for direct CALL (E8 xx xx xx xx) + if opcode_window[0] == 0xE8: + return True + + # Check for indirect CALL using ModR/M (FF /2) + if opcode_window[0] == 0xFF: + # Extract ModR/M byte + modRM = opcode_window[1] + reg = (modRM >> 3) & 0x07 # Middle three bits + + if reg == 2: + return True + + return False + + def compute_call_skip(self, opcode_window: bytes) -> int: + """Compute the instruction size of the current call instruction.""" + # Check for direct CALL (E8 xx xx xx xx) + if opcode_window[0] == 0xE8: + return 5 # Direct CALL + + # Check for indirect CALL using ModR/M (FF /2) + if opcode_window[0] == 0xFF: + # Extract ModR/M byte + modRM = opcode_window[1] + mod = (modRM >> 6) & 0x03 # First two bits + reg = (modRM >> 3) & 0x07 # Next three bits + + # Check if reg field is 010 (indirect CALL) + if reg == 2: + if mod == 0: + if (modRM & 0x07) == 4: + return 3 + (4 if opcode_window[2] == 0x25 else 0) # SIB byte + optional disp32 + elif (modRM & 0x07) == 5: + return 6 # disp32 + return 2 # No displacement + elif mod == 1: + return 3 # disp8 + elif mod == 2: + return 6 # disp32 + elif mod == 3: + return 2 # Register direct + + return 0 # Not a CALL + + def get_call_and_skip_amount(self, opcode_window: bytes) -> tuple[bool, int]: + """Check if the current instruction is a call instruction and compute the instruction size.""" + skip = self.compute_call_skip(opcode_window) + return skip != 0, skip diff --git a/libdebug/architectures/amd64/amd64_ptrace_hw_bp_helper.py b/libdebug/architectures/amd64/amd64_ptrace_hw_bp_helper.py deleted file mode 100644 index 30d66747..00000000 --- a/libdebug/architectures/amd64/amd64_ptrace_hw_bp_helper.py +++ /dev/null @@ -1,166 +0,0 @@ -# -# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Roberto Alessandro Bertolini. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from libdebug.architectures.ptrace_hardware_breakpoint_manager import ( - PtraceHardwareBreakpointManager, -) -from libdebug.liblog import liblog - -if TYPE_CHECKING: - from collections.abc import Callable - - from libdebug.data.breakpoint import Breakpoint - from libdebug.state.thread_context import ThreadContext - - -AMD64_DBGREGS_OFF = { - "DR0": 0x350, - "DR1": 0x358, - "DR2": 0x360, - "DR3": 0x368, - "DR4": 0x370, - "DR5": 0x378, - "DR6": 0x380, - "DR7": 0x388, -} -AMD64_DBGREGS_CTRL_LOCAL = {"DR0": 1 << 0, "DR1": 1 << 2, "DR2": 1 << 4, "DR3": 1 << 6} -AMD64_DBGREGS_CTRL_COND = {"DR0": 16, "DR1": 20, "DR2": 24, "DR3": 28} -AMD64_DBGREGS_CTRL_COND_VAL = {"x": 0, "w": 1, "rw": 3} -AMD64_DBGREGS_CTRL_LEN = {"DR0": 18, "DR1": 22, "DR2": 26, "DR3": 30} -AMD64_DBGREGS_CTRL_LEN_VAL = {1: 0, 2: 1, 8: 2, 4: 3} - -AMD64_DBREGS_COUNT = 4 - - -class Amd64PtraceHardwareBreakpointManager(PtraceHardwareBreakpointManager): - """A hardware breakpoint manager for the amd64 architecture. - - Attributes: - thread (ThreadContext): The target thread. - peek_user (callable): A function that reads a number of bytes from the target thread registers. - poke_user (callable): A function that writes a number of bytes to the target thread registers. - breakpoint_count (int): The number of hardware breakpoints set. - """ - - def __init__( - self: Amd64PtraceHardwareBreakpointManager, - thread: ThreadContext, - peek_user: Callable[[int, int], int], - poke_user: Callable[[int, int, int], None], - ) -> None: - """Initializes the hardware breakpoint manager.""" - super().__init__(thread, peek_user, poke_user) - self.breakpoint_registers: dict[str, Breakpoint | None] = { - "DR0": None, - "DR1": None, - "DR2": None, - "DR3": None, - } - - def install_breakpoint(self: Amd64PtraceHardwareBreakpointManager, bp: Breakpoint) -> None: - """Installs a hardware breakpoint at the provided location.""" - if self.breakpoint_count >= AMD64_DBREGS_COUNT: - raise RuntimeError("No more hardware breakpoints available.") - - # Find the first available breakpoint register - register = next(reg for reg, bp in self.breakpoint_registers.items() if bp is None) - liblog.debugger(f"Installing hardware breakpoint on register {register}.") - - # Write the breakpoint address in the register - self.poke_user(self.thread.thread_id, AMD64_DBGREGS_OFF[register], bp.address) - - # Set the breakpoint control register - ctrl = ( - AMD64_DBGREGS_CTRL_LOCAL[register] - | (AMD64_DBGREGS_CTRL_COND_VAL[bp.condition] << AMD64_DBGREGS_CTRL_COND[register]) - | (AMD64_DBGREGS_CTRL_LEN_VAL[bp.length] << AMD64_DBGREGS_CTRL_LEN[register]) - ) - - # Read the current value of the register - current_ctrl = self.peek_user(self.thread.thread_id, AMD64_DBGREGS_OFF["DR7"]) - - # Clear condition and length fields for the current register - current_ctrl &= ~(0x3 << AMD64_DBGREGS_CTRL_COND[register]) - current_ctrl &= ~(0x3 << AMD64_DBGREGS_CTRL_LEN[register]) - - # Set the new value of the register - current_ctrl |= ctrl - - # Write the new value of the register - self.poke_user(self.thread.thread_id, AMD64_DBGREGS_OFF["DR7"], current_ctrl) - - # Save the breakpoint - self.breakpoint_registers[register] = bp - - liblog.debugger(f"Hardware breakpoint installed on register {register}.") - - self.breakpoint_count += 1 - - def remove_breakpoint(self: Amd64PtraceHardwareBreakpointManager, bp: Breakpoint) -> None: - """Removes a hardware breakpoint at the provided location.""" - if self.breakpoint_count <= 0: - raise RuntimeError("No more hardware breakpoints to remove.") - - # Find the breakpoint register - register = next(reg for reg, bp_ in self.breakpoint_registers.items() if bp_ == bp) - - if register is None: - raise RuntimeError("Hardware breakpoint not found.") - - liblog.debugger(f"Removing hardware breakpoint on register {register}.") - - # Clear the breakpoint address in the register - self.poke_user(self.thread.thread_id, AMD64_DBGREGS_OFF[register], 0) - - # Read the current value of the control register - current_ctrl = self.peek_user(self.thread.thread_id, AMD64_DBGREGS_OFF["DR7"]) - - # Clear the breakpoint control register - current_ctrl &= ~AMD64_DBGREGS_CTRL_LOCAL[register] - - # Write the new value of the register - self.poke_user(self.thread.thread_id, AMD64_DBGREGS_OFF["DR7"], current_ctrl) - - # Remove the breakpoint - self.breakpoint_registers[register] = None - - liblog.debugger(f"Hardware breakpoint removed from register {register}.") - - self.breakpoint_count -= 1 - - def available_breakpoints(self: Amd64PtraceHardwareBreakpointManager) -> int: - """Returns the number of available hardware breakpoint registers.""" - return AMD64_DBREGS_COUNT - self.breakpoint_count - - def is_watchpoint_hit(self: Amd64PtraceHardwareBreakpointManager) -> Breakpoint | None: - """Checks if a watchpoint has been hit. - - Returns: - Breakpoint | None: The watchpoint that has been hit, or None if no watchpoint has been hit. - """ - dr6 = self.peek_user(self.thread.thread_id, AMD64_DBGREGS_OFF["DR6"]) - - watchpoint: Breakpoint | None = None - - # Check the DR6 register to see which watchpoint has been hit - if dr6 & 0x1: - watchpoint = self.breakpoint_registers["DR0"] - elif dr6 & 0x2: - watchpoint = self.breakpoint_registers["DR1"] - elif dr6 & 0x4: - watchpoint = self.breakpoint_registers["DR2"] - elif dr6 & 0x8: - watchpoint = self.breakpoint_registers["DR3"] - - if watchpoint is not None and watchpoint.condition == "x": - # It is a breakpoint, we do not care here - watchpoint = None - - return watchpoint diff --git a/libdebug/architectures/amd64/amd64_ptrace_register_holder.py b/libdebug/architectures/amd64/amd64_ptrace_register_holder.py index 537e1c1a..44597dba 100644 --- a/libdebug/architectures/amd64/amd64_ptrace_register_holder.py +++ b/libdebug/architectures/amd64/amd64_ptrace_register_holder.py @@ -115,6 +115,120 @@ def setter(self: Amd64Registers, value: int) -> None: return property(getter, setter, None, name) +def _get_property_fp_xmm0(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.xmm0[index].data, "little") + + def setter(self: Amd64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + data = value.to_bytes(16, "little") + self._fp_register_file.xmm0[index].data = data + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_ymm0(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + xmm0 = int.from_bytes(self._fp_register_file.xmm0[index].data, "little") + ymm0 = int.from_bytes(self._fp_register_file.ymm0[index].data, "little") + return (ymm0 << 128) | xmm0 + + def setter(self: Amd64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + new_xmm0 = value & ((1 << 128) - 1) + new_ymm0 = value >> 128 + self._fp_register_file.xmm0[index].data = new_xmm0.to_bytes(16, "little") + self._fp_register_file.ymm0[index].data = new_ymm0.to_bytes(16, "little") + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_zmm0(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + zmm0 = int.from_bytes(self._fp_register_file.zmm0[index].data, "little") + ymm0 = int.from_bytes(self._fp_register_file.ymm0[index].data, "little") + xmm0 = int.from_bytes(self._fp_register_file.xmm0[index].data, "little") + return (zmm0 << 256) | (ymm0 << 128) | xmm0 + + def setter(self: Amd64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + new_xmm0 = value & ((1 << 128) - 1) + new_ymm0 = (value >> 128) & ((1 << 128) - 1) + new_zmm0 = value >> 256 + self._fp_register_file.xmm0[index].data = new_xmm0.to_bytes(16, "little") + self._fp_register_file.ymm0[index].data = new_ymm0.to_bytes(16, "little") + self._fp_register_file.zmm0[index].data = new_zmm0.to_bytes(32, "little") + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_xmm1(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + zmm1 = int.from_bytes(self._fp_register_file.zmm1[index].data, "little") + return zmm1 & ((1 << 128) - 1) + + def setter(self: Amd64Registers, value: int) -> None: + # We do not clear the upper 384 bits of the register + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + previous_value = int.from_bytes(self._fp_register_file.zmm1[index].data, "little") + + new_value = (previous_value & ~((1 << 128) - 1)) | (value & ((1 << 128) - 1)) + self._fp_register_file.zmm1[index].data = new_value.to_bytes(64, "little") + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_ymm1(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + zmm1 = int.from_bytes(self._fp_register_file.zmm1[index].data, "little") + return zmm1 & ((1 << 256) - 1) + + def setter(self: Amd64Registers, value: int) -> None: + # We do not clear the upper 256 bits of the register + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + previous_value = self._fp_register_file.zmm1[index] + + new_value = (previous_value & ~((1 << 256) - 1)) | (value & ((1 << 256) - 1)) + self._fp_register_file.zmm1[index].data = new_value.to_bytes(64, "little") + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + +def _get_property_fp_zmm1(name: str, index: int) -> property: + def getter(self: Amd64Registers) -> int: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + return int.from_bytes(self._fp_register_file.zmm1[index].data, "little") + + def setter(self: Amd64Registers, value: int) -> None: + if not self._fp_register_file.fresh: + self._internal_debugger._fetch_fp_registers(self) + self._fp_register_file.zmm1[index].data = value.to_bytes(64, "little") + self._fp_register_file.dirty = True + + return property(getter, setter, None, name) + + @dataclass class Amd64PtraceRegisterHolder(PtraceRegisterHolder): """A class that provides views and setters for the registers of an x86_64 process.""" @@ -126,6 +240,7 @@ def provide_regs_class(self: Amd64PtraceRegisterHolder) -> type: def apply_on_regs(self: Amd64PtraceRegisterHolder, target: Amd64Registers, target_class: type) -> None: """Apply the register accessors to the Amd64Registers class.""" target.register_file = self.register_file + target._fp_register_file = self.fp_register_file # If the accessors are already defined, we don't need to redefine them if hasattr(target_class, "rip"): @@ -170,6 +285,20 @@ def apply_on_regs(self: Amd64PtraceRegisterHolder, target: Amd64Registers, targe # setup special registers target_class.rip = _get_property_64("rip") + # setup floating-point registers + # see libdebug/cffi/ptrace_cffi_build.py for the possible values of fp_register_file.type + match self.fp_register_file.type: + case 0: + self._handle_fp_512(target_class) + case 1: + self._handle_fp_896(target_class) + case 2: + self._handle_fp_2696(target_class) + case _: + raise NotImplementedError( + f"Floating-point register file type {self.fp_register_file.type} not available.", + ) + def apply_on_thread(self: Amd64PtraceRegisterHolder, target: ThreadContext, target_class: type) -> None: """Apply the register accessors to the thread class.""" target.register_file = self.register_file @@ -190,3 +319,45 @@ def apply_on_thread(self: Amd64PtraceRegisterHolder, target: ThreadContext, targ target_class.syscall_arg3 = _get_property_64("r10") target_class.syscall_arg4 = _get_property_64("r8") target_class.syscall_arg5 = _get_property_64("r9") + + def _handle_fp_512(self: Amd64PtraceRegisterHolder, target_class: type) -> None: + """Handle the case where the xsave area is 512 bytes long, which means we just have the xmm registers.""" + for index in range(16): + name = f"xmm{index}" + setattr(target_class, name, _get_property_fp_xmm0(name, index)) + + def _handle_fp_896(self: Amd64PtraceRegisterHolder, target_class: type) -> None: + """Handle the case where the xsave area is 896 bytes long, which means we have the xmm and ymm registers.""" + for index in range(16): + name = f"xmm{index}" + setattr(target_class, name, _get_property_fp_xmm0(name, index)) + + for index in range(16): + name = f"ymm{index}" + setattr(target_class, name, _get_property_fp_ymm0(name, index)) + + def _handle_fp_2696(self: Amd64PtraceRegisterHolder, target_class: type) -> None: + """Handle the case where the xsave area is 2696 bytes long, which means we have 32 zmm registers.""" + for index in range(16): + name = f"xmm{index}" + setattr(target_class, name, _get_property_fp_xmm0(name, index)) + + for index in range(16): + name = f"ymm{index}" + setattr(target_class, name, _get_property_fp_ymm0(name, index)) + + for index in range(16): + name = f"zmm{index}" + setattr(target_class, name, _get_property_fp_zmm0(name, index)) + + for index in range(16): + name = f"xmm{index + 16}" + setattr(target_class, name, _get_property_fp_xmm1(name, index)) + + for index in range(16): + name = f"ymm{index + 16}" + setattr(target_class, name, _get_property_fp_ymm1(name, index)) + + for index in range(16): + name = f"zmm{index + 16}" + setattr(target_class, name, _get_property_fp_zmm1(name, index)) diff --git a/libdebug/architectures/amd64/amd64_registers.py b/libdebug/architectures/amd64/amd64_registers.py index 17fc7734..a8dbf72a 100644 --- a/libdebug/architectures/amd64/amd64_registers.py +++ b/libdebug/architectures/amd64/amd64_registers.py @@ -16,6 +16,7 @@ class Amd64Registers(Registers): """This class holds the state of the architectural-dependent registers of a process.""" - def __init__(self: Amd64Registers) -> None: + def __init__(self: Amd64Registers, thread_id: int) -> None: """Initializes the Registers object.""" self._internal_debugger = get_global_internal_debugger() + self._thread_id = thread_id diff --git a/libdebug/architectures/amd64/amd64_stack_unwinder.py b/libdebug/architectures/amd64/amd64_stack_unwinder.py index 0a82e8b4..6cec1798 100644 --- a/libdebug/architectures/amd64/amd64_stack_unwinder.py +++ b/libdebug/architectures/amd64/amd64_stack_unwinder.py @@ -9,13 +9,13 @@ from typing import TYPE_CHECKING from libdebug.architectures.stack_unwinding_manager import StackUnwindingManager -from libdebug.liblog import logging +from libdebug.liblog import liblog if TYPE_CHECKING: + from libdebug.data.memory_map import MemoryMap from libdebug.state.thread_context import ThreadContext - class Amd64StackUnwinder(StackUnwindingManager): """Class that provides stack unwinding for the x86_64 architecture.""" @@ -39,13 +39,13 @@ def unwind(self: Amd64StackUnwinder, target: ThreadContext) -> list: while current_rbp: try: # Read the return address - return_address = int.from_bytes(target.memory[current_rbp + 8, 8], byteorder="little") + return_address = int.from_bytes(target.memory[current_rbp + 8, 8, "absolute"], byteorder="little") if not any(vmap.start <= return_address < vmap.end for vmap in vmaps): break # Read the previous rbp and set it as the current one - current_rbp = int.from_bytes(target.memory[current_rbp, 8], byteorder="little") + current_rbp = int.from_bytes(target.memory[current_rbp, 8, "absolute"], byteorder="little") stack_trace.append(return_address) except (OSError, ValueError): @@ -54,39 +54,48 @@ def unwind(self: Amd64StackUnwinder, target: ThreadContext) -> list: # If we are in the prolouge of a function, we need to get the return address from the stack # using a slightly more complex method try: - first_return_address = self.get_return_address(target) + first_return_address = self.get_return_address(target, vmaps) - if first_return_address != stack_trace[1]: - stack_trace.insert(1, first_return_address) + if len(stack_trace) > 1: + if first_return_address != stack_trace[1]: + stack_trace.insert(1, first_return_address) + else: + stack_trace.append(first_return_address) except (OSError, ValueError): - logging.WARNING( - "Failed to get the return address from the stack. Check stack frame registers (e.g., base pointer). The stack trace may be incomplete.", + liblog.warning( + "Failed to get the return address. Check stack frame registers (e.g., base pointer). The stack trace may be incomplete.", ) return stack_trace - def get_return_address(self: Amd64StackUnwinder, target: ThreadContext) -> int: + def get_return_address(self: Amd64StackUnwinder, target: ThreadContext, vmaps: list[MemoryMap]) -> int: """Get the return address of the current function. Args: target (ThreadContext): The target ThreadContext. + vmaps (list[MemoryMap]): The memory maps of the process. Returns: int: The return address. """ - instruction_window = target.memory[target.regs.rip, 4] + instruction_window = target.memory[target.regs.rip, 4, "absolute"] # Check if the instruction window is a function preamble and handle each case return_address = None if self._preamble_state(instruction_window) == 0: - return_address = target.memory[target.regs.rbp + 8, 8] + return_address = target.memory[target.regs.rbp + 8, 8, "absolute"] elif self._preamble_state(instruction_window) == 1: - return_address = target.memory[target.regs.rsp, 8] + return_address = target.memory[target.regs.rsp, 8, "absolute"] else: - return_address = target.memory[target.regs.rsp + 8, 8] + return_address = target.memory[target.regs.rsp + 8, 8, "absolute"] + + return_address = int.from_bytes(return_address, byteorder="little") + + if not any(vmap.start <= return_address < vmap.end for vmap in vmaps): + raise ValueError("Return address not in any valid memory map") - return int.from_bytes(return_address, byteorder="little") + return return_address def _preamble_state(self: Amd64StackUnwinder, instruction_window: bytes) -> int: """Check if the instruction window is a function preamble and if so at what stage. diff --git a/libdebug/architectures/breakpoint_validator.py b/libdebug/architectures/breakpoint_validator.py new file mode 100644 index 00000000..6804f497 --- /dev/null +++ b/libdebug/architectures/breakpoint_validator.py @@ -0,0 +1,24 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from libdebug.architectures.aarch64.aarch64_breakpoint_validator import validate_breakpoint_aarch64 +from libdebug.architectures.amd64.amd64_breakpoint_validator import validate_breakpoint_amd64 + +if TYPE_CHECKING: + from libdebug.data.breakpoint import Breakpoint + +def validate_hardware_breakpoint(arch: str, bp: Breakpoint) -> None: + """Validate a hardware breakpoint for the specified architecture.""" + if arch == "aarch64": + validate_breakpoint_aarch64(bp) + elif arch == "amd64": + validate_breakpoint_amd64(bp) + else: + raise ValueError(f"Architecture {arch} not supported") diff --git a/libdebug/architectures/call_utilities_manager.py b/libdebug/architectures/call_utilities_manager.py new file mode 100644 index 00000000..09d26b9c --- /dev/null +++ b/libdebug/architectures/call_utilities_manager.py @@ -0,0 +1,24 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from abc import ABC, abstractmethod + +class CallUtilitiesManager(ABC): + """An architecture-independent interface for call instruction utilities.""" + + @abstractmethod + def is_call(self: CallUtilitiesManager, opcode_window: bytes) -> bool: + """Check if the current instruction is a call instruction.""" + + @abstractmethod + def compute_call_skip(self: CallUtilitiesManager, opcode_window: bytes) -> int: + """Compute the address where to skip after the current call instruction.""" + + @abstractmethod + def get_call_and_skip_amount(self, opcode_window: bytes) -> tuple[bool, int]: + """Check if the current instruction is a call instruction and compute the instruction size.""" \ No newline at end of file diff --git a/libdebug/architectures/call_utilities_provider.py b/libdebug/architectures/call_utilities_provider.py new file mode 100644 index 00000000..4bee02f8 --- /dev/null +++ b/libdebug/architectures/call_utilities_provider.py @@ -0,0 +1,27 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from libdebug.architectures.aarch64.aarch64_call_utilities import ( + Aarch64CallUtilities, +) +from libdebug.architectures.amd64.amd64_call_utilities import ( + Amd64CallUtilities, +) +from libdebug.architectures.call_utilities_manager import CallUtilitiesManager + +_aarch64_call_utilities = Aarch64CallUtilities() +_amd64_call_utilities = Amd64CallUtilities() + + +def call_utilities_provider(architecture: str) -> CallUtilitiesManager: + """Returns an instance of the call utilities provider to be used by the `_InternalDebugger` class.""" + match architecture: + case "amd64": + return _amd64_call_utilities + case "aarch64": + return _aarch64_call_utilities + case _: + raise NotImplementedError(f"Architecture {architecture} not available.") diff --git a/libdebug/architectures/ptrace_hardware_breakpoint_manager.py b/libdebug/architectures/ptrace_hardware_breakpoint_manager.py deleted file mode 100644 index 1328c090..00000000 --- a/libdebug/architectures/ptrace_hardware_breakpoint_manager.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Roberto Alessandro Bertolini. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Callable - - from libdebug.data.breakpoint import Breakpoint - from libdebug.state.thread_context import ThreadContext - - -class PtraceHardwareBreakpointManager(ABC): - """An architecture-independent interface for managing hardware breakpoints. - - Attributes: - thread (ThreadContext): The target thread. - peek_user (callable): A function that reads a number of bytes from the target thread registers. - poke_user (callable): A function that writes a number of bytes to the target thread registers. - breakpoint_count (int): The number of hardware breakpoints set. - """ - - def __init__( - self: PtraceHardwareBreakpointManager, - thread: ThreadContext, - peek_user: Callable[[int, int], int], - poke_user: Callable[[int, int, int], None], - ) -> None: - """Initializes the hardware breakpoint manager.""" - self.thread = thread - self.peek_user = peek_user - self.poke_user = poke_user - self.breakpoint_count = 0 - - @abstractmethod - def install_breakpoint(self: PtraceHardwareBreakpointManager, bp: Breakpoint) -> None: - """Installs a hardware breakpoint at the provided location.""" - - @abstractmethod - def remove_breakpoint(self: PtraceHardwareBreakpointManager, bp: Breakpoint) -> None: - """Removes a hardware breakpoint at the provided location.""" - - @abstractmethod - def available_breakpoints(self: PtraceHardwareBreakpointManager) -> int: - """Returns the number of available hardware breakpoint registers.""" - - @abstractmethod - def is_watchpoint_hit(self: PtraceHardwareBreakpointManager) -> Breakpoint | None: - """Checks if a watchpoint has been hit. - - Returns: - Breakpoint | None: The watchpoint that has been hit, or None if no watchpoint has been hit. - """ diff --git a/libdebug/architectures/ptrace_hardware_breakpoint_provider.py b/libdebug/architectures/ptrace_hardware_breakpoint_provider.py deleted file mode 100644 index 9be66609..00000000 --- a/libdebug/architectures/ptrace_hardware_breakpoint_provider.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Roberto Alessandro Bertolini. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -from collections.abc import Callable - -from libdebug.architectures.amd64.amd64_ptrace_hw_bp_helper import ( - Amd64PtraceHardwareBreakpointManager, -) -from libdebug.architectures.ptrace_hardware_breakpoint_manager import ( - PtraceHardwareBreakpointManager, -) -from libdebug.state.thread_context import ThreadContext -from libdebug.utils.libcontext import libcontext - - -def ptrace_hardware_breakpoint_manager_provider( - thread: ThreadContext, - peek_user: Callable[[int, int], int], - poke_user: Callable[[int, int, int], None], -) -> PtraceHardwareBreakpointManager: - """Returns an instance of the hardware breakpoint manager to be used by the `_InternalDebugger` class.""" - architecture = libcontext.arch - - match architecture: - case "amd64": - return Amd64PtraceHardwareBreakpointManager(thread, peek_user, poke_user) - case _: - raise NotImplementedError(f"Architecture {architecture} not available.") diff --git a/libdebug/architectures/ptrace_software_breakpoint_patcher.py b/libdebug/architectures/ptrace_software_breakpoint_patcher.py index 726e6637..54f2d267 100644 --- a/libdebug/architectures/ptrace_software_breakpoint_patcher.py +++ b/libdebug/architectures/ptrace_software_breakpoint_patcher.py @@ -4,15 +4,13 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # -from libdebug.utils.libcontext import libcontext - -def software_breakpoint_byte_size() -> int: +def software_breakpoint_byte_size(architecture: str) -> int: """Return the size of a software breakpoint instruction.""" - match libcontext.arch: - case "amd64": - return 1 - case "x86": + match architecture: + case "amd64" | "i386": return 1 + case "aarch64": + return 4 case _: - raise ValueError(f"Unsupported architecture: {libcontext.arch}") + raise ValueError(f"Unsupported architecture: {architecture}") diff --git a/libdebug/architectures/register_helper.py b/libdebug/architectures/register_helper.py index 1bec5112..95e20048 100644 --- a/libdebug/architectures/register_helper.py +++ b/libdebug/architectures/register_helper.py @@ -4,25 +4,25 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # -from collections.abc import Callable - +from libdebug.architectures.aarch64.aarch64_ptrace_register_holder import ( + Aarch64PtraceRegisterHolder, +) from libdebug.architectures.amd64.amd64_ptrace_register_holder import ( Amd64PtraceRegisterHolder, ) from libdebug.data.register_holder import RegisterHolder -from libdebug.utils.libcontext import libcontext def register_holder_provider( + architecture: str, register_file: object, - _: Callable[[], object] | None = None, - __: Callable[[object], None] | None = None, + fp_register_file: object, ) -> RegisterHolder: """Returns an instance of the register holder to be used by the `_InternalDebugger` class.""" - architecture = libcontext.arch - match architecture: - case "amd64": - return Amd64PtraceRegisterHolder(register_file) + case "amd64" | "i386": + return Amd64PtraceRegisterHolder(register_file, fp_register_file) + case "aarch64": + return Aarch64PtraceRegisterHolder(register_file, fp_register_file) case _: raise NotImplementedError(f"Architecture {architecture} not available.") diff --git a/libdebug/architectures/stack_unwinding_manager.py b/libdebug/architectures/stack_unwinding_manager.py index 110dc9b1..fb0f8f02 100644 --- a/libdebug/architectures/stack_unwinding_manager.py +++ b/libdebug/architectures/stack_unwinding_manager.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from libdebug.data.memory_map import MemoryMap from libdebug.state.thread_context import ThreadContext @@ -21,5 +22,5 @@ def unwind(self: StackUnwindingManager, target: ThreadContext) -> list: """Unwind the stack of the target process.""" @abstractmethod - def get_return_address(self: StackUnwindingManager, target: ThreadContext) -> int: + def get_return_address(self: StackUnwindingManager, target: ThreadContext, vmaps: list[MemoryMap]) -> int: """Get the return address of the current function.""" diff --git a/libdebug/architectures/stack_unwinding_provider.py b/libdebug/architectures/stack_unwinding_provider.py index 052e06de..51778dec 100644 --- a/libdebug/architectures/stack_unwinding_provider.py +++ b/libdebug/architectures/stack_unwinding_provider.py @@ -4,21 +4,24 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # +from libdebug.architectures.aarch64.aarch64_stack_unwinder import ( + Aarch64StackUnwinder, +) from libdebug.architectures.amd64.amd64_stack_unwinder import ( Amd64StackUnwinder, ) from libdebug.architectures.stack_unwinding_manager import StackUnwindingManager -from libdebug.utils.libcontext import libcontext +_aarch64_stack_unwinder = Aarch64StackUnwinder() _amd64_stack_unwinder = Amd64StackUnwinder() -def stack_unwinding_provider() -> StackUnwindingManager: +def stack_unwinding_provider(architecture: str) -> StackUnwindingManager: """Returns an instance of the stack unwinding provider to be used by the `_InternalDebugger` class.""" - architecture = libcontext.arch - match architecture: case "amd64": return _amd64_stack_unwinder + case "aarch64": + return _aarch64_stack_unwinder case _: raise NotImplementedError(f"Architecture {architecture} not available.") diff --git a/libdebug/architectures/amd64/amd64_syscall_hijacker.py b/libdebug/architectures/syscall_hijacker.py similarity index 89% rename from libdebug/architectures/amd64/amd64_syscall_hijacker.py rename to libdebug/architectures/syscall_hijacker.py index bea4ef4f..99a4f4b6 100644 --- a/libdebug/architectures/amd64/amd64_syscall_hijacker.py +++ b/libdebug/architectures/syscall_hijacker.py @@ -1,6 +1,6 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2024 Gabriele Digregorio. All rights reserved. +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # @@ -8,15 +8,13 @@ from typing import TYPE_CHECKING -from libdebug.architectures.syscall_hijacking_manager import SyscallHijackingManager - if TYPE_CHECKING: from collections.abc import Callable from libdebug.state.thread_context import ThreadContext -class Amd64SyscallHijacker(SyscallHijackingManager): +class SyscallHijacker: """Class that provides syscall hijacking for the x86_64 architecture.""" # Allowed arguments for the hijacker @@ -33,7 +31,7 @@ class Amd64SyscallHijacker(SyscallHijackingManager): ) def create_hijacker( - self: Amd64SyscallHijacker, + self: SyscallHijacker, new_syscall: int, **kwargs: int, ) -> Callable[[ThreadContext, int], None]: @@ -51,7 +49,7 @@ def hijack_on_enter_wrapper(d: ThreadContext, _: int) -> None: return hijack_on_enter_wrapper def _hijack_on_enter( - self: Amd64SyscallHijacker, + self: SyscallHijacker, d: ThreadContext, new_syscall: int, **kwargs: int, diff --git a/libdebug/architectures/syscall_hijacking_manager.py b/libdebug/architectures/syscall_hijacking_manager.py deleted file mode 100644 index 697a3f5f..00000000 --- a/libdebug/architectures/syscall_hijacking_manager.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2024 Gabriele Digregorio. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Callable - - from libdebug.state.thread_context import ThreadContext - - -class SyscallHijackingManager(ABC): - """An architecture-independent interface for syscall hijacking.""" - - @abstractmethod - def create_hijacker( - self: SyscallHijackingManager, - new_syscall: int, - **kwargs: int, - ) -> Callable[[ThreadContext, int], None]: - """Create a new hijacker for the given syscall.""" - - @abstractmethod - def _hijack_on_enter(self: SyscallHijackingManager, d: ThreadContext, new_syscall: int, **kwargs: int) -> None: - """Hijack the syscall on enter.""" diff --git a/libdebug/architectures/syscall_hijacking_provider.py b/libdebug/architectures/syscall_hijacking_provider.py deleted file mode 100644 index 4f2886a7..00000000 --- a/libdebug/architectures/syscall_hijacking_provider.py +++ /dev/null @@ -1,24 +0,0 @@ -# -# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2024 Gabriele Digregorio. All rights reserved. -# Licensed under the MIT license. See LICENSE file in the project root for details. -# - -from libdebug.architectures.amd64.amd64_syscall_hijacker import ( - Amd64SyscallHijacker, -) -from libdebug.architectures.syscall_hijacking_manager import SyscallHijackingManager -from libdebug.utils.libcontext import libcontext - -_amd64_syscall_hijacker = Amd64SyscallHijacker() - - -def syscall_hijacking_provider() -> SyscallHijackingManager: - """Returns an instance of the syscall hijacking provider to be used by the `_InternalDebugger` class.""" - architecture = libcontext.arch - - match architecture: - case "amd64": - return _amd64_syscall_hijacker - case _: - raise NotImplementedError(f"Architecture {architecture} not available.") diff --git a/libdebug/builtin/pretty_print_syscall_handler.py b/libdebug/builtin/pretty_print_syscall_handler.py index c9b204f2..4dca9e41 100644 --- a/libdebug/builtin/pretty_print_syscall_handler.py +++ b/libdebug/builtin/pretty_print_syscall_handler.py @@ -26,8 +26,8 @@ def pprint_on_enter(d: ThreadContext, syscall_number: int, **kwargs: int) -> Non syscall_number (int): the syscall number. **kwargs (bool): the keyword arguments. """ - syscall_name = resolve_syscall_name(syscall_number) - syscall_args = resolve_syscall_arguments(syscall_number) + syscall_name = resolve_syscall_name(d._internal_debugger.arch, syscall_number) + syscall_args = resolve_syscall_arguments(d._internal_debugger.arch, syscall_number) values = [ d.syscall_arg0, diff --git a/libdebug/cffi/debug_sym_cffi_source.c b/libdebug/cffi/debug_sym_cffi_source.c index 1e660b46..75fc93b5 100644 --- a/libdebug/cffi/debug_sym_cffi_source.c +++ b/libdebug/cffi/debug_sym_cffi_source.c @@ -23,6 +23,8 @@ typedef struct SymbolInfo struct SymbolInfo *next; } SymbolInfo; +void process_symbol_tables(Elf *elf); + // Function to add new symbol info to the linked list SymbolInfo *add_symbol_info(SymbolInfo **head, const char *name, Dwarf_Addr low_pc, Dwarf_Addr high_pc) { diff --git a/libdebug/cffi/debug_sym_cffi_source_legacy.c b/libdebug/cffi/debug_sym_cffi_source_legacy.c index 39b18bce..8d1a7567 100644 --- a/libdebug/cffi/debug_sym_cffi_source_legacy.c +++ b/libdebug/cffi/debug_sym_cffi_source_legacy.c @@ -23,6 +23,8 @@ typedef struct SymbolInfo struct SymbolInfo *next; } SymbolInfo; +void process_symbol_tables(Elf *elf); + // Function to add new symbol info to the linked list SymbolInfo *add_symbol_info(SymbolInfo **head, const char *name, Dwarf_Addr low_pc, Dwarf_Addr high_pc) { diff --git a/libdebug/cffi/ptrace_cffi_build.py b/libdebug/cffi/ptrace_cffi_build.py index 95f8d0a3..101044ae 100644 --- a/libdebug/cffi/ptrace_cffi_build.py +++ b/libdebug/cffi/ptrace_cffi_build.py @@ -5,12 +5,147 @@ # import platform +from pathlib import Path from cffi import FFI -if platform.machine() == "x86_64": - user_regs_struct = """ - struct user_regs_struct + +architecture = platform.machine() + +if architecture == "x86_64": + # We need to determine if we have AVX, AVX2, AVX512, etc. + path = Path("/proc/cpuinfo") + + try: + with path.open() as f: + cpuinfo = f.read() + except OSError as e: + raise RuntimeError("Cannot read /proc/cpuinfo. Are you running on Linux?") from e + + if "avx512" in cpuinfo: + fp_regs_struct = """ + struct reg_128 + { + unsigned char data[16]; + }; + + struct reg_256 + { + unsigned char data[32]; + }; + + struct reg_512 + { + unsigned char data[64]; + }; + + // For details about the layout of the xsave structure, see Intel's Architecture Instruction Set Extensions Programming Reference + // Chapter 3.2.4 "The Layout of XSAVE Save Area" + // https://www.intel.com/content/dam/develop/external/us/en/documents/319433-024-697869.pdf + #pragma pack(push, 1) + struct fp_regs_struct + { + unsigned long type; + _Bool dirty; // true if the debugging script has modified the state of the registers + _Bool fresh; // true if the registers have already been fetched for this state + unsigned char bool_padding[6]; + unsigned char padding0[32]; + struct reg_128 st[8]; + struct reg_128 xmm0[16]; + unsigned char padding1[96]; + // end of the 512 byte legacy region + unsigned char padding2[64]; + // ymm0 starts at offset 576 + struct reg_128 ymm0[16]; + unsigned char padding3[320]; + // zmm0 starts at offset 1152 + struct reg_256 zmm0[16]; + // zmm1 starts at offset 1664 + struct reg_512 zmm1[16]; + unsigned char padding4[8]; + }; + #pragma pack(pop) + """ + + fpregs_define = """ + #define FPREGS_AVX 2 + """ + elif "avx" in cpuinfo: + fp_regs_struct = """ + struct reg_128 + { + unsigned char data[16]; + }; + + // For details about the layout of the xsave structure, see Intel's Architecture Instruction Set Extensions Programming Reference + // Chapter 3.2.4 "The Layout of XSAVE Save Area" + // https://www.intel.com/content/dam/develop/external/us/en/documents/319433-024-697869.pdf + #pragma pack(push, 1) + struct fp_regs_struct + { + unsigned long type; + _Bool dirty; // true if the debugging script has modified the state of the registers + _Bool fresh; // true if the registers have already been fetched for this state + unsigned char bool_padding[6]; + unsigned char padding0[32]; + struct reg_128 st[8]; + struct reg_128 xmm0[16]; + unsigned char padding1[96]; + // end of the 512 byte legacy region + unsigned char padding2[64]; + // ymm0 starts at offset 576 + struct reg_128 ymm0[16]; + unsigned char padding3[64]; + }; + #pragma pack(pop) + """ + + fpregs_define = """ + #define FPREGS_AVX 1 + """ + else: + fp_regs_struct = """ + struct reg_128 + { + unsigned char data[16]; + }; + + // For details about the layout of the xsave structure, see Intel's Architecture Instruction Set Extensions Programming Reference + // Chapter 3.2.4 "The Layout of XSAVE Save Area" + // https://www.intel.com/content/dam/develop/external/us/en/documents/319433-024-697869.pdf + #pragma pack(push, 1) + struct fp_regs_struct + { + unsigned long type; + _Bool dirty; // true if the debugging script has modified the state of the registers + _Bool fresh; // true if the registers have already been fetched for this state + unsigned char bool_padding[6]; + unsigned char padding0[32]; + struct reg_128 st[8]; + struct reg_128 xmm0[16]; + unsigned char padding1[96]; + }; + #pragma pack(pop) + """ + + fpregs_define = """ + #define FPREGS_AVX 0 + """ + + if "xsave" not in cpuinfo: + fpregs_define += """ + #define XSAVE 0 + """ + + # We don't support non-XSAVE architectures + raise NotImplementedError("XSAVE not supported. Please open an issue on GitHub and include your hardware details.") + else: + fpregs_define += """ + #define XSAVE 1 + """ + + ptrace_regs_struct = """ + struct ptrace_regs_struct { unsigned long r15; unsigned long r14; @@ -42,6 +177,10 @@ }; """ + arch_define = """ + #define ARCH_AMD64 + """ + breakpoint_define = """ #define INSTRUCTION_POINTER(regs) (regs.rip) #define INSTALL_BREAKPOINT(instruction) ((instruction & 0xFFFFFFFFFFFFFF00) | 0xCC) @@ -49,10 +188,10 @@ #define IS_SW_BREAKPOINT(instruction) (instruction == 0xCC) """ - finish_define = """ + control_flow_define = """ + // X86_64 Architecture specific #define IS_RET_INSTRUCTION(instruction) (instruction == 0xC3 || instruction == 0xCB || instruction == 0xC2 || instruction == 0xCA) - // X86_64 Architecture specific int IS_CALL_INSTRUCTION(uint8_t* instr) { // Check for direct CALL (E8 xx xx xx xx) @@ -74,14 +213,109 @@ return 0; // Not a CALL } """ +elif architecture == "aarch64": + fp_regs_struct = """ + struct reg_128 + { + unsigned char data[16]; + }; + + // /usr/include/aarch64-linux-gnu/asm/ptrace.h + #pragma pack(push, 1) + struct fp_regs_struct + { + _Bool dirty; // true if the debugging script has modified the state of the registers + _Bool fresh; // true if the registers have already been fetched for this state + unsigned char bool_padding[2]; + struct reg_128 vregs[32]; + unsigned int fpsr; + unsigned int fpcr; + unsigned long padding; + }; + #pragma pack(pop) + """ + + fpregs_define = "" + + ptrace_regs_struct = """ + struct ptrace_regs_struct + { + unsigned long x0; + unsigned long x1; + unsigned long x2; + unsigned long x3; + unsigned long x4; + unsigned long x5; + unsigned long x6; + unsigned long x7; + unsigned long x8; + unsigned long x9; + unsigned long x10; + unsigned long x11; + unsigned long x12; + unsigned long x13; + unsigned long x14; + unsigned long x15; + unsigned long x16; + unsigned long x17; + unsigned long x18; + unsigned long x19; + unsigned long x20; + unsigned long x21; + unsigned long x22; + unsigned long x23; + unsigned long x24; + unsigned long x25; + unsigned long x26; + unsigned long x27; + unsigned long x28; + unsigned long x29; + unsigned long x30; + unsigned long sp; + unsigned long pc; + unsigned long pstate; + _Bool override_syscall_number; + }; + """ + + arch_define = """ + #define ARCH_AARCH64 + """ + + breakpoint_define = """ + #define INSTRUCTION_POINTER(regs) (regs.pc) + #define INSTALL_BREAKPOINT(instruction) ((instruction & 0xFFFFFFFF00000000) | 0xD4200000) + #define BREAKPOINT_SIZE 4 + #define IS_SW_BREAKPOINT(instruction) (instruction == 0xD4200000) + """ + + control_flow_define = """ + #define IS_RET_INSTRUCTION(instruction) (instruction == 0xD65F03C0) + + // AARCH64 Architecture specific + int IS_CALL_INSTRUCTION(uint8_t* instr) + { + // Check for direct CALL (BL) + if ((instr[3] & 0xFC) == 0x94) { + return 1; // It's a CALL + } + + // Check for indirect CALL (BLR) + if ((instr[3] == 0xD6 && (instr[2] & 0x3F) == 0x3F)) { + return 1; // It's a CALL + } + + return 0; // Not a CALL + } + """ else: raise NotImplementedError(f"Architecture {platform.machine()} not available.") ffibuilder = FFI() -ffibuilder.cdef( - user_regs_struct - + """ +ffibuilder.cdef(ptrace_regs_struct) +ffibuilder.cdef(fp_regs_struct, packed=True) +ffibuilder.cdef(""" struct ptrace_hit_bp { int pid; uint64_t addr; @@ -97,9 +331,19 @@ struct software_breakpoint *next; }; + struct hardware_breakpoint { + uint64_t addr; + int tid; + char enabled; + char type[2]; + char len; + struct hardware_breakpoint *next; + }; + struct thread { int tid; - struct user_regs_struct regs; + struct ptrace_regs_struct regs; + struct fp_regs_struct fpregs; int signal_to_forward; struct thread *next; }; @@ -113,7 +357,8 @@ struct global_state { struct thread *t_HEAD; struct thread *dead_t_HEAD; - struct software_breakpoint *b_HEAD; + struct software_breakpoint *sw_b_HEAD; + struct hardware_breakpoint *hw_b_HEAD; _Bool handle_syscall_enabled; }; @@ -129,8 +374,9 @@ uint64_t ptrace_peekdata(int pid, uint64_t addr); uint64_t ptrace_pokedata(int pid, uint64_t addr, uint64_t data); - uint64_t ptrace_peekuser(int pid, uint64_t addr); - uint64_t ptrace_pokeuser(int pid, uint64_t addr, uint64_t data); + struct fp_regs_struct *get_thread_fp_regs(struct global_state *state, int tid); + void get_fp_regs(int tid, struct fp_regs_struct *fpregs); + void set_fp_regs(int tid, struct fp_regs_struct *fpregs); uint64_t ptrace_geteventmsg(int pid); @@ -144,7 +390,7 @@ struct thread_status *wait_all_and_update_regs(struct global_state *state, int pid); void free_thread_status_list(struct thread_status *head); - struct user_regs_struct* register_thread(struct global_state *state, int tid); + struct ptrace_regs_struct* register_thread(struct global_state *state, int tid); void unregister_thread(struct global_state *state, int tid); void free_thread_list(struct global_state *state); @@ -152,6 +398,15 @@ void unregister_breakpoint(struct global_state *state, uint64_t address); void enable_breakpoint(struct global_state *state, uint64_t address); void disable_breakpoint(struct global_state *state, uint64_t address); + + void register_hw_breakpoint(struct global_state *state, int tid, uint64_t address, char type[2], char len); + void unregister_hw_breakpoint(struct global_state *state, int tid, uint64_t address); + void enable_hw_breakpoint(struct global_state *state, int tid, uint64_t address); + void disable_hw_breakpoint(struct global_state *state, int tid, uint64_t address); + unsigned long get_hit_hw_breakpoint(struct global_state *state, int tid); + int get_remaining_hw_breakpoint_count(struct global_state *state, int tid); + int get_remaining_hw_watchpoint_count(struct global_state *state, int tid); + void free_breakpoints(struct global_state *state); """ ) @@ -159,7 +414,13 @@ with open("libdebug/cffi/ptrace_cffi_source.c") as f: ffibuilder.set_source( "libdebug.cffi._ptrace_cffi", - breakpoint_define + finish_define + f.read(), + ptrace_regs_struct + + arch_define + + fp_regs_struct + + fpregs_define + + breakpoint_define + + control_flow_define + + f.read(), libraries=[], ) diff --git a/libdebug/cffi/ptrace_cffi_source.c b/libdebug/cffi/ptrace_cffi_source.c index 8ee7da4e..bfd42780 100644 --- a/libdebug/cffi/ptrace_cffi_source.c +++ b/libdebug/cffi/ptrace_cffi_source.c @@ -4,6 +4,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for details. // +#include #include #include #include @@ -11,9 +12,31 @@ #include #include #include +#include #include #include +// Run some static assertions to ensure that the fp types are correct +#ifdef ARCH_AMD64 + #ifndef FPREGS_AVX + #error "FPREGS_AVX must be defined" + #endif + + #ifndef XSAVE + #error "XSAVE must be defined" + #endif + + #if (FPREGS_AVX == 0) + _Static_assert((sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, padding0)) == 512, "user_fpregs_struct size is not 512 bytes"); + #elif (FPREGS_AVX == 1) + _Static_assert((sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, padding0)) == 896, "user_fpregs_struct size is not 896 bytes"); + #elif (FPREGS_AVX == 2) + _Static_assert((sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, padding0)) == 2696, "user_fpregs_struct size is not 2696 bytes"); + #else + #error "FPREGS_AVX must be 0, 1 or 2" + #endif +#endif + struct ptrace_hit_bp { int pid; uint64_t addr; @@ -29,9 +52,19 @@ struct software_breakpoint { struct software_breakpoint *next; }; +struct hardware_breakpoint { + uint64_t addr; + int tid; + char enabled; + char type[2]; + char len; + struct hardware_breakpoint *next; +}; + struct thread { int tid; - struct user_regs_struct regs; + struct ptrace_regs_struct regs; + struct fp_regs_struct fpregs; int signal_to_forward; struct thread *next; }; @@ -45,11 +78,412 @@ struct thread_status { struct global_state { struct thread *t_HEAD; struct thread *dead_t_HEAD; - struct software_breakpoint *b_HEAD; + struct software_breakpoint *sw_b_HEAD; + struct hardware_breakpoint *hw_b_HEAD; _Bool handle_syscall_enabled; }; -struct user_regs_struct *register_thread(struct global_state *state, int tid) +#ifdef ARCH_AMD64 +int getregs(int tid, struct ptrace_regs_struct *regs) +{ + return ptrace(PTRACE_GETREGS, tid, NULL, regs); +} + +int setregs(int tid, struct ptrace_regs_struct *regs) +{ + return ptrace(PTRACE_SETREGS, tid, NULL, regs); +} +#endif + +#ifdef ARCH_AARCH64 +int getregs(int tid, struct ptrace_regs_struct *regs) +{ + regs->override_syscall_number = 0; + + struct iovec iov; + iov.iov_base = regs; + iov.iov_len = sizeof(struct ptrace_regs_struct); + return ptrace(PTRACE_GETREGSET, tid, NT_PRSTATUS, &iov); +} + +int setregs(int tid, struct ptrace_regs_struct *regs) +{ + struct iovec iov; + + if (regs->override_syscall_number) { + iov.iov_base = ®s->x8; + iov.iov_len = sizeof(regs->x8); + ptrace(PTRACE_SETREGSET, tid, NT_ARM_SYSTEM_CALL, &iov); + regs->override_syscall_number = 0; + } + + iov.iov_base = regs; + iov.iov_len = sizeof(struct ptrace_regs_struct); + return ptrace(PTRACE_SETREGSET, tid, NT_PRSTATUS, &iov); +} +#endif + +#ifdef ARCH_AMD64 + +#define DR_BASE 0x350 +#define DR_SIZE 0x8 +#define CTRL_LOCAL(x) (1 << (2 * x)) +#define CTRL_COND(x) (16 + (4 * x)) +#define CTRL_COND_VAL(x) (x == 'x' ? 0 : (x == 'w' ? 1 : 3)) +#define CTRL_LEN(x) (18 + (4 * x)) +#define CTRL_LEN_VAL(x) (x == 1 ? 0 : (x == 2 ? 1 : (x == 8 ? 2 : 3))) + +void install_hardware_breakpoint(struct hardware_breakpoint *bp) +{ + // find a free debug register + int i; + for (i = 0; i < 4; i++) { + unsigned long address = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + i * DR_SIZE); + + if (!address) + break; + } + + if (i == 4) { + perror("No debug registers available"); + return; + } + + unsigned long ctrl = CTRL_LOCAL(i) | CTRL_COND_VAL(bp->type[0]) << CTRL_COND(i) | CTRL_LEN_VAL(bp->len) << CTRL_LEN(i); + + // read the state from DR7 + unsigned long state = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + 7 * DR_SIZE); + + // reset the state, for good measure + state &= ~(3 << CTRL_COND(i)); + state &= ~(3 << CTRL_LEN(i)); + + // register the breakpoint + state |= ctrl; + + // write the address and the state + ptrace(PTRACE_POKEUSER, bp->tid, DR_BASE + i * DR_SIZE, bp->addr); + ptrace(PTRACE_POKEUSER, bp->tid, DR_BASE + 7 * DR_SIZE, state); +} + + +void remove_hardware_breakpoint(struct hardware_breakpoint *bp) +{ + // find the register + int i; + for (i = 0; i < 4; i++) { + unsigned long address = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + i * DR_SIZE); + + if (address == bp->addr) + break; + } + + if (i == 4) { + perror("Breakpoint not found"); + return; + } + + // read the state from DR7 + unsigned long state = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + 7 * DR_SIZE); + + // reset the state + state &= ~(3 << CTRL_COND(i)); + state &= ~(3 << CTRL_LEN(i)); + + // write the state + ptrace(PTRACE_POKEUSER, bp->tid, DR_BASE + 7 * DR_SIZE, state); + + // clear the address + ptrace(PTRACE_POKEUSER, bp->tid, DR_BASE + i * DR_SIZE, 0); +} + +int is_breakpoint_hit(struct hardware_breakpoint *bp) +{ + unsigned long status = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + 6 * DR_SIZE); + + int index; + if (status & 0x1) + index = 0; + else if (status & 0x2) + index = 1; + else if (status & 0x4) + index = 2; + else if (status & 0x8) + index = 3; + else + return 0; + + unsigned long address = ptrace(PTRACE_PEEKUSER, bp->tid, DR_BASE + index * DR_SIZE); + + if (address == bp->addr) + return 1; + + return 0; +} + +int get_remaining_hw_breakpoint_count(struct global_state *state, int tid) +{ + int i; + for (i = 0; i < 4; i++) { + unsigned long address = ptrace(PTRACE_PEEKUSER, tid, DR_BASE + i * DR_SIZE); + + if (!address) + break; + } + + return 4 - i; +} + +int get_remaining_hw_watchpoint_count(struct global_state *state, int tid) +{ + return get_remaining_hw_breakpoint_count(state, tid); +} +#endif + +#ifdef ARCH_AARCH64 +struct user_hwdebug_state { + unsigned int dbg_info; + unsigned int pad; + struct { + unsigned long addr; + unsigned int ctrl; + unsigned int pad; + } dbg_regs[16]; +}; + +int get_breakpoint_type(char type[2]) +{ + if (type[0] == 'r') { + if (type[1] == 'w') { + return 3; + } else { + return 1; + } + } else if (type[0] == 'w') { + return 2; + } else if (type[0] == 'x') { + return 0; + } else { + return -1; + } +} + +void install_hardware_breakpoint(struct hardware_breakpoint *bp) +{ + // find a free debug register + struct user_hwdebug_state state = {0}; + + struct iovec iov; + iov.iov_base = &state; + iov.iov_len = sizeof state; + + unsigned long command = get_breakpoint_type(bp->type) == 0 ? NT_ARM_HW_BREAK : NT_ARM_HW_WATCH; + + ptrace(PTRACE_GETREGSET, bp->tid, command, &iov); + + int i; + for (i = 0; i < 16; i++) { + if (!state.dbg_regs[i].addr) + break; + } + + if (i == 16) { + perror("No debug registers available"); + return; + } + + if (bp->type[0] == 'x') { + // Hardware breakpoint can only be of length 4 + bp->len = 4; + } + + unsigned int length = (1 << bp->len) - 1; + unsigned int condition = get_breakpoint_type(bp->type); + unsigned int control = (length << 5) | (condition << 3) | (2 << 1) | 1; + + state.dbg_regs[i].addr = bp->addr; + state.dbg_regs[i].ctrl = control; + + ptrace(PTRACE_SETREGSET, bp->tid, command, &iov); +} + +void remove_hardware_breakpoint(struct hardware_breakpoint *bp) +{ + struct user_hwdebug_state state = {0}; + + struct iovec iov; + iov.iov_base = &state; + iov.iov_len = sizeof state; + + unsigned long command = get_breakpoint_type(bp->type) == 0 ? NT_ARM_HW_BREAK : NT_ARM_HW_WATCH; + + ptrace(PTRACE_GETREGSET, bp->tid, command, &iov); + + int i; + for (i = 0; i < 16; i++) { + if (state.dbg_regs[i].addr == bp->addr) + break; + } + + if (i == 16) { + perror("Breakpoint not found"); + return; + } + + state.dbg_regs[i].addr = 0; + state.dbg_regs[i].ctrl = 0; + + ptrace(PTRACE_SETREGSET, bp->tid, command, &iov); +} + +int is_breakpoint_hit(struct hardware_breakpoint *bp) +{ + siginfo_t si; + + if (ptrace(PTRACE_GETSIGINFO, bp->tid, NULL, &si) == -1) { + return 0; + } + + // Check that the signal is a SIGTRAP and the code is 0x4 + if (!(si.si_signo == SIGTRAP && si.si_code == 0x4)) { + return 0; + } + + unsigned long addr = (unsigned long) si.si_addr; + + if (addr == bp->addr) { + return 1; + } + + return 0; +} + +int _get_remaining_count(struct global_state *state, int tid, int command) +{ + struct user_hwdebug_state dbg_state = {0}; + + struct iovec iov; + iov.iov_base = &dbg_state; + iov.iov_len = sizeof dbg_state; + + ptrace(PTRACE_GETREGSET, tid, command, &iov); + + return dbg_state.dbg_info & 0xff; +} + +int get_remaining_hw_breakpoint_count(struct global_state *state, int tid) +{ + return _get_remaining_count(state, tid, NT_ARM_HW_BREAK); +} + +int get_remaining_hw_watchpoint_count(struct global_state *state, int tid) +{ + return _get_remaining_count(state, tid, NT_ARM_HW_WATCH); +} +#endif + +struct thread *get_thread(struct global_state *state, int tid) +{ + struct thread *t = state->t_HEAD; + while (t != NULL) { + if (t->tid == tid) return t; + t = t->next; + } + + return NULL; +} + +struct fp_regs_struct *get_thread_fp_regs(struct global_state *state, int tid) +{ + struct thread *t = get_thread(state, tid); + + if (t) { + return &t->fpregs; + } + + return NULL; +} + +#ifdef ARCH_AMD64 +void get_fp_regs(int tid, struct fp_regs_struct *fpregs) +{ + #if (XSAVE == 0) + + #else + struct iovec iov; + + iov.iov_base = (unsigned char *)(fpregs) + offsetof(struct fp_regs_struct, padding0); + iov.iov_len = sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, padding0); + + if (ptrace(PTRACE_GETREGSET, tid, NT_X86_XSTATE, &iov) == -1) { + perror("ptrace_getregset_xstate"); + } + #endif + + fpregs->fresh = 1; +} + +void set_fp_regs(int tid, struct fp_regs_struct *fpregs) +{ + #if (XSAVE == 0) + + #else + struct iovec iov; + + iov.iov_base = (unsigned char *)(fpregs) + offsetof(struct fp_regs_struct, padding0); + iov.iov_len = sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, padding0); + + if (ptrace(PTRACE_SETREGSET, tid, NT_X86_XSTATE, &iov) == -1) { + perror("ptrace_setregset_xstate"); + } + #endif + + fpregs->dirty = 0; + fpregs->fresh = 0; +} +#endif + +#ifdef ARCH_AARCH64 +void get_fp_regs(int tid, struct fp_regs_struct *fpregs) +{ + struct iovec iov; + + iov.iov_base = (unsigned char *)(fpregs) + offsetof(struct fp_regs_struct, vregs); + iov.iov_len = sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, vregs); + + if (ptrace(PTRACE_GETREGSET, tid, NT_FPREGSET, &iov) == -1) { + perror("ptrace_getregset_xstate"); + } + + fpregs->fresh = 1; +} + +void set_fp_regs(int tid, struct fp_regs_struct *fpregs) +{ + struct iovec iov; + + iov.iov_base = (unsigned char *)(fpregs) + offsetof(struct fp_regs_struct, vregs); + iov.iov_len = sizeof(struct fp_regs_struct) - offsetof(struct fp_regs_struct, vregs); + + if (ptrace(PTRACE_SETREGSET, tid, NT_FPREGSET, &iov) == -1) { + perror("ptrace_setregset_xstate"); + } + + fpregs->dirty = 0; + fpregs->fresh = 0; +} +#endif + +void check_and_set_fp_regs(struct thread *t) +{ + if (t->fpregs.dirty) { + set_fp_regs(t->tid, &t->fpregs); + } + + t->fpregs.fresh = 0; +} + +struct ptrace_regs_struct *register_thread(struct global_state *state, int tid) { // Verify if the thread is already registered struct thread *t = state->t_HEAD; @@ -62,7 +496,13 @@ struct user_regs_struct *register_thread(struct global_state *state, int tid) t->tid = tid; t->signal_to_forward = 0; - ptrace(PTRACE_GETREGS, tid, NULL, &t->regs); +#ifdef ARCH_AMD64 + t->fpregs.type = FPREGS_AVX; +#endif + t->fpregs.dirty = 0; + t->fpregs.fresh = 0; + + getregs(tid, &t->regs); t->next = state->t_HEAD; state->t_HEAD = t; @@ -132,7 +572,7 @@ void ptrace_detach_for_kill(struct global_state *state, int pid) // note that the order is important: the main thread must be detached last while (t != NULL) { // let's attempt to read the registers of the thread - if (ptrace(PTRACE_GETREGS, t->tid, NULL, &t->regs)) { + if (getregs(t->tid, &t->regs)) { // if we can't read the registers, the thread is probably still running // ensure that the thread is stopped tgkill(pid, t->tid, SIGSTOP); @@ -162,7 +602,7 @@ void ptrace_detach_for_migration(struct global_state *state, int pid) while (t != NULL) { // the user might have modified the state of the registers // so we use SETREGS to check if the process is running - if (ptrace(PTRACE_SETREGS, t->tid, NULL, &t->regs)) { + if (setregs(t->tid, &t->regs)) { // if we can't read the registers, the thread is probably still running // ensure that the thread is stopped tgkill(pid, t->tid, SIGSTOP); @@ -171,7 +611,8 @@ void ptrace_detach_for_migration(struct global_state *state, int pid) waitpid(t->tid, NULL, 0); // set the registers again, as the first time it failed - ptrace(PTRACE_SETREGS, t->tid, NULL, &t->regs); + setregs(t->tid, &t->regs); + check_and_set_fp_regs(t); } // Be sure that the thread will not run during gdb reattachment @@ -195,7 +636,7 @@ void ptrace_reattach_from_gdb(struct global_state *state, int pid) fprintf(stderr, "ptrace_attach failed for thread %d: %s\\n", t->tid, strerror(errno)); - if (ptrace(PTRACE_GETREGS, t->tid, NULL, &t->regs)) + if (getregs(t->tid, &t->regs)) fprintf(stderr, "ptrace_getregs failed for thread %d: %s\\n", t->tid, strerror(errno)); @@ -233,20 +674,6 @@ uint64_t ptrace_pokedata(int pid, uint64_t addr, uint64_t data) return ptrace(PTRACE_POKEDATA, pid, (void *)addr, data); } -uint64_t ptrace_peekuser(int pid, uint64_t addr) -{ - // Since the value returned by a successful PTRACE_PEEK* - // request may be -1, the caller must clear errno before the call, - errno = 0; - - return ptrace(PTRACE_PEEKUSER, pid, addr, NULL); -} - -uint64_t ptrace_pokeuser(int pid, uint64_t addr, uint64_t data) -{ - return ptrace(PTRACE_POKEUSER, pid, addr, data); -} - uint64_t ptrace_geteventmsg(int pid) { uint64_t data = 0; @@ -262,8 +689,11 @@ long singlestep(struct global_state *state, int tid) struct thread *t = state->t_HEAD; int signal_to_forward = 0; while (t != NULL) { - if (ptrace(PTRACE_SETREGS, t->tid, NULL, &t->regs)) + if (setregs(t->tid, &t->regs)) perror("ptrace_setregs"); + + check_and_set_fp_regs(t); + if (t->tid == tid) { signal_to_forward = t->signal_to_forward; t->signal_to_forward = 0; @@ -271,7 +701,27 @@ long singlestep(struct global_state *state, int tid) t = t->next; } +#ifdef ARCH_AMD64 + return ptrace(PTRACE_SINGLESTEP, tid, NULL, signal_to_forward); +#endif + +#ifdef ARCH_AARCH64 + // Cannot singlestep if we are stopped on a hardware breakpoint + // So we have to check for this, remove it, singlestep and then re-add it + struct hardware_breakpoint *bp = state->hw_b_HEAD; + + while (bp != NULL) { + if (bp->tid == tid && bp->enabled && is_breakpoint_hit(bp)) { + remove_hardware_breakpoint(bp); + long ret = ptrace(PTRACE_SINGLESTEP, tid, NULL, signal_to_forward); + install_hardware_breakpoint(bp); + return ret; + } + bp = bp->next; + } + return ptrace(PTRACE_SINGLESTEP, tid, NULL, signal_to_forward); +#endif } int step_until(struct global_state *state, int tid, uint64_t addr, int max_steps) @@ -279,9 +729,11 @@ int step_until(struct global_state *state, int tid, uint64_t addr, int max_steps // flush any register changes struct thread *t = state->t_HEAD, *stepping_thread = NULL; while (t != NULL) { - if (ptrace(PTRACE_SETREGS, t->tid, NULL, &t->regs)) + if (setregs(t->tid, &t->regs)) perror("ptrace_setregs"); + check_and_set_fp_regs(t); + if (t->tid == tid) stepping_thread = t; @@ -296,6 +748,16 @@ int step_until(struct global_state *state, int tid, uint64_t addr, int max_steps return -1; } + // remove any hardware breakpoint that might be set on the stepping thread + struct hardware_breakpoint *bp = state->hw_b_HEAD; + + while (bp != NULL) { + if (bp->tid == tid && bp->enabled) { + remove_hardware_breakpoint(bp); + } + bp = bp->next; + } + while (max_steps == -1 || count < max_steps) { if (ptrace(PTRACE_SINGLESTEP, tid, NULL, NULL)) return -1; @@ -305,7 +767,7 @@ int step_until(struct global_state *state, int tid, uint64_t addr, int max_steps previous_ip = INSTRUCTION_POINTER(stepping_thread->regs); // update the registers - ptrace(PTRACE_GETREGS, tid, NULL, &stepping_thread->regs); + getregs(tid, &stepping_thread->regs); if (INSTRUCTION_POINTER(stepping_thread->regs) == addr) break; @@ -316,6 +778,16 @@ int step_until(struct global_state *state, int tid, uint64_t addr, int max_steps count++; } + // re-add the hardware breakpoints + bp = state->hw_b_HEAD; + + while (bp != NULL) { + if (bp->tid == tid && bp->enabled) { + install_hardware_breakpoint(bp); + } + bp = bp->next; + } + return 0; } @@ -326,9 +798,12 @@ int prepare_for_run(struct global_state *state, int pid) // flush any register changes struct thread *t = state->t_HEAD; while (t != NULL) { - if (ptrace(PTRACE_SETREGS, t->tid, NULL, &t->regs)) + if (setregs(t->tid, &t->regs)) fprintf(stderr, "ptrace_setregs failed for thread %d: %s\\n", t->tid, strerror(errno)); + + check_and_set_fp_regs(t); + t = t->next; } @@ -342,7 +817,7 @@ int prepare_for_run(struct global_state *state, int pid) t_hit = 0; uint64_t ip = INSTRUCTION_POINTER(t->regs); - b = state->b_HEAD; + b = state->sw_b_HEAD; while (b != NULL && !t_hit) { if (b->addr == ip) // we hit a software breakpoint on this thread @@ -369,8 +844,47 @@ int prepare_for_run(struct global_state *state, int pid) t = t->next; } +#ifdef ARCH_AARCH64 + // iterate over all the threads and check if any of them has hit a hardware + // breakpoint + t = state->t_HEAD; + struct hardware_breakpoint *bp; + int bp_hit; + + while (t != NULL) { + bp_hit = 0; + + bp = state->hw_b_HEAD; + while (bp != NULL && !bp_hit) { + if (bp->tid == t->tid && bp->enabled && is_breakpoint_hit(bp)) { + // we hit a hardware breakpoint on this thread + bp_hit = 1; + break; + } + + bp = bp->next; + } + + if (bp_hit) { + // remove the breakpoint + remove_hardware_breakpoint(bp); + + // step over the breakpoint + if (ptrace(PTRACE_SINGLESTEP, t->tid, NULL, NULL)) return -1; + + // wait for the child + waitpid(t->tid, &status, 0); + + // re-add the breakpoint + install_hardware_breakpoint(bp); + } + + t = t->next; + } +#endif + // Reset any software breakpoint - b = state->b_HEAD; + b = state->sw_b_HEAD; while (b != NULL) { if (b->enabled) { ptrace(PTRACE_POKEDATA, pid, (void *)b->addr, @@ -423,7 +937,7 @@ struct thread_status *wait_all_and_update_regs(struct global_state *state, int p if (t->tid != head->tid) { // If GETREGS succeeds, the thread is already stopped, so we must // not "stop" it again - if (ptrace(PTRACE_GETREGS, t->tid, NULL, &t->regs) == -1) { + if (getregs(t->tid, &t->regs) == -1) { // Stop the thread with a SIGSTOP tgkill(pid, t->tid, SIGSTOP); // Wait for the thread to stop @@ -453,12 +967,12 @@ struct thread_status *wait_all_and_update_regs(struct global_state *state, int p // Update the registers of all the threads t = state->t_HEAD; while (t) { - ptrace(PTRACE_GETREGS, t->tid, NULL, &t->regs); + getregs(t->tid, &t->regs); t = t->next; } // Restore any software breakpoint - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; while (b != NULL) { if (b->enabled) { @@ -491,7 +1005,7 @@ void register_breakpoint(struct global_state *state, int pid, uint64_t address) ptrace(PTRACE_POKEDATA, pid, (void *)address, patched_instruction); - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; while (b != NULL) { if (b->addr == address) { @@ -509,13 +1023,13 @@ void register_breakpoint(struct global_state *state, int pid, uint64_t address) // Breakpoints should be inserted ordered by address, increasing // This is important, because we don't want a breakpoint patching another - if (state->b_HEAD == NULL || state->b_HEAD->addr > address) { - b->next = state->b_HEAD; - state->b_HEAD = b; + if (state->sw_b_HEAD == NULL || state->sw_b_HEAD->addr > address) { + b->next = state->sw_b_HEAD; + state->sw_b_HEAD = b; return; } else { - struct software_breakpoint *prev = state->b_HEAD; - struct software_breakpoint *next = state->b_HEAD->next; + struct software_breakpoint *prev = state->sw_b_HEAD; + struct software_breakpoint *next = state->sw_b_HEAD->next; while (next != NULL && next->addr < address) { prev = next; @@ -529,13 +1043,13 @@ void register_breakpoint(struct global_state *state, int pid, uint64_t address) void unregister_breakpoint(struct global_state *state, uint64_t address) { - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; struct software_breakpoint *prev = NULL; while (b != NULL) { if (b->addr == address) { if (prev == NULL) { - state->b_HEAD = b->next; + state->sw_b_HEAD = b->next; } else { prev->next = b->next; } @@ -549,7 +1063,7 @@ void unregister_breakpoint(struct global_state *state, uint64_t address) void enable_breakpoint(struct global_state *state, uint64_t address) { - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; while (b != NULL) { if (b->addr == address) { @@ -567,7 +1081,7 @@ void enable_breakpoint(struct global_state *state, uint64_t address) void disable_breakpoint(struct global_state *state, uint64_t address) { - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; while (b != NULL) { if (b->addr == address) { @@ -585,7 +1099,7 @@ void disable_breakpoint(struct global_state *state, uint64_t address) void free_breakpoints(struct global_state *state) { - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; struct software_breakpoint *next; while (b != NULL) { @@ -594,7 +1108,18 @@ void free_breakpoints(struct global_state *state) b = next; } - state->b_HEAD = NULL; + state->sw_b_HEAD = NULL; + + struct hardware_breakpoint *h = state->hw_b_HEAD; + struct hardware_breakpoint *next_h; + + while (h != NULL) { + next_h = h->next; + free(h); + h = next_h; + } + + state->hw_b_HEAD = NULL; } int stepping_finish(struct global_state *state, int tid) @@ -616,7 +1141,7 @@ int stepping_finish(struct global_state *state, int tid) } uint64_t previous_ip, current_ip; - uint64_t opcode_window, first_opcode_byte; + uint64_t opcode_window, opcode; // We need to keep track of the nested calls int nested_call_counter = 1; @@ -630,24 +1155,32 @@ int stepping_finish(struct global_state *state, int tid) previous_ip = INSTRUCTION_POINTER(stepping_thread->regs); // update the registers - ptrace(PTRACE_GETREGS, tid, NULL, &stepping_thread->regs); + getregs(tid, &stepping_thread->regs); current_ip = INSTRUCTION_POINTER(stepping_thread->regs); // Get value at current instruction pointer opcode_window = ptrace(PTRACE_PEEKDATA, tid, (void *)current_ip, NULL); - first_opcode_byte = opcode_window & 0xFF; + +#ifdef ARCH_AMD64 + // on amd64 we care only about the first byte + opcode = opcode_window & 0xFF; +#endif + +#ifdef ARCH_AARCH64 + opcode = opcode_window & 0xFFFFFFFF; +#endif // if the instruction pointer didn't change, we return // because we hit a hardware breakpoint // we do the same if we hit a software breakpoint - if (current_ip == previous_ip || IS_SW_BREAKPOINT(first_opcode_byte)) + if (current_ip == previous_ip || IS_SW_BREAKPOINT(opcode)) goto cleanup; // If we hit a call instruction, we increment the counter if (IS_CALL_INSTRUCTION((uint8_t*) &opcode_window)) nested_call_counter++; - else if (IS_RET_INSTRUCTION(first_opcode_byte)) + else if (IS_RET_INSTRUCTION(opcode)) nested_call_counter--; } while (nested_call_counter > 0); @@ -659,11 +1192,11 @@ int stepping_finish(struct global_state *state, int tid) waitpid(tid, &status, 0); // update the registers - ptrace(PTRACE_GETREGS, tid, NULL, &stepping_thread->regs); + getregs(tid, &stepping_thread->regs); cleanup: // remove any installed breakpoint - struct software_breakpoint *b = state->b_HEAD; + struct software_breakpoint *b = state->sw_b_HEAD; while (b != NULL) { if (b->enabled) { ptrace(PTRACE_POKEDATA, tid, (void *)b->addr, b->instruction); @@ -673,3 +1206,100 @@ int stepping_finish(struct global_state *state, int tid) return 0; } + +void register_hw_breakpoint(struct global_state *state, int tid, uint64_t address, char type[2], char len) +{ + struct hardware_breakpoint *b = state->hw_b_HEAD; + + while (b != NULL) { + if (b->addr == address && b->tid == tid) { + perror("Breakpoint already registered"); + return; + } + b = b->next; + } + + b = malloc(sizeof(struct hardware_breakpoint)); + b->addr = address; + b->tid = tid; + b->enabled = 1; + b->type[0] = type[0]; + b->type[1] = type[1]; + b->len = len; + + b->next = state->hw_b_HEAD; + state->hw_b_HEAD = b; + + install_hardware_breakpoint(b); +} + +void unregister_hw_breakpoint(struct global_state *state, int tid, uint64_t address) +{ + struct hardware_breakpoint *b = state->hw_b_HEAD; + struct hardware_breakpoint *prev = NULL; + + while (b != NULL) { + if (b->addr == address && b->tid == tid) { + if (prev == NULL) { + state->hw_b_HEAD = b->next; + } else { + prev->next = b->next; + } + + if (b->enabled) { + remove_hardware_breakpoint(b); + } + + free(b); + return; + } + prev = b; + b = b->next; + } +} + +void enable_hw_breakpoint(struct global_state *state, int tid, uint64_t address) +{ + struct hardware_breakpoint *b = state->hw_b_HEAD; + + while (b != NULL) { + if (b->addr == address && b->tid == tid) { + if (!b->enabled) { + install_hardware_breakpoint(b); + } + + b->enabled = 1; + } + b = b->next; + } +} + +void disable_hw_breakpoint(struct global_state *state, int tid, uint64_t address) +{ + struct hardware_breakpoint *b = state->hw_b_HEAD; + + while (b != NULL) { + if (b->addr == address && b->tid == tid) { + if (b->enabled) { + remove_hardware_breakpoint(b); + } + + b->enabled = 0; + } + b = b->next; + } +} + +unsigned long get_hit_hw_breakpoint(struct global_state *state, int tid) +{ + struct hardware_breakpoint *b = state->hw_b_HEAD; + + while (b != NULL) { + if (b->tid == tid && is_breakpoint_hit(b)) { + return b->addr; + } + b = b->next; + } + + return 0; +} diff --git a/libdebug/debugger/debugger.py b/libdebug/debugger/debugger.py index d7437be1..33f36ecb 100644 --- a/libdebug/debugger/debugger.py +++ b/libdebug/debugger/debugger.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING +from libdebug.utils.arch_mappings import map_arch from libdebug.utils.signal_utils import ( get_all_signal_numbers, resolve_signal_name, @@ -63,9 +64,9 @@ def kill(self: Debugger) -> None: self._internal_debugger.kill() def terminate(self: Debugger) -> None: - """Terminates the background thread. + """Interrupts the process, kills it and then terminates the background thread. - The debugger object cannot be used after this method is called. + The debugger object will not be usable after this method is called. This method should only be called to free up resources when the debugger object is no longer needed. """ self._internal_debugger.terminate() @@ -94,7 +95,7 @@ def breakpoint( self: Debugger, position: int | str, hardware: bool = False, - condition: str | None = None, + condition: str = "x", length: int = 1, callback: None | Callable[[ThreadContext, Breakpoint], None] = None, file: str = "hybrid", @@ -264,7 +265,7 @@ def bp( self: Debugger, position: int | str, hardware: bool = False, - condition: str | None = None, + condition: str = "x", length: int = 1, callback: None | Callable[[ThreadContext, Breakpoint], None] = None, file: str = "hybrid", @@ -317,6 +318,28 @@ def wp( file=file, ) + @property + def arch(self: Debugger) -> str: + """Get the architecture of the process.""" + return self._internal_debugger.arch + + @arch.setter + def arch(self: Debugger, value: str) -> None: + """Set the architecture of the process.""" + self._internal_debugger.arch = map_arch(value) + + @property + def kill_on_exit(self: Debugger) -> bool: + """Get whether the process will be killed when the debugger exits.""" + return self._internal_debugger.kill_on_exit + + @kill_on_exit.setter + def kill_on_exit(self: Debugger, value: bool) -> None: + if not isinstance(value, bool): + raise TypeError("kill_on_exit must be a boolean") + + self._internal_debugger.kill_on_exit = value + @property def threads(self: Debugger) -> list[ThreadContext]: """Get the list of threads in the process.""" @@ -395,7 +418,10 @@ def syscalls_to_pprint(self: Debugger) -> list[str] | None: if self._internal_debugger.syscalls_to_pprint is None: return None else: - return [resolve_syscall_name(v) for v in self._internal_debugger.syscalls_to_pprint] + return [ + resolve_syscall_name(self._internal_debugger.arch, v) + for v in self._internal_debugger.syscalls_to_pprint + ] @syscalls_to_pprint.setter def syscalls_to_pprint(self: Debugger, value: list[int | str] | None) -> None: @@ -408,7 +434,7 @@ def syscalls_to_pprint(self: Debugger, value: list[int | str] | None) -> None: self._internal_debugger.syscalls_to_pprint = None elif isinstance(value, list): self._internal_debugger.syscalls_to_pprint = [ - v if isinstance(v, int) else resolve_syscall_number(v) for v in value + v if isinstance(v, int) else resolve_syscall_number(self._internal_debugger.arch, v) for v in value ] else: raise ValueError( @@ -427,7 +453,10 @@ def syscalls_to_not_pprint(self: Debugger) -> list[str] | None: if self._internal_debugger.syscalls_to_not_pprint is None: return None else: - return [resolve_syscall_name(v) for v in self._internal_debugger.syscalls_to_not_pprint] + return [ + resolve_syscall_name(self._internal_debugger.arch, v) + for v in self._internal_debugger.syscalls_to_not_pprint + ] @syscalls_to_not_pprint.setter def syscalls_to_not_pprint(self: Debugger, value: list[int | str] | None) -> None: @@ -440,7 +469,7 @@ def syscalls_to_not_pprint(self: Debugger, value: list[int | str] | None) -> Non self._internal_debugger.syscalls_to_not_pprint = None elif isinstance(value, list): self._internal_debugger.syscalls_to_not_pprint = [ - v if isinstance(v, int) else resolve_syscall_number(v) for v in value + v if isinstance(v, int) else resolve_syscall_number(self._internal_debugger.arch, v) for v in value ] else: raise ValueError( @@ -475,6 +504,30 @@ def signals_to_block(self: Debugger, signals: list[int | str]) -> None: self._internal_debugger.signals_to_block = signals + @property + def fast_memory(self: Debugger) -> bool: + """Get the state of the fast_memory flag. + + It is used to determine if the debugger should use a faster memory access method. + + Returns: + bool: True if the debugger should use a faster memory access method, False otherwise. + """ + return self._internal_debugger.fast_memory + + @fast_memory.setter + def fast_memory(self: Debugger, value: bool) -> None: + """Set the state of the fast_memory flag. + + It is used to determine if the debugger should use a faster memory access method. + + Args: + value (bool): the value to set. + """ + if not isinstance(value, bool): + raise TypeError("fast_memory must be a boolean") + self._internal_debugger.fast_memory = value + def __getattr__(self: Debugger, name: str) -> object: """This function is called when an attribute is not found in the `Debugger` object. diff --git a/libdebug/debugger/internal_debugger.py b/libdebug/debugger/internal_debugger.py index 107d1cc9..f49a9f71 100644 --- a/libdebug/debugger/internal_debugger.py +++ b/libdebug/debugger/internal_debugger.py @@ -10,6 +10,7 @@ import functools import os import signal +import sys from pathlib import Path from queue import Queue from signal import SIGKILL, SIGSTOP, SIGTRAP @@ -19,11 +20,11 @@ import psutil -from libdebug.architectures.syscall_hijacking_provider import syscall_hijacking_provider +from libdebug.architectures.breakpoint_validator import validate_hardware_breakpoint +from libdebug.architectures.syscall_hijacker import SyscallHijacker from libdebug.builtin.antidebug_syscall_handler import on_enter_ptrace, on_exit_ptrace from libdebug.builtin.pretty_print_syscall_handler import pprint_on_enter, pprint_on_exit from libdebug.data.breakpoint import Breakpoint -from libdebug.data.memory_view import MemoryView from libdebug.data.signal_catcher import SignalCatcher from libdebug.data.syscall_handler import SyscallHandler from libdebug.debugger.internal_debugger_instance_manager import ( @@ -32,7 +33,11 @@ ) from libdebug.interfaces.interface_helper import provide_debugging_interface from libdebug.liblog import liblog +from libdebug.memory.chunked_memory_view import ChunkedMemoryView +from libdebug.memory.direct_memory_view import DirectMemoryView +from libdebug.memory.process_memory_manager import ProcessMemoryManager from libdebug.state.resume_context import ResumeContext +from libdebug.utils.arch_mappings import map_arch from libdebug.utils.debugger_wrappers import ( background_alias, change_state_function_process, @@ -44,6 +49,7 @@ resolve_symbol_in_maps, ) from libdebug.utils.libcontext import libcontext +from libdebug.utils.platform_utils import get_platform_register_size from libdebug.utils.print_style import PrintStyle from libdebug.utils.signal_utils import ( resolve_signal_name, @@ -59,12 +65,14 @@ from collections.abc import Callable from libdebug.data.memory_map import MemoryMap + from libdebug.data.registers import Registers from libdebug.interfaces.debugging_interface import DebuggingInterface + from libdebug.memory.abstract_memory_view import AbstractMemoryView from libdebug.state.thread_context import ThreadContext from libdebug.utils.pipe_manager import PipeManager THREAD_TERMINATE = -1 -GDB_GOBACK_LOCATION = str((Path(__file__).parent / "utils" / "gdb.py").resolve()) +GDB_GOBACK_LOCATION = str((Path(__file__).parent.parent / "utils" / "gdb.py").resolve()) class InternalDebugger: @@ -73,6 +81,9 @@ class InternalDebugger: aslr_enabled: bool """A flag that indicates if ASLR is enabled or not.""" + arch: str + """The architecture of the debugged process.""" + argv: list[str] """The command line arguments of the debugged process.""" @@ -82,6 +93,9 @@ class InternalDebugger: escape_antidebug: bool """A flag that indicates if the debugger should escape anti-debugging techniques.""" + fast_memory: bool + """A flag that indicates if the debugger should use a faster memory access method.""" + autoreach_entrypoint: bool """A flag that indicates if the debugger should automatically reach the entry point of the debugged process.""" @@ -109,6 +123,9 @@ class InternalDebugger: syscalls_to_not_pprint: list[int] | None """The syscalls to not pretty print.""" + kill_on_exit: bool + """A flag that indicates if the debugger should kill the debugged process when it exits.""" + threads: list[ThreadContext] """A list of all the threads of the debugged process.""" @@ -118,7 +135,7 @@ class InternalDebugger: pipe_manager: PipeManager """The pipe manager used to communicate with the debugged process.""" - memory: MemoryView + memory: AbstractMemoryView """The memory view of the debugged process.""" debugging_interface: DebuggingInterface @@ -133,18 +150,24 @@ class InternalDebugger: resume_context: ResumeContext """Context that indicates if the debugger should resume the debugged process.""" - _polling_thread: Thread | None + __polling_thread: Thread | None """The background thread used to poll the process for state change.""" - _polling_thread_command_queue: Queue | None + __polling_thread_command_queue: Queue | None """The queue used to send commands to the background thread.""" - _polling_thread_response_queue: Queue | None + __polling_thread_response_queue: Queue | None """The queue used to receive responses from the background thread.""" _is_running: bool """The overall state of the debugged process. True if the process is running, False otherwise.""" + _fast_memory: DirectMemoryView + """The memory view of the debugged process using the fast memory access method.""" + + _slow_memory: ChunkedMemoryView + """The memory view of the debugged process using the slow memory access method.""" + def __init__(self: InternalDebugger) -> None: """Initialize the context.""" # These must be reinitialized on every call to "debugger" @@ -162,10 +185,14 @@ def __init__(self: InternalDebugger) -> None: self.pprint_syscalls = False self.pipe_manager = None self.process_id = 0 - self.threads = list() + self.threads = [] self.instanced = False self._is_running = False self.resume_context = ResumeContext() + self.arch = map_arch(libcontext.platform) + self.kill_on_exit = True + self._process_memory_manager = ProcessMemoryManager() + self.fast_memory = False self.__polling_thread_command_queue = Queue() self.__polling_thread_response_queue = Queue() @@ -188,14 +215,18 @@ def clear(self: InternalDebugger) -> None: def start_up(self: InternalDebugger) -> None: """Starts up the context.""" - # The context is linked to itself link_to_internal_debugger(self, self) self.start_processing_thread() with extend_internal_debugger(self): self.debugging_interface = provide_debugging_interface() - self.memory = MemoryView(self._peek_memory, self._poke_memory) + self._fast_memory = DirectMemoryView(self._fast_read_memory, self._fast_write_memory) + self._slow_memory = ChunkedMemoryView( + self._peek_memory, + self._poke_memory, + unit_size=get_platform_register_size(libcontext.platform), + ) def start_processing_thread(self: InternalDebugger) -> None: """Starts the thread that will poll the traced process for state change.""" @@ -238,15 +269,17 @@ def run(self: InternalDebugger) -> None: self.__polling_thread_command_queue.put((self.__threaded_run, ())) + self._join_and_check_status() + if self.escape_antidebug: liblog.debugger("Enabling anti-debugging escape mechanism.") self._enable_antidebug_escaping() - self._join_and_check_status() - if not self.pipe_manager: raise RuntimeError("Something went wrong during pipe initialization.") + self._process_memory_manager.open(self.process_id) + return self.pipe_manager def attach(self: InternalDebugger, pid: int) -> None: @@ -265,6 +298,8 @@ def attach(self: InternalDebugger, pid: int) -> None: self.__polling_thread_command_queue.put((self.__threaded_attach, (pid,))) + self._process_memory_manager.open(self.process_id) + self._join_and_check_status() def detach(self: InternalDebugger) -> None: @@ -278,6 +313,8 @@ def detach(self: InternalDebugger) -> None: self._join_and_check_status() + self._process_memory_manager.close() + @background_alias(_background_invalid_call) def kill(self: InternalDebugger) -> None: """Kills the process.""" @@ -287,6 +324,8 @@ def kill(self: InternalDebugger) -> None: # This exception might occur if the process has already died liblog.debugger("OSError raised during kill") + self._process_memory_manager.close() + self.__polling_thread_command_queue.put((self.__threaded_kill, ())) self.instanced = False @@ -297,11 +336,19 @@ def kill(self: InternalDebugger) -> None: self._join_and_check_status() def terminate(self: InternalDebugger) -> None: - """Terminates the background thread. + """Interrupts the process, kills it and then terminates the background thread. - The debugger object cannot be used after this method is called. + The debugger object will not be usable after this method is called. This method should only be called to free up resources when the debugger object is no longer needed. """ + if self.instanced and self.running: + self.interrupt() + + if self.instanced: + self.kill() + + self.instanced = False + if self.__polling_thread is not None: self.__polling_thread_command_queue.put((THREAD_TERMINATE, ())) self.__polling_thread.join() @@ -362,6 +409,11 @@ def maps(self: InternalDebugger) -> list[MemoryMap]: self._ensure_process_stopped() return self.debugging_interface.maps() + @property + def memory(self: InternalDebugger) -> AbstractMemoryView: + """The memory view of the debugged process.""" + return self._fast_memory if self.fast_memory else self._slow_memory + def print_maps(self: InternalDebugger) -> None: """Prints the memory maps of the process.""" self._ensure_process_stopped() @@ -382,7 +434,7 @@ def breakpoint( self: InternalDebugger, position: int | str, hardware: bool = False, - condition: str | None = None, + condition: str = "x", length: int = 1, callback: None | Callable[[ThreadContext, Breakpoint], None] = None, file: str = "hybrid", @@ -407,26 +459,13 @@ def breakpoint( address = self.resolve_address(position, file) position = hex(address) - if condition: - if not hardware: - raise ValueError( - "Breakpoint condition is supported only for hardware watchpoints.", - ) - - if condition.lower() not in ["w", "rw", "x"]: - raise ValueError( - "Invalid condition for watchpoints. Supported conditions are 'w', 'rw', 'x'.", - ) - - if length not in [1, 2, 4, 8]: - raise ValueError( - "Invalid length for watchpoints. Supported lengths are 1, 2, 4, 8.", - ) + if condition != "x" and not hardware: + raise ValueError("Breakpoint condition is supported only for hardware watchpoints.") - if hardware and not condition: - condition = "x" + bp = Breakpoint(address, position, 0, hardware, callback, condition.lower(), length) - bp = Breakpoint(address, position, 0, hardware, callback, condition, length) + if hardware: + validate_hardware_breakpoint(self.arch, bp) link_to_internal_debugger(bp, self) @@ -470,15 +509,15 @@ def catch_signal( match signal_number: case SIGKILL.value: raise ValueError( - f"Cannot catch SIGKILL ({signal_number}) as it cannot be caught or ignored. This is a kernel restriction." + f"Cannot catch SIGKILL ({signal_number}) as it cannot be caught or ignored. This is a kernel restriction.", ) case SIGSTOP.value: raise ValueError( - f"Cannot catch SIGSTOP ({signal_number}) as it is used by the debugger or ptrace for their internal operations." + f"Cannot catch SIGSTOP ({signal_number}) as it is used by the debugger or ptrace for their internal operations.", ) case SIGTRAP.value: raise ValueError( - f"Cannot catch SIGTRAP ({signal_number}) as it is used by the debugger or ptrace for their internal operations." + f"Cannot catch SIGTRAP ({signal_number}) as it is used by the debugger or ptrace for their internal operations.", ) if signal_number in self.caught_signals: @@ -559,7 +598,7 @@ def handle_syscall( Returns: HandledSyscall: The HandledSyscall object. """ - syscall_number = resolve_syscall_number(syscall) if isinstance(syscall, str) else syscall + syscall_number = resolve_syscall_number(self.arch, syscall) if isinstance(syscall, str) else syscall if not isinstance(recursive, bool): raise TypeError("recursive must be a boolean") @@ -569,7 +608,7 @@ def handle_syscall( handler = self.handled_syscalls[syscall_number] if handler.on_enter_user or handler.on_exit_user: liblog.warning( - f"Syscall {resolve_syscall_name(syscall_number)} is already handled by a user-defined handler. Overriding it.", + f"Syscall {resolve_syscall_name(self.arch, syscall_number)} is already handled by a user-defined handler. Overriding it.", ) handler.on_enter_user = on_enter handler.on_exit_user = on_exit @@ -616,22 +655,24 @@ def hijack_syscall( Returns: HandledSyscall: The HandledSyscall object. """ - if set(kwargs) - syscall_hijacking_provider().allowed_args: + if set(kwargs) - SyscallHijacker.allowed_args: raise ValueError("Invalid keyword arguments in syscall hijack") if isinstance(original_syscall, str): - original_syscall_number = resolve_syscall_number(original_syscall) + original_syscall_number = resolve_syscall_number(self.arch, original_syscall) else: original_syscall_number = original_syscall - new_syscall_number = resolve_syscall_number(new_syscall) if isinstance(new_syscall, str) else new_syscall + new_syscall_number = ( + resolve_syscall_number(self.arch, new_syscall) if isinstance(new_syscall, str) else new_syscall + ) if original_syscall_number == new_syscall_number: raise ValueError( "The original syscall and the new syscall must be different during hijacking.", ) - on_enter = syscall_hijacking_provider().create_hijacker( + on_enter = SyscallHijacker().create_hijacker( new_syscall_number, **kwargs, ) @@ -671,7 +712,6 @@ def hijack_syscall( @change_state_function_process def gdb(self: InternalDebugger, open_in_new_process: bool = True) -> None: """Migrates the current debugging session to GDB.""" - # TODO: not needed? self.interrupt() @@ -888,13 +928,28 @@ def finish(self: InternalDebugger, thread: ThreadContext, heuristic: str = "back self._join_and_check_status() + def _background_next( + self: InternalDebugger, + thread: ThreadContext, + ) -> None: + """Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns.""" + self.__threaded_next(thread) + + @background_alias(_background_next) + @change_state_function_thread + def next(self: InternalDebugger, thread: ThreadContext) -> None: + """Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns.""" + self._ensure_process_stopped() + self.__polling_thread_command_queue.put((self.__threaded_next, (thread,))) + self._join_and_check_status() + def enable_pretty_print( self: InternalDebugger, ) -> SyscallHandler: """Handles a syscall in the target process to pretty prints its arguments and return value.""" self._ensure_process_stopped() - syscall_numbers = get_all_syscall_numbers() + syscall_numbers = get_all_syscall_numbers(self.arch) for syscall_number in syscall_numbers: # Check if the syscall is already handled (by the user or by the pretty print handler) @@ -1030,15 +1085,14 @@ def resolve_address( else: # If the address was not found and the backing file is not "absolute", # we have to assume it is in the main map - backing_file = self._get_process_full_path() + backing_file = self._process_full_path liblog.warning( f"No backing file specified and no corresponding absolute address found for {hex(address)}. Assuming {backing_file}.", ) - elif ( - backing_file == (full_backing_path := self._get_process_full_path()) - or backing_file == "binary" - or backing_file == self._get_process_name() - ): + elif backing_file == (full_backing_path := self._process_full_path) or backing_file in [ + "binary", + self._process_name, + ]: backing_file = full_backing_path filtered_maps = [] @@ -1056,7 +1110,7 @@ def resolve_address( if not filtered_maps: raise ValueError( - f"The specified string {backing_file} does not correspond to any backing file. The available backing files are: {', '.join(set(vmap.backing_file for vmap in maps))}." + f"The specified string {backing_file} does not correspond to any backing file. The available backing files are: {', '.join(set(vmap.backing_file for vmap in maps))}.", ) return normalize_and_validate_address(address, filtered_maps) @@ -1077,13 +1131,12 @@ def resolve_symbol(self: InternalDebugger, symbol: str, backing_file: str) -> in if backing_file == "hybrid": # If no explicit backing file is specified, we have to assume it is in the main map - backing_file = self._get_process_full_path() + backing_file = self._process_full_path liblog.debugger(f"No backing file specified for the symbol {symbol}. Assuming {backing_file}.") - elif ( - backing_file == (full_backing_path := self._get_process_full_path()) - or backing_file == "binary" - or backing_file == self._get_process_name() - ): + elif backing_file == (full_backing_path := self._process_full_path) or backing_file in [ + "binary", + self._process_name, + ]: backing_file = full_backing_path filtered_maps = [] @@ -1101,7 +1154,7 @@ def resolve_symbol(self: InternalDebugger, symbol: str, backing_file: str) -> in if not filtered_maps: raise ValueError( - f"The specified string {backing_file} does not correspond to any backing file. The available backing files are: {', '.join(set(vmap.backing_file for vmap in maps))}." + f"The specified string {backing_file} does not correspond to any backing file. The available backing files are: {', '.join(set(vmap.backing_file for vmap in maps))}.", ) return resolve_symbol_in_maps(symbol, filtered_maps) @@ -1162,8 +1215,8 @@ def _join_and_check_status(self: InternalDebugger) -> None: if response is not None: raise response - @functools.cache - def _get_process_full_path(self: InternalDebugger) -> str: + @functools.cached_property + def _process_full_path(self: InternalDebugger) -> str: """Get the full path of the process. Returns: @@ -1171,8 +1224,8 @@ def _get_process_full_path(self: InternalDebugger) -> str: """ return str(Path(f"/proc/{self.process_id}/exe").readlink()) - @functools.cache - def _get_process_name(self: InternalDebugger) -> str: + @functools.cached_property + def _process_name(self: InternalDebugger) -> str: """Get the name of the process. Returns: @@ -1287,6 +1340,11 @@ def __threaded_finish(self: InternalDebugger, thread: ThreadContext, heuristic: self.set_stopped() + def __threaded_next(self: InternalDebugger, thread: ThreadContext) -> None: + liblog.debugger("Next on thread %s.", thread.thread_id) + self.debugging_interface.next(thread) + self.set_stopped() + def __threaded_gdb(self: InternalDebugger) -> None: self.debugging_interface.migrate_to_gdb() @@ -1295,18 +1353,23 @@ def __threaded_migrate_from_gdb(self: InternalDebugger) -> None: def __threaded_peek_memory(self: InternalDebugger, address: int) -> bytes | BaseException: value = self.debugging_interface.peek_memory(address) - # TODO: this is only for amd64 - return value.to_bytes(8, "little") + return value.to_bytes(get_platform_register_size(libcontext.platform), sys.byteorder) def __threaded_poke_memory(self: InternalDebugger, address: int, data: bytes) -> None: - int_data = int.from_bytes(data, "little") + int_data = int.from_bytes(data, sys.byteorder) self.debugging_interface.poke_memory(address, int_data) + def __threaded_fetch_fp_registers(self: InternalDebugger, registers: Registers) -> None: + self.debugging_interface.fetch_fp_registers(registers) + + def __threaded_flush_fp_registers(self: InternalDebugger, registers: Registers) -> None: + self.debugging_interface.flush_fp_registers(registers) + @background_alias(__threaded_peek_memory) def _peek_memory(self: InternalDebugger, address: int) -> bytes: """Reads memory from the process.""" if not self.instanced: - raise RuntimeError("Process not running, cannot step.") + raise RuntimeError("Process not running, cannot access memory.") if self.running: # Reading memory while the process is running could lead to concurrency issues @@ -1332,11 +1395,27 @@ def _peek_memory(self: InternalDebugger, address: int) -> bytes: return value + def _fast_read_memory(self: InternalDebugger, address: int, size: int) -> bytes: + """Reads memory from the process.""" + if not self.instanced: + raise RuntimeError("Process not running, cannot access memory.") + + if self.running: + # Reading memory while the process is running could lead to concurrency issues + # and corrupted values + liblog.debugger( + "Process is running. Waiting for it to stop before reading memory.", + ) + + self._ensure_process_stopped() + + return self._process_memory_manager.read(address, size) + @background_alias(__threaded_poke_memory) def _poke_memory(self: InternalDebugger, address: int, data: bytes) -> None: """Writes memory to the process.""" if not self.instanced: - raise RuntimeError("Process not running, cannot step.") + raise RuntimeError("Process not running, cannot access memory.") if self.running: # Reading memory while the process is running could lead to concurrency issues @@ -1353,10 +1432,54 @@ def _poke_memory(self: InternalDebugger, address: int, data: bytes) -> None: self._join_and_check_status() + def _fast_write_memory(self: InternalDebugger, address: int, data: bytes) -> None: + """Writes memory to the process.""" + if not self.instanced: + raise RuntimeError("Process not running, cannot access memory.") + + if self.running: + # Reading memory while the process is running could lead to concurrency issues + # and corrupted values + liblog.debugger( + "Process is running. Waiting for it to stop before writing to memory.", + ) + + self._ensure_process_stopped() + + self._process_memory_manager.write(address, data) + + @background_alias(__threaded_fetch_fp_registers) + def _fetch_fp_registers(self: InternalDebugger, registers: Registers) -> None: + """Fetches the floating point registers of a thread.""" + if not self.instanced: + raise RuntimeError("Process not running, cannot read floating-point registers.") + + self._ensure_process_stopped() + + self.__polling_thread_command_queue.put( + (self.__threaded_fetch_fp_registers, (registers,)), + ) + + self._join_and_check_status() + + @background_alias(__threaded_flush_fp_registers) + def _flush_fp_registers(self: InternalDebugger, registers: Registers) -> None: + """Flushes the floating point registers of a thread.""" + if not self.instanced: + raise RuntimeError("Process not running, cannot write floating-point registers.") + + self._ensure_process_stopped() + + self.__polling_thread_command_queue.put( + (self.__threaded_flush_fp_registers, (registers,)), + ) + + self._join_and_check_status() + def _enable_antidebug_escaping(self: InternalDebugger) -> None: """Enables the anti-debugging escape mechanism.""" handler = SyscallHandler( - resolve_syscall_number("ptrace"), + resolve_syscall_number(self.arch, "ptrace"), on_enter_ptrace, on_exit_ptrace, None, @@ -1367,6 +1490,8 @@ def _enable_antidebug_escaping(self: InternalDebugger) -> None: self.__polling_thread_command_queue.put((self.__threaded_handle_syscall, (handler,))) + self._join_and_check_status() + # Seutp hidden state for the handler handler._traceme_called = False handler._command = None diff --git a/libdebug/debugger/internal_debugger_holder.py b/libdebug/debugger/internal_debugger_holder.py index 186a307b..8a2d6da7 100644 --- a/libdebug/debugger/internal_debugger_holder.py +++ b/libdebug/debugger/internal_debugger_holder.py @@ -1,20 +1,50 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2024 Gabriele Digregorio. All rights reserved. +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # +from __future__ import annotations + +import atexit from dataclasses import dataclass, field from threading import Lock +from typing import TYPE_CHECKING from weakref import WeakKeyDictionary +from libdebug.liblog import liblog + +if TYPE_CHECKING: + from libdebug.debugger.internal_debugger import InternalDebugger + @dataclass class InternalDebuggerHolder: """A holder for internal debuggers.""" + internal_debuggers: WeakKeyDictionary = field(default_factory=WeakKeyDictionary) global_internal_debugger = None internal_debugger_lock = Lock() internal_debugger_holder = InternalDebuggerHolder() + + +def _cleanup_internal_debugger() -> None: + """Cleanup the internal debugger.""" + for debugger in set(internal_debugger_holder.internal_debuggers.values()): + debugger: InternalDebugger + + if debugger.instanced and debugger.kill_on_exit: + try: + debugger.interrupt() + except Exception as e: + liblog.debugger(f"Error while interrupting debuggee: {e}") + + try: + debugger.terminate() + except Exception as e: + liblog.debugger(f"Error while terminating the debugger: {e}") + + +atexit.register(_cleanup_internal_debugger) diff --git a/libdebug/interfaces/debugging_interface.py b/libdebug/interfaces/debugging_interface.py index f1463eae..4159032e 100644 --- a/libdebug/interfaces/debugging_interface.py +++ b/libdebug/interfaces/debugging_interface.py @@ -1,6 +1,6 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Copyright (c) 2023-2024 Roberto Alessandro Bertolini, Gabriele Digregorio, Francesco Panebianco. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # @@ -12,6 +12,7 @@ if TYPE_CHECKING: from libdebug.data.breakpoint import Breakpoint from libdebug.data.memory_map import MemoryMap + from libdebug.data.registers import Registers from libdebug.data.signal_catcher import SignalCatcher from libdebug.data.syscall_handler import SyscallHandler from libdebug.state.thread_context import ThreadContext @@ -20,6 +21,7 @@ class DebuggingInterface(ABC): """The interface used by `_InternalDebugger` to communicate with the available debugging backends, such as `ptrace` or `gdb`.""" + @abstractmethod def __init__(self: DebuggingInterface) -> None: """Initializes the DebuggingInterface classs.""" @@ -94,6 +96,10 @@ def finish(self: DebuggingInterface, thread: ThreadContext, heuristic: str) -> N heuristic (str, optional): The heuristic to use. Defaults to "backtrace". """ + def next(self: DebuggingInterface, thread: ThreadContext) -> None: + """Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns. + """ + @abstractmethod def maps(self: DebuggingInterface) -> list[MemoryMap]: """Returns the memory maps of the process.""" @@ -165,3 +171,19 @@ def poke_memory(self: DebuggingInterface, address: int, data: int) -> None: address (int): The address to write. data (int): The value to write. """ + + @abstractmethod + def fetch_fp_registers(self: DebuggingInterface, registers: Registers) -> None: + """Fetches the floating-point registers of the specified thread. + + Args: + registers (Registers): The registers instance to update. + """ + + @abstractmethod + def flush_fp_registers(self: DebuggingInterface, registers: Registers) -> None: + """Flushes the floating-point registers of the specified thread. + + Args: + registers (Registers): The registers instance to flush. + """ diff --git a/libdebug/libdebug.py b/libdebug/libdebug.py index 55819051..5d5bcf88 100644 --- a/libdebug/libdebug.py +++ b/libdebug/libdebug.py @@ -7,6 +7,7 @@ from libdebug.debugger.debugger import Debugger from libdebug.debugger.internal_debugger import InternalDebugger +from libdebug.utils.elf_utils import elf_architecture def debugger( @@ -16,6 +17,8 @@ def debugger( escape_antidebug: bool = False, continue_to_binary_entrypoint: bool = True, auto_interrupt_on_command: bool = False, + fast_memory: bool = False, + kill_on_exit: bool = True, ) -> Debugger: """This function is used to create a new `Debugger` object. It returns a `Debugger` object. @@ -26,6 +29,8 @@ def debugger( escape_antidebug (bool): Whether to automatically attempt to patch antidebugger detectors based on the ptrace syscall. continue_to_binary_entrypoint (bool, optional): Whether to automatically continue to the binary entrypoint. Defaults to True. auto_interrupt_on_command (bool, optional): Whether to automatically interrupt the process when a command is issued. Defaults to False. + fast_memory (bool, optional): Whether to use a faster memory reading method. Defaults to False. + kill_on_exit (bool, optional): Whether to kill the debugged process when the debugger exits. Defaults to True. Returns: Debugger: The `Debugger` object. @@ -40,8 +45,14 @@ def debugger( internal_debugger.autoreach_entrypoint = continue_to_binary_entrypoint internal_debugger.auto_interrupt_on_command = auto_interrupt_on_command internal_debugger.escape_antidebug = escape_antidebug + internal_debugger.fast_memory = fast_memory + internal_debugger.kill_on_exit = kill_on_exit debugger = Debugger() debugger.post_init_(internal_debugger) + # If we are attaching, we assume the architecture is the same as the current platform + if argv: + debugger.arch = elf_architecture(argv[0]) + return debugger diff --git a/libdebug/memory/__init__.py b/libdebug/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libdebug/data/memory_view.py b/libdebug/memory/abstract_memory_view.py similarity index 70% rename from libdebug/data/memory_view.py rename to libdebug/memory/abstract_memory_view.py index 44ef6793..e9d66b21 100644 --- a/libdebug/data/memory_view.py +++ b/libdebug/memory/abstract_memory_view.py @@ -6,48 +6,26 @@ from __future__ import annotations -from collections.abc import Callable, MutableSequence -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from collections.abc import MutableSequence from libdebug.debugger.internal_debugger_instance_manager import provide_internal_debugger from libdebug.liblog import liblog -if TYPE_CHECKING: - from libdebug.debugger.internal_debugger import InternalDebugger +class AbstractMemoryView(MutableSequence, ABC): + """An abstract memory interface for the target process. -class MemoryView(MutableSequence): - """A memory interface for the target process. - - This class must be used to read and write memory of the target process. - - Attributes: - getter (Callable[[int], bytes]): A function that reads memory from the target process. - setter (Callable[[int, bytes], None]): A function that writes memory to the target process. - maps_provider (Callable[[], list[MemoryMap]]): A function that returns the memory maps of the target process. - unit_size (int, optional): The data size used by the getter and setter functions. Defaults to 8. - align_to (int, optional): The address alignment that must be used when reading and writing memory. Defaults to 1. + An implementation of class must be used to read and write memory of the target process. """ - context: InternalDebugger - """The debugging context of the target process.""" - - def __init__( - self: MemoryView, - getter: Callable[[int], bytes], - setter: Callable[[int, bytes], None], - unit_size: int = 8, - align_to: int = 1, - ) -> None: + def __init__(self: AbstractMemoryView) -> None: """Initializes the MemoryView.""" - self.getter = getter - self.setter = setter - self.unit_size = unit_size - self.align_to = align_to self._internal_debugger = provide_internal_debugger(self) self.maps_provider = self._internal_debugger.debugging_interface.maps - def read(self: MemoryView, address: int, size: int) -> bytes: + @abstractmethod + def read(self: AbstractMemoryView, address: int, size: int) -> bytes: """Reads memory from the target process. Args: @@ -57,72 +35,17 @@ def read(self: MemoryView, address: int, size: int) -> bytes: Returns: bytes: The read bytes. """ - if self.align_to == 1: - data = b"" - - remainder = size % self.unit_size - - for i in range(address, address + size - remainder, self.unit_size): - data += self.getter(i) - - if remainder: - data += self.getter(address + size - remainder)[:remainder] - - return data - else: - prefix = address % self.align_to - prefix_size = self.unit_size - prefix - - data = self.getter(address - prefix)[prefix:] - - remainder = (size - prefix_size) % self.unit_size - - for i in range( - address + prefix_size, - address + size - remainder, - self.unit_size, - ): - data += self.getter(i) - if remainder: - data += self.getter(address + size - remainder)[:remainder] - - return data - - def write(self: MemoryView, address: int, data: bytes) -> None: + @abstractmethod + def write(self: AbstractMemoryView, address: int, data: bytes) -> None: """Writes memory to the target process. Args: address (int): The address to write to. data (bytes): The data to write. """ - size = len(data) - - if self.align_to == 1: - remainder = size % self.unit_size - base = address - else: - prefix = address % self.align_to - prefix_size = self.unit_size - prefix - - prev_data = self.getter(address - prefix) - - self.setter(address - prefix, prev_data[:prefix_size] + data[:prefix]) - remainder = (size - prefix_size) % self.unit_size - base = address + prefix_size - - for i in range(base, address + size - remainder, self.unit_size): - self.setter(i, data[i - address : i - address + self.unit_size]) - - if remainder: - prev_data = self.getter(address + size - remainder) - self.setter( - address + size - remainder, - data[size - remainder :] + prev_data[remainder:], - ) - - def __getitem__(self: MemoryView, key: int | slice | str | tuple) -> bytes: + def __getitem__(self: AbstractMemoryView, key: int | slice | str | tuple) -> bytes: """Read from memory, either a single byte or a byte string. Args: @@ -130,7 +53,7 @@ def __getitem__(self: MemoryView, key: int | slice | str | tuple) -> bytes: """ return self._manage_memory_read_type(key) - def __setitem__(self: MemoryView, key: int | slice | str | tuple, value: bytes) -> None: + def __setitem__(self: AbstractMemoryView, key: int | slice | str | tuple, value: bytes) -> None: """Write to memory, either a single byte or a byte string. Args: @@ -141,7 +64,11 @@ def __setitem__(self: MemoryView, key: int | slice | str | tuple, value: bytes) raise TypeError("Invalid type for the value to write to memory. Expected bytes.") self._manage_memory_write_type(key, value) - def _manage_memory_read_type(self: MemoryView, key: int | slice | str | tuple, file: str = "hybrid") -> bytes: + def _manage_memory_read_type( + self: AbstractMemoryView, + key: int | slice | str | tuple, + file: str = "hybrid", + ) -> bytes: """Manage the read from memory, according to the typing. Args: @@ -183,7 +110,7 @@ def _manage_memory_read_type(self: MemoryView, key: int | slice | str | tuple, f else: raise TypeError("Invalid key type.") - def _manage_memory_read_tuple(self: MemoryView, key: tuple) -> bytes: + def _manage_memory_read_tuple(self: AbstractMemoryView, key: tuple) -> bytes: """Manage the read from memory, when the access is through a tuple. Args: @@ -223,7 +150,7 @@ def _manage_memory_read_tuple(self: MemoryView, key: tuple) -> bytes: raise ValueError("Invalid address.") from e def _manage_memory_write_type( - self: MemoryView, + self: AbstractMemoryView, key: int | slice | str | tuple, value: bytes, file: str = "hybrid", @@ -279,7 +206,7 @@ def _manage_memory_write_type( else: raise TypeError("Invalid key type.") - def _manage_memory_write_tuple(self: MemoryView, key: tuple, value: bytes) -> None: + def _manage_memory_write_tuple(self: AbstractMemoryView, key: tuple, value: bytes) -> None: """Manage the write to memory, when the access is through a tuple. Args: @@ -323,14 +250,14 @@ def _manage_memory_write_tuple(self: MemoryView, key: tuple, value: bytes) -> No except OSError as e: raise ValueError("Invalid address.") from e - def __delitem__(self: MemoryView, key: int | slice | str | tuple) -> None: + def __delitem__(self: AbstractMemoryView, key: int | slice | str | tuple) -> None: """MemoryView doesn't support deletion.""" raise NotImplementedError("MemoryView doesn't support deletion") - def __len__(self: MemoryView) -> None: + def __len__(self: AbstractMemoryView) -> None: """MemoryView doesn't support length.""" raise NotImplementedError("MemoryView doesn't support length") - def insert(self: MemoryView, index: int, value: int) -> None: + def insert(self: AbstractMemoryView, index: int, value: int) -> None: """MemoryView doesn't support insertion.""" raise NotImplementedError("MemoryView doesn't support insertion") diff --git a/libdebug/memory/chunked_memory_view.py b/libdebug/memory/chunked_memory_view.py new file mode 100644 index 00000000..cec0fb88 --- /dev/null +++ b/libdebug/memory/chunked_memory_view.py @@ -0,0 +1,114 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from libdebug.memory.abstract_memory_view import AbstractMemoryView + +if TYPE_CHECKING: + from collections.abc import Callable + + +class ChunkedMemoryView(AbstractMemoryView): + """A memory interface for the target process, intended for chunk-based memory access. + + Attributes: + getter (Callable[[int], bytes]): A function that reads a chunk of memory from the target process. + setter (Callable[[int, bytes], None]): A function that writes a chunk of memory to the target process. + unit_size (int, optional): The chunk size used by the getter and setter functions. Defaults to 8. + align_to (int, optional): The address alignment that must be used when reading and writing memory. Defaults to 1. + """ + + def __init__( + self: ChunkedMemoryView, + getter: Callable[[int], bytes], + setter: Callable[[int, bytes], None], + unit_size: int = 8, + align_to: int = 1, + ) -> None: + """Initializes the MemoryView.""" + super().__init__() + self.getter = getter + self.setter = setter + self.unit_size = unit_size + self.align_to = align_to + + def read(self: ChunkedMemoryView, address: int, size: int) -> bytes: + """Reads memory from the target process. + + Args: + address (int): The address to read from. + size (int): The number of bytes to read. + + Returns: + bytes: The read bytes. + """ + if self.align_to == 1: + data = b"" + + remainder = size % self.unit_size + + for i in range(address, address + size - remainder, self.unit_size): + data += self.getter(i) + + if remainder: + data += self.getter(address + size - remainder)[:remainder] + + return data + else: + prefix = address % self.align_to + prefix_size = self.unit_size - prefix + + data = self.getter(address - prefix)[prefix:] + + remainder = (size - prefix_size) % self.unit_size + + for i in range( + address + prefix_size, + address + size - remainder, + self.unit_size, + ): + data += self.getter(i) + + if remainder: + data += self.getter(address + size - remainder)[:remainder] + + return data + + def write(self: ChunkedMemoryView, address: int, data: bytes) -> None: + """Writes memory to the target process. + + Args: + address (int): The address to write to. + data (bytes): The data to write. + """ + size = len(data) + + if self.align_to == 1: + remainder = size % self.unit_size + base = address + else: + prefix = address % self.align_to + prefix_size = self.unit_size - prefix + + prev_data = self.getter(address - prefix) + + self.setter(address - prefix, prev_data[:prefix_size] + data[:prefix]) + + remainder = (size - prefix_size) % self.unit_size + base = address + prefix_size + + for i in range(base, address + size - remainder, self.unit_size): + self.setter(i, data[i - address : i - address + self.unit_size]) + + if remainder: + prev_data = self.getter(address + size - remainder) + self.setter( + address + size - remainder, + data[size - remainder :] + prev_data[remainder:], + ) diff --git a/libdebug/memory/direct_memory_view.py b/libdebug/memory/direct_memory_view.py new file mode 100644 index 00000000..5823b61e --- /dev/null +++ b/libdebug/memory/direct_memory_view.py @@ -0,0 +1,74 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from libdebug.memory.abstract_memory_view import AbstractMemoryView + +if TYPE_CHECKING: + from collections.abc import Callable + + +class DirectMemoryView(AbstractMemoryView): + """A memory interface for the target process, intended for direct memory access. + + Attributes: + getter (Callable[[int, int], bytes]): A function that reads a variable amount of data from the target's memory. + setter (Callable[[int, bytes], None]): A function that writes memory to the target process. + align_to (int, optional): The address alignment that must be used when reading and writing memory. Defaults to 1. + """ + + def __init__( + self: DirectMemoryView, + getter: Callable[[int, int], bytes], + setter: Callable[[int, bytes], None], + align_to: int = 1, + ) -> None: + """Initializes the MemoryView.""" + super().__init__() + self.getter = getter + self.setter = setter + self.align_to = align_to + + def read(self: DirectMemoryView, address: int, size: int) -> bytes: + """Reads memory from the target process. + + Args: + address (int): The address to read from. + size (int): The number of bytes to read. + + Returns: + bytes: The read bytes. + """ + if self.align_to == 1: + return self.getter(address, size) + else: + prefix = address % self.align_to + base_address = address - prefix + new_size = size + prefix + data = self.getter(base_address, new_size) + return data[prefix : prefix + size] + + def write(self: DirectMemoryView, address: int, data: bytes) -> None: + """Writes memory to the target process. + + Args: + address (int): The address to write to. + data (bytes): The data to write. + """ + size = len(data) + + if self.align_to == 1: + self.setter(address, data) + else: + prefix = address % self.align_to + base_address = address - prefix + new_size = size + prefix + prefix_data = self.getter(base_address, new_size) + new_data = prefix_data[:prefix] + data + prefix_data[prefix + size :] + self.setter(base_address, new_data) diff --git a/libdebug/memory/process_memory_manager.py b/libdebug/memory/process_memory_manager.py new file mode 100644 index 00000000..05f5cfa7 --- /dev/null +++ b/libdebug/memory/process_memory_manager.py @@ -0,0 +1,54 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +from __future__ import annotations + + +class ProcessMemoryManager: + """A class that provides accessors to the memory of a process, through /proc/pid/mem.""" + + def open(self: ProcessMemoryManager, process_id: int) -> None: + """Initializes the ProcessMemoryManager.""" + self.process_id = process_id + self._mem_file = None + + def _open(self: ProcessMemoryManager) -> None: + self._mem_file = open(f"/proc/{self.process_id}/mem", "r+b", buffering=0) + + def read(self: ProcessMemoryManager, address: int, size: int) -> bytes: + """Reads memory from the target process. + + Args: + address (int): The address to read from. + size (int): The number of bytes to read. + + Returns: + bytes: The read bytes. + """ + if not self._mem_file: + self._open() + + self._mem_file.seek(address) + return self._mem_file.read(size) + + def write(self: ProcessMemoryManager, address: int, data: bytes) -> None: + """Writes memory to the target process. + + Args: + address (int): The address to write to. + data (bytes): The data to write. + """ + if not self._mem_file: + self._open() + + self._mem_file.seek(address) + self._mem_file.write(data) + + def close(self: ProcessMemoryManager) -> None: + """Closes the memory file.""" + if self._mem_file: + self._mem_file.close() + self._mem_file = None diff --git a/libdebug/ptrace/ptrace_interface.py b/libdebug/ptrace/ptrace_interface.py index 073f06e9..f4246f82 100644 --- a/libdebug/ptrace/ptrace_interface.py +++ b/libdebug/ptrace/ptrace_interface.py @@ -13,10 +13,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from libdebug.architectures.ptrace_hardware_breakpoint_provider import ( - ptrace_hardware_breakpoint_manager_provider, -) from libdebug.architectures.register_helper import register_holder_provider +from libdebug.architectures.call_utilities_provider import call_utilities_provider from libdebug.cffi import _ptrace_cffi from libdebug.data.breakpoint import Breakpoint from libdebug.debugger.internal_debugger_instance_manager import ( @@ -50,10 +48,8 @@ ) if TYPE_CHECKING: - from libdebug.architectures.ptrace_hardware_breakpoint_manager import ( - PtraceHardwareBreakpointManager, - ) from libdebug.data.memory_map import MemoryMap + from libdebug.data.registers import Registers from libdebug.data.signal_catcher import SignalCatcher from libdebug.data.syscall_handler import SyscallHandler from libdebug.debugger.internal_debugger import InternalDebugger @@ -62,9 +58,6 @@ class PtraceInterface(DebuggingInterface): """The interface used by `_InternalDebugger` to communicate with the `ptrace` debugging backend.""" - hardware_bp_helpers: dict[int, PtraceHardwareBreakpointManager] - """The hardware breakpoint managers (one for each thread).""" - process_id: int | None """The process ID of the debugged process.""" @@ -85,20 +78,18 @@ def __init__(self: PtraceInterface) -> None: self._global_state = self.ffi.new("struct global_state*") self._global_state.t_HEAD = self.ffi.NULL self._global_state.dead_t_HEAD = self.ffi.NULL - self._global_state.b_HEAD = self.ffi.NULL + self._global_state.sw_b_HEAD = self.ffi.NULL + self._global_state.hw_b_HEAD = self.ffi.NULL self.process_id = 0 self.detached = False - self.hardware_bp_helpers = {} - self._disabled_aslr = False self.reset() def reset(self: PtraceInterface) -> None: """Resets the state of the interface.""" - self.hardware_bp_helpers.clear() self.lib_trace.free_thread_list(self._global_state) self.lib_trace.free_breakpoints(self._global_state) @@ -221,7 +212,6 @@ def kill(self: PtraceInterface) -> None: def cont(self: PtraceInterface) -> None: """Continues the execution of the process.""" - # Forward signals to the threads if self._internal_debugger.resume_context.threads_with_signals_to_forward: self.forward_signal() @@ -252,6 +242,7 @@ def cont(self: PtraceInterface) -> None: self._global_state, self.process_id, ) + if result < 0: errno_val = self.ffi.errno raise OSError(errno_val, errno.errorcode[errno_val]) @@ -319,7 +310,7 @@ def finish(self: PtraceInterface, thread: ThreadContext, heuristic: str) -> None invalidate_process_cache() elif heuristic == "backtrace": # Breakpoint to return address - last_saved_instruction_pointer = thread.current_return_address() + last_saved_instruction_pointer = thread.saved_ip # If a breakpoint already exists at the return address, we don't need to set a new one found = False @@ -331,15 +322,22 @@ def finish(self: PtraceInterface, thread: ThreadContext, heuristic: str) -> None ip_breakpoint = bp break + # If we find an existing breakpoint that is disabled, we enable it + # but we need to disable it back after the command + should_disable = False + if not found: # Check if we have enough hardware breakpoints available # Otherwise we use a software breakpoint - install_hw_bp = self.hardware_bp_helpers[thread.thread_id].available_breakpoints() > 0 + install_hw_bp = ( + self.lib_trace.get_remaining_hw_breakpoint_count(self._global_state, thread.thread_id) > 0 + ) ip_breakpoint = Breakpoint(last_saved_instruction_pointer, hardware=install_hw_bp) self.set_breakpoint(ip_breakpoint) elif not ip_breakpoint.enabled: self._enable_breakpoint(ip_breakpoint) + should_disable = True self.cont() self.wait() @@ -347,9 +345,60 @@ def finish(self: PtraceInterface, thread: ThreadContext, heuristic: str) -> None # Remove the breakpoint if it was set by us if not found: self.unset_breakpoint(ip_breakpoint) + # Disable the breakpoint if it was just enabled by us + elif should_disable: + self._disable_breakpoint(ip_breakpoint) else: raise ValueError(f"Unimplemented heuristic {heuristic}") + def next(self: PtraceInterface, thread: ThreadContext) -> None: + """Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns.""" + + opcode_window = thread.memory.read(thread.instruction_pointer, 8) + + # Check if the current instruction is a call and its skip amount + is_call, skip = call_utilities_provider(self._internal_debugger.arch).get_call_and_skip_amount(opcode_window) + + if is_call: + skip_address = thread.instruction_pointer + skip + + # If a breakpoint already exists at the return address, we don't need to set a new one + found = False + ip_breakpoint = self._internal_debugger.breakpoints.get(skip_address) + + if ip_breakpoint is not None: + found = True + + # If we find an existing breakpoint that is disabled, we enable it + # but we need to disable it back after the command + should_disable = False + + if not found: + # Check if we have enough hardware breakpoints available + # Otherwise we use a software breakpoint + install_hw_bp = ( + self.lib_trace.get_remaining_hw_breakpoint_count(self._global_state, thread.thread_id) > 0 + ) + ip_breakpoint = Breakpoint(skip_address, hardware=install_hw_bp) + self.set_breakpoint(ip_breakpoint) + elif not ip_breakpoint.enabled: + self._enable_breakpoint(ip_breakpoint) + should_disable = True + + self.cont() + self.wait() + + # Remove the breakpoint if it was set by us + if not found: + self.unset_breakpoint(ip_breakpoint) + # Disable the breakpoint if it was just enabled by us + elif should_disable: + self._disable_breakpoint(ip_breakpoint) + else: + # Step forward + self.step(thread) + self.wait() + def _setup_pipe(self: PtraceInterface) -> None: """Sets up the pipe manager for the child process. @@ -398,6 +447,7 @@ def wait(self: PtraceInterface) -> None: self._global_state, self.process_id, ) + cursor = result invalidate_process_cache() @@ -440,6 +490,16 @@ def forward_signal(self: PtraceInterface) -> None: def migrate_to_gdb(self: PtraceInterface) -> None: """Migrates the current process to GDB.""" + # Delete any hardware breakpoint + for bp in self._internal_debugger.breakpoints.values(): + if bp.hardware: + for thread in self._internal_debugger.threads: + self.lib_trace.unregister_hw_breakpoint( + self._global_state, + thread.thread_id, + bp.address, + ) + self.lib_trace.ptrace_detach_for_migration(self._global_state, self.process_id) def migrate_from_gdb(self: PtraceInterface) -> None: @@ -451,10 +511,15 @@ def migrate_from_gdb(self: PtraceInterface) -> None: # We have to reinstall any hardware breakpoint for bp in self._internal_debugger.breakpoints.values(): - if bp.hardware and bp.enabled: - for helper in self.hardware_bp_helpers.values(): - helper.remove_breakpoint(bp) - helper.install_breakpoint(bp) + if bp.hardware: + for thread in self._internal_debugger.threads: + self.lib_trace.register_hw_breakpoint( + self._global_state, + thread.thread_id, + bp.address, + bp.condition.encode().ljust(2, b"\x00"), + chr(bp.length).encode(), + ) def register_new_thread(self: PtraceInterface, new_thread_id: int) -> None: """Registers a new thread. @@ -468,23 +533,25 @@ def register_new_thread(self: PtraceInterface, new_thread_id: int) -> None: new_thread_id, ) - register_holder = register_holder_provider(register_file) + fp_register_file = self.lib_trace.get_thread_fp_regs(self._global_state, new_thread_id) + + register_holder = register_holder_provider(self._internal_debugger.arch, register_file, fp_register_file) with extend_internal_debugger(self._internal_debugger): thread = ThreadContext(new_thread_id, register_holder) self._internal_debugger.insert_new_thread(thread) - thread_hw_bp_helper = ptrace_hardware_breakpoint_manager_provider( - thread, - self._peek_user, - self._poke_user, - ) - self.hardware_bp_helpers[new_thread_id] = thread_hw_bp_helper # For any hardware breakpoints, we need to reapply them to the new thread for bp in self._internal_debugger.breakpoints.values(): if bp.hardware: - thread_hw_bp_helper.install_breakpoint(bp) + self.lib_trace.register_hw_breakpoint( + self._global_state, + new_thread_id, + bp.address, + bp.condition.encode().ljust(2, b"\x00"), + chr(bp.length).encode(), + ) def unregister_thread( self: PtraceInterface, @@ -503,9 +570,6 @@ def unregister_thread( self._internal_debugger.set_thread_as_dead(thread_id, exit_code=exit_code, exit_signal=exit_signal) - # Remove the hardware breakpoint manager for the thread - self.hardware_bp_helpers.pop(thread_id) - def _set_sw_breakpoint(self: PtraceInterface, bp: Breakpoint) -> None: """Sets a software breakpoint at the specified address. @@ -550,8 +614,22 @@ def set_breakpoint(self: PtraceInterface, bp: Breakpoint, insert: bool = True) - insert (bool): Whether the breakpoint has to be inserted or just enabled. """ if bp.hardware: - for helper in self.hardware_bp_helpers.values(): - helper.install_breakpoint(bp) + for thread in self._internal_debugger.threads: + if bp.condition == "x": + remaining = self.lib_trace.get_remaining_hw_breakpoint_count(self._global_state, thread.thread_id) + else: + remaining = self.lib_trace.get_remaining_hw_watchpoint_count(self._global_state, thread.thread_id) + + if not remaining: + raise ValueError("No more hardware breakpoints of this type available") + + self.lib_trace.register_hw_breakpoint( + self._global_state, + thread.thread_id, + bp.address, + bp.condition.encode().ljust(2, b"\x00"), + chr(bp.length).encode(), + ) elif insert: self._set_sw_breakpoint(bp) else: @@ -568,8 +646,12 @@ def unset_breakpoint(self: PtraceInterface, bp: Breakpoint, delete: bool = True) delete (bool): Whether the breakpoint has to be deleted or just disabled. """ if bp.hardware: - for helper in self.hardware_bp_helpers.values(): - helper.remove_breakpoint(bp) + for thread in self._internal_debugger.threads: + self.lib_trace.unregister_hw_breakpoint( + self._global_state, + thread.thread_id, + bp.address, + ) elif delete: self._unset_sw_breakpoint(bp) else: @@ -638,33 +720,22 @@ def poke_memory(self: PtraceInterface, address: int, value: int) -> None: error = self.ffi.errno raise OSError(error, errno.errorcode[error]) - def _peek_user(self: PtraceInterface, thread_id: int, address: int) -> int: - """Reads the memory at the specified address.""" - result = self.lib_trace.ptrace_peekuser(thread_id, address) - liblog.debugger( - "PEEKUSER at address %d returned with result %x", - address, - result, - ) - - error = self.ffi.errno - if error: - raise OSError(error, errno.errorcode[error]) + def fetch_fp_registers(self: PtraceInterface, registers: Registers) -> None: + """Fetches the floating-point registers of the specified thread. - return result + Args: + registers (Registers): The registers instance to update. + """ + liblog.debugger("Fetching floating-point registers for thread %d", registers._thread_id) + self.lib_trace.get_fp_regs(registers._thread_id, registers._fp_register_file) - def _poke_user(self: PtraceInterface, thread_id: int, address: int, value: int) -> None: - """Writes the memory at the specified address.""" - result = self.lib_trace.ptrace_pokeuser(thread_id, address, value) - liblog.debugger( - "POKEUSER at address %d returned with result %d", - address, - result, - ) + def flush_fp_registers(self: PtraceInterface, _: Registers) -> None: + """Flushes the floating-point registers of the specified thread. - if result == -1: - error = self.ffi.errno - raise OSError(error, errno.errorcode[error]) + Args: + registers (Registers): The registers instance to update. + """ + raise NotImplementedError("Flushing floating-point registers is automatically handled by the native code.") def _get_event_msg(self: PtraceInterface, thread_id: int) -> int: """Returns the event message.""" @@ -673,3 +744,17 @@ def _get_event_msg(self: PtraceInterface, thread_id: int) -> int: def maps(self: PtraceInterface) -> list[MemoryMap]: """Returns the memory maps of the process.""" return get_process_maps(self.process_id) + + def get_hit_watchpoint(self: PtraceInterface, thread_id: int) -> Breakpoint: + """Returns the watchpoint that has been hit.""" + address = self.lib_trace.get_hit_hw_breakpoint(self._global_state, thread_id) + + if not address: + return None + + bp = self._internal_debugger.breakpoints[address] + + if bp.condition != "x": + return bp + + return None diff --git a/libdebug/ptrace/ptrace_register_holder.py b/libdebug/ptrace/ptrace_register_holder.py index 1f09922a..aad6d592 100644 --- a/libdebug/ptrace/ptrace_register_holder.py +++ b/libdebug/ptrace/ptrace_register_holder.py @@ -25,6 +25,9 @@ class PtraceRegisterHolder(RegisterHolder): register_file: object """The register file of the target process, as returned by ptrace.""" + fp_register_file: object + """The floating-point register file of the target process, as returned by ptrace.""" + def poll(self: PtraceRegisterHolder, target: ThreadContext) -> None: """Poll the register values from the specified target.""" raise NotImplementedError("Do not call this method.") diff --git a/libdebug/ptrace/ptrace_status_handler.py b/libdebug/ptrace/ptrace_status_handler.py index 2717ac13..8ad42c47 100644 --- a/libdebug/ptrace/ptrace_status_handler.py +++ b/libdebug/ptrace/ptrace_status_handler.py @@ -84,7 +84,7 @@ def _handle_breakpoints(self: PtraceStatusHandler, thread_id: int) -> bool: else: # If the trap was caused by a software breakpoint, we need to restore the original instruction # and set the instruction pointer to the previous instruction. - ip -= software_breakpoint_byte_size() + ip -= software_breakpoint_byte_size(self.internal_debugger.arch) bp = self.internal_debugger.breakpoints.get(ip) if bp and bp.enabled and not bp._disabled_for_step: @@ -102,7 +102,7 @@ def _handle_breakpoints(self: PtraceStatusHandler, thread_id: int) -> bool: # Manage watchpoints if not bp: - bp = self.ptrace_interface.hardware_bp_helpers[thread_id].is_watchpoint_hit() + bp = self.ptrace_interface.get_hit_watchpoint(thread_id) if bp: liblog.debugger("Watchpoint hit at 0x%x", bp.address) @@ -380,13 +380,12 @@ def _internal_signal_handler( case StopEvents.FORK_EVENT: # The process has been forked liblog.warning( - f"Process {pid} forked. Continuing execution of the parent process. The child process will be stopped until the user decides to attach to it." + f"Process {pid} forked. Continuing execution of the parent process. The child process will be stopped until the user decides to attach to it.", ) self.forward_signal = False def _handle_change(self: PtraceStatusHandler, pid: int, status: int, results: list) -> None: """Handle a change in the status of a traced process.""" - # Initialize the forward_signal flag self.forward_signal = True diff --git a/libdebug/state/thread_context.py b/libdebug/state/thread_context.py index c65d0e0b..6b165a34 100644 --- a/libdebug/state/thread_context.py +++ b/libdebug/state/thread_context.py @@ -1,6 +1,6 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio, Francesco Panebianco. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # from __future__ import annotations @@ -17,10 +17,10 @@ from libdebug.utils.signal_utils import resolve_signal_name, resolve_signal_number if TYPE_CHECKING: - from libdebug.data.memory_view import MemoryView from libdebug.data.register_holder import RegisterHolder from libdebug.data.registers import Registers from libdebug.debugger.internal_debugger import InternalDebugger + from libdebug.memory.abstract_memory_view import AbstractMemoryView class ThreadContext: @@ -79,7 +79,7 @@ def __init__(self: ThreadContext, thread_id: int, registers: RegisterHolder) -> self._internal_debugger = provide_internal_debugger(self) self._thread_id = thread_id regs_class = registers.provide_regs_class() - self.regs = regs_class() + self.regs = regs_class(thread_id) registers.apply_on_regs(self.regs, regs_class) registers.apply_on_thread(self, ThreadContext) @@ -93,12 +93,12 @@ def dead(self: ThreadContext) -> bool: return self._dead @property - def memory(self: ThreadContext) -> MemoryView: + def memory(self: ThreadContext) -> AbstractMemoryView: """The memory view of the debugged process.""" return self._internal_debugger.memory @property - def mem(self: ThreadContext) -> MemoryView: + def mem(self: ThreadContext) -> AbstractMemoryView: """Alias for the `memory` property. Get the memory view of the process. @@ -167,7 +167,7 @@ def signal(self: ThreadContext, signal: str | int) -> None: self._internal_debugger._ensure_process_stopped() if self._signal_number != 0: liblog.debugger( - f"Overwriting signal {resolve_signal_name(self._signal_number)} with {resolve_signal_name(signal) if isinstance(signal, int) else signal}." + f"Overwriting signal {resolve_signal_name(self._signal_number)} with {resolve_signal_name(signal) if isinstance(signal, int) else signal}.", ) if isinstance(signal, str): signal = resolve_signal_number(signal) @@ -181,7 +181,7 @@ def backtrace(self: ThreadContext, as_symbols: bool = False) -> list: as_symbols (bool, optional): Whether to return the backtrace as symbols """ self._internal_debugger._ensure_process_stopped() - stack_unwinder = stack_unwinding_provider() + stack_unwinder = stack_unwinding_provider(self._internal_debugger.arch) backtrace = stack_unwinder.unwind(self) if as_symbols: maps = self._internal_debugger.debugging_interface.maps() @@ -191,7 +191,7 @@ def backtrace(self: ThreadContext, as_symbols: bool = False) -> list: def print_backtrace(self: ThreadContext) -> None: """Prints the current backtrace of the thread.""" self._internal_debugger._ensure_process_stopped() - stack_unwinder = stack_unwinding_provider() + stack_unwinder = stack_unwinding_provider(self._internal_debugger.arch) backtrace = stack_unwinder.unwind(self) maps = self._internal_debugger.debugging_interface.maps() for return_address in backtrace: @@ -201,16 +201,19 @@ def print_backtrace(self: ThreadContext) -> None: else: print(f"{PrintStyle.RED}{return_address:#x} <{return_address_symbol}> {PrintStyle.RESET}") - def current_return_address(self: ThreadContext) -> int: - """Returns the return address of the current function.""" + @property + def saved_ip(self: ThreadContext) -> int: + """The return address of the current function.""" self._internal_debugger._ensure_process_stopped() - stack_unwinder = stack_unwinding_provider() - + stack_unwinder = stack_unwinding_provider(self._internal_debugger.arch) + try: - return_address = stack_unwinder.get_return_address(self) + return_address = stack_unwinder.get_return_address(self, self._internal_debugger.debugging_interface.maps()) except (OSError, ValueError) as e: - raise ValueError("Failed to get the return address. Check stack frame registers (e.g., base pointer).") from e - + raise ValueError( + "Failed to get the return address. Check stack frame registers (e.g., base pointer).", + ) from e + return return_address def step(self: ThreadContext) -> None: @@ -246,6 +249,10 @@ def finish(self: ThreadContext, heuristic: str = "backtrace") -> None: """ self._internal_debugger.finish(self, heuristic=heuristic) + def next(self: ThreadContext) -> None: + """Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns.""" + self._internal_debugger.next(self) + def si(self: ThreadContext) -> None: """Alias for the `step` method. @@ -279,3 +286,7 @@ def fin(self: ThreadContext, heuristic: str = "backtrace") -> None: heuristic (str, optional): The heuristic to use. Defaults to "backtrace". """ self._internal_debugger.finish(self, heuristic) + + def ni(self: ThreadContext) -> None: + """Alias for the `next` method. Executes the next instruction of the process. If the instruction is a call, the debugger will continue until the called function returns.""" + self._internal_debugger.next(self) diff --git a/libdebug/utils/arch_mappings.py b/libdebug/utils/arch_mappings.py new file mode 100644 index 00000000..bc7a7694 --- /dev/null +++ b/libdebug/utils/arch_mappings.py @@ -0,0 +1,31 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +ARCH_MAPPING = { + "x86": "i386", + "x86_64": "amd64", + "x64": "amd64", + "arm64": "aarch64", +} + + +def map_arch(arch: str) -> str: + """Map the architecture to the correct format. + + Args: + arch (str): the architecture to map. + + Returns: + str: the mapped architecture. + """ + arch = arch.lower() + + if arch in ARCH_MAPPING.values(): + return arch + elif arch in ARCH_MAPPING: + return ARCH_MAPPING[arch] + else: + raise ValueError(f"Architecture {arch} not supported.") diff --git a/libdebug/utils/debugger_wrappers.py b/libdebug/utils/debugger_wrappers.py index 9c57fff8..f56b7cb4 100644 --- a/libdebug/utils/debugger_wrappers.py +++ b/libdebug/utils/debugger_wrappers.py @@ -40,7 +40,10 @@ def change_state_function_thread(method: callable) -> callable: @wraps(method) def wrapper( - self: InternalDebugger, thread: ThreadContext, *args: ..., **kwargs: ... + self: InternalDebugger, + thread: ThreadContext, + *args: ..., + **kwargs: ..., ) -> ...: if not self.instanced: raise RuntimeError( diff --git a/libdebug/utils/elf_utils.py b/libdebug/utils/elf_utils.py index f01e989f..7a5212bb 100644 --- a/libdebug/utils/elf_utils.py +++ b/libdebug/utils/elf_utils.py @@ -211,22 +211,37 @@ def resolve_address(path: str, address: int) -> str: @functools.cache -def is_pie(path: str) -> bool: - """Returns True if the specified ELF file is position independent, False otherwise. +def parse_elf_characteristics(path: str) -> tuple[bool, int, str]: + """Returns a tuple containing the PIE flag, the entry point and the architecture of the specified ELF file. Args: path (str): The path to the ELF file. Returns: - bool: True if the specified ELF file is position independent, False otherwise. + tuple: A tuple containing the PIE flag, the entry point and the architecture of the specified ELF file. """ with Path(path).open("rb") as elf_file: elf = ELFFile(elf_file) - return elf.header.e_type == "ET_DYN" + pie = elf.header.e_type == "ET_DYN" + entry_point = elf.header.e_entry + arch = elf.get_machine_arch() + + return pie, entry_point, arch + + +def is_pie(path: str) -> bool: + """Returns True if the specified ELF file is position independent, False otherwise. + + Args: + path (str): The path to the ELF file. + + Returns: + bool: True if the specified ELF file is position independent, False otherwise. + """ + return parse_elf_characteristics(path)[0] -@functools.cache def get_entry_point(path: str) -> int: """Returns the entry point of the specified ELF file. @@ -236,7 +251,16 @@ def get_entry_point(path: str) -> int: Returns: int: The entry point of the specified ELF file. """ - with Path(path).open("rb") as elf_file: - elf = ELFFile(elf_file) + return parse_elf_characteristics(path)[1] + - return elf.header.e_entry +def elf_architecture(path: str) -> str: + """Returns the architecture of the specified ELF file. + + Args: + path (str): The path to the ELF file. + + Returns: + str: The architecture of the specified ELF file. + """ + return parse_elf_characteristics(path)[2] diff --git a/libdebug/utils/libcontext.py b/libdebug/utils/libcontext.py index f502f9cc..e1167bf7 100644 --- a/libdebug/utils/libcontext.py +++ b/libdebug/utils/libcontext.py @@ -6,11 +6,13 @@ from __future__ import annotations +import platform import sys from contextlib import contextmanager from copy import deepcopy from libdebug.liblog import liblog +from libdebug.utils.arch_mappings import map_arch class LibContext: @@ -61,7 +63,6 @@ def __init__(self: LibContext) -> None: self._general_logger = "DEBUG" self._initialized = True - self._arch = "amd64" self._terminal = [] def _set_debug_level_for_all(self: LibContext) -> None: @@ -151,21 +152,9 @@ def general_logger(self: LibContext, value: str) -> None: ) @property - def arch(self: LibContext) -> str: - """Property getter for architecture. - - Returns: - _arch (str): the current architecture. - """ - return self._arch - - @arch.setter - def arch(self: LibContext, value: str) -> None: - """Property setter for arch, ensuring it's a valid architecture.""" - if value in ["amd64"]: - self._arch = value - else: - raise RuntimeError("The specified architecture is not supported") + def platform(self: LibContext) -> str: + """Return the current platform.""" + return map_arch(platform.machine()) @property def terminal(self: LibContext) -> list[str]: diff --git a/libdebug/utils/platform_utils.py b/libdebug/utils/platform_utils.py new file mode 100644 index 00000000..f83b7e88 --- /dev/null +++ b/libdebug/utils/platform_utils.py @@ -0,0 +1,23 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + + +def get_platform_register_size(arch: str) -> int: + """Get the register size of the platform. + + Args: + arch (str): The architecture of the platform. + + Returns: + int: The register size in bytes. + """ + match arch: + case "amd64": + return 8 + case "aarch64": + return 8 + case _: + raise ValueError(f"Architecture {arch} not supported.") diff --git a/libdebug/utils/print_style.py b/libdebug/utils/print_style.py index 9fe884b7..48dca068 100644 --- a/libdebug/utils/print_style.py +++ b/libdebug/utils/print_style.py @@ -4,6 +4,7 @@ # Licensed under the MIT license. See LICENSE file in the project root for details. # + class PrintStyle: """Class to define colors for the terminal.""" diff --git a/libdebug/utils/syscall_utils.py b/libdebug/utils/syscall_utils.py index ca4bdc9c..f96cb93e 100644 --- a/libdebug/utils/syscall_utils.py +++ b/libdebug/utils/syscall_utils.py @@ -10,8 +10,6 @@ import requests -from libdebug.utils.libcontext import libcontext - SYSCALLS_REMOTE = "https://syscalls.mebeim.net/db" LOCAL_FOLDER_PATH = (Path.home() / ".cache" / "libdebug" / "syscalls").resolve() @@ -21,6 +19,8 @@ def get_remote_definition_url(arch: str) -> str: match arch: case "amd64": return f"{SYSCALLS_REMOTE}/x86/64/x64/latest/table.json" + case "aarch64": + return f"{SYSCALLS_REMOTE}/arm64/64/aarch64/latest/table.json" case _: raise ValueError(f"Architecture {arch} not supported") @@ -55,9 +55,9 @@ def get_syscall_definitions(arch: str) -> dict: @functools.cache -def resolve_syscall_number(name: str) -> int: +def resolve_syscall_number(architecture: str, name: str) -> int: """Resolve a syscall name to its number.""" - definitions = get_syscall_definitions(libcontext.arch) + definitions = get_syscall_definitions(architecture) for syscall in definitions["syscalls"]: if syscall["name"] == name: @@ -67,9 +67,9 @@ def resolve_syscall_number(name: str) -> int: @functools.cache -def resolve_syscall_name(number: int) -> str: +def resolve_syscall_name(architecture: str, number: int) -> str: """Resolve a syscall number to its name.""" - definitions = get_syscall_definitions(libcontext.arch) + definitions = get_syscall_definitions(architecture) for syscall in definitions["syscalls"]: if syscall["number"] == number: @@ -79,9 +79,9 @@ def resolve_syscall_name(number: int) -> str: @functools.cache -def resolve_syscall_arguments(number: int) -> list[str]: +def resolve_syscall_arguments(architecture: str, number: int) -> list[str]: """Resolve a syscall number to its argument definition.""" - definitions = get_syscall_definitions(libcontext.arch) + definitions = get_syscall_definitions(architecture) for syscall in definitions["syscalls"]: if syscall["number"] == number: @@ -91,8 +91,8 @@ def resolve_syscall_arguments(number: int) -> list[str]: @functools.cache -def get_all_syscall_numbers() -> list[int]: +def get_all_syscall_numbers(architecture: str) -> list[int]: """Retrieves all the syscall numbers.""" - definitions = get_syscall_definitions(libcontext.arch) + definitions = get_syscall_definitions(architecture) return [syscall["number"] for syscall in definitions["syscalls"]] diff --git a/pyproject.toml b/pyproject.toml index 761dfddf..f36dc557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ issues = "https://github.com/libdebug/libdebug/issues" [tool.ruff] include = ["pyproject.toml", "libdebug/**/*.py"] -exclude = ["libdebug/cffi/*.py"] +exclude = ["libdebug/cffi/*.py", "test/"] line-length = 120 indent-width = 4 target-version = "py312" diff --git a/setup.py b/setup.py index 2eff8455..a337f121 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ if not ( os.path.isfile("/usr/include/sys/ptrace.h") or os.path.isfile("/usr/include/x86_64-linux-gnu/sys/ptrace.h") + or os.path.isfile("/usr/include/aarch64-linux-gnu/sys/ptrace.h") ): print("Required C libraries not found. Please install ptrace or kernel headers") exit(1) @@ -68,7 +69,7 @@ def get_outputs(self): setup( name="libdebug", - version="0.5.4", + version="0.6.0", author="JinBlack, Io_no, MrIndeciso, Frank01001", description="A library to debug binary programs", packages=find_packages(include=["libdebug", "libdebug.*"]), diff --git a/test/aarch64/binaries/antidebug_brute_test b/test/aarch64/binaries/antidebug_brute_test new file mode 100755 index 00000000..93d69473 Binary files /dev/null and b/test/aarch64/binaries/antidebug_brute_test differ diff --git a/test/aarch64/binaries/attach_test b/test/aarch64/binaries/attach_test new file mode 100755 index 00000000..2fea0a91 Binary files /dev/null and b/test/aarch64/binaries/attach_test differ diff --git a/test/aarch64/binaries/backtrace_test b/test/aarch64/binaries/backtrace_test new file mode 100755 index 00000000..6a620ee5 Binary files /dev/null and b/test/aarch64/binaries/backtrace_test differ diff --git a/test/aarch64/binaries/basic_test b/test/aarch64/binaries/basic_test new file mode 100755 index 00000000..149c055b Binary files /dev/null and b/test/aarch64/binaries/basic_test differ diff --git a/test/aarch64/binaries/basic_test_pie b/test/aarch64/binaries/basic_test_pie new file mode 100755 index 00000000..204fb256 Binary files /dev/null and b/test/aarch64/binaries/basic_test_pie differ diff --git a/test/aarch64/binaries/benchmark b/test/aarch64/binaries/benchmark new file mode 100755 index 00000000..ca47274c Binary files /dev/null and b/test/aarch64/binaries/benchmark differ diff --git a/test/aarch64/binaries/breakpoint_test b/test/aarch64/binaries/breakpoint_test new file mode 100755 index 00000000..ac5cbb2d Binary files /dev/null and b/test/aarch64/binaries/breakpoint_test differ diff --git a/test/aarch64/binaries/brute_test b/test/aarch64/binaries/brute_test new file mode 100755 index 00000000..a2993312 Binary files /dev/null and b/test/aarch64/binaries/brute_test differ diff --git a/test/aarch64/binaries/catch_signal_test b/test/aarch64/binaries/catch_signal_test new file mode 100755 index 00000000..79799718 Binary files /dev/null and b/test/aarch64/binaries/catch_signal_test differ diff --git a/test/aarch64/binaries/executable_section_test b/test/aarch64/binaries/executable_section_test new file mode 100755 index 00000000..7f2d21dc Binary files /dev/null and b/test/aarch64/binaries/executable_section_test differ diff --git a/test/aarch64/binaries/finish_test b/test/aarch64/binaries/finish_test new file mode 100755 index 00000000..f461a5d7 Binary files /dev/null and b/test/aarch64/binaries/finish_test differ diff --git a/test/aarch64/binaries/floating_point_test b/test/aarch64/binaries/floating_point_test new file mode 100755 index 00000000..428cb425 Binary files /dev/null and b/test/aarch64/binaries/floating_point_test differ diff --git a/test/aarch64/binaries/handle_syscall_test b/test/aarch64/binaries/handle_syscall_test new file mode 100755 index 00000000..0623cfd7 Binary files /dev/null and b/test/aarch64/binaries/handle_syscall_test differ diff --git a/test/aarch64/binaries/jumpstart_test b/test/aarch64/binaries/jumpstart_test new file mode 100755 index 00000000..78598581 Binary files /dev/null and b/test/aarch64/binaries/jumpstart_test differ diff --git a/test/aarch64/binaries/jumpstart_test_preload.so b/test/aarch64/binaries/jumpstart_test_preload.so new file mode 100755 index 00000000..05c57d6b Binary files /dev/null and b/test/aarch64/binaries/jumpstart_test_preload.so differ diff --git a/test/aarch64/binaries/memory_test b/test/aarch64/binaries/memory_test new file mode 100755 index 00000000..1da1a8bc Binary files /dev/null and b/test/aarch64/binaries/memory_test differ diff --git a/test/aarch64/binaries/memory_test_2 b/test/aarch64/binaries/memory_test_2 new file mode 100755 index 00000000..52f25bfe Binary files /dev/null and b/test/aarch64/binaries/memory_test_2 differ diff --git a/test/aarch64/binaries/segfault_test b/test/aarch64/binaries/segfault_test new file mode 100755 index 00000000..58701df2 Binary files /dev/null and b/test/aarch64/binaries/segfault_test differ diff --git a/test/aarch64/binaries/signals_multithread_det_test b/test/aarch64/binaries/signals_multithread_det_test new file mode 100755 index 00000000..a7c660d5 Binary files /dev/null and b/test/aarch64/binaries/signals_multithread_det_test differ diff --git a/test/aarch64/binaries/signals_multithread_undet_test b/test/aarch64/binaries/signals_multithread_undet_test new file mode 100755 index 00000000..22c78cd8 Binary files /dev/null and b/test/aarch64/binaries/signals_multithread_undet_test differ diff --git a/test/aarch64/binaries/speed_test b/test/aarch64/binaries/speed_test new file mode 100755 index 00000000..caf5bc11 Binary files /dev/null and b/test/aarch64/binaries/speed_test differ diff --git a/test/aarch64/binaries/thread_test b/test/aarch64/binaries/thread_test new file mode 100755 index 00000000..1be46c07 Binary files /dev/null and b/test/aarch64/binaries/thread_test differ diff --git a/test/aarch64/binaries/thread_test_complex b/test/aarch64/binaries/thread_test_complex new file mode 100755 index 00000000..237d9839 Binary files /dev/null and b/test/aarch64/binaries/thread_test_complex differ diff --git a/test/aarch64/binaries/watchpoint_test b/test/aarch64/binaries/watchpoint_test new file mode 100755 index 00000000..f30ad912 Binary files /dev/null and b/test/aarch64/binaries/watchpoint_test differ diff --git a/test/aarch64/run_suite.py b/test/aarch64/run_suite.py new file mode 100644 index 00000000..ebf8be5d --- /dev/null +++ b/test/aarch64/run_suite.py @@ -0,0 +1,78 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import sys +from unittest import TestLoader, TestSuite, TextTestRunner + +from scripts.attach_detach_test import AttachDetachTest +from scripts.auto_waiting_test import AutoWaitingTest +from scripts.backtrace_test import BacktraceTest +from scripts.basic_test import BasicTest +from scripts.basic_test_pie import BasicTestPie +from scripts.basic_test_hw import BasicTestHw +from scripts.breakpoint_test import BreakpointTest +from scripts.brute_test import BruteTest +from scripts.builtin_handler_test import BuiltinHandlerTest +from scripts.callback_test import CallbackTest +from scripts.catch_signal_test import CatchSignalTest +from scripts.control_flow_test import ControlFlowTest +from scripts.death_test import DeathTest +from scripts.finish_test import FinishTest +from scripts.floating_point_test import FloatingPointTest +from scripts.handle_syscall_test import HandleSyscallTest +from scripts.hijack_syscall_test import HijackSyscallTest +from scripts.jumpstart_test import JumpstartTest +from scripts.memory_test import MemoryTest +from scripts.next_test import NextTest +from scripts.signals_multithread_test import SignalMultithreadTest +from scripts.speed_test import SpeedTest +from scripts.thread_test_complex import ThreadTestComplex +from scripts.thread_test import ThreadTest +from scripts.watchpoint_test import WatchpointTest + +def fast_suite(): + suite = TestSuite() + + suite.addTest(TestLoader().loadTestsFromTestCase(AttachDetachTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(AutoWaitingTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(BacktraceTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(BasicTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(BasicTestPie)) + suite.addTest(TestLoader().loadTestsFromTestCase(BasicTestHw)) + suite.addTest(TestLoader().loadTestsFromTestCase(BreakpointTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(BruteTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(BuiltinHandlerTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(CallbackTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(CatchSignalTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(ControlFlowTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(DeathTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(FinishTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(FloatingPointTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(HandleSyscallTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(HijackSyscallTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(JumpstartTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(MemoryTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(NextTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(SignalMultithreadTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(SpeedTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(ThreadTestComplex)) + suite.addTest(TestLoader().loadTestsFromTestCase(ThreadTest)) + suite.addTest(TestLoader().loadTestsFromTestCase(WatchpointTest)) + + return suite + + +if __name__ == "__main__": + if sys.version_info >= (3, 12): + runner = TextTestRunner(verbosity=2, durations=3) + else: + runner = TextTestRunner(verbosity=2) + + suite = fast_suite() + + runner.run(suite) + + sys.exit(0) diff --git a/test/aarch64/scripts/attach_detach_test.py b/test/aarch64/scripts/attach_detach_test.py new file mode 100644 index 00000000..97311f40 --- /dev/null +++ b/test/aarch64/scripts/attach_detach_test.py @@ -0,0 +1,92 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import logging +import unittest + +from pwn import process + +from libdebug import debugger + +logging.getLogger("pwnlib").setLevel(logging.ERROR) + + +class AttachDetachTest(unittest.TestCase): + def setUp(self): + pass + + def test_attach(self): + r = process("binaries/attach_test") + + d = debugger() + d.attach(r.pid) + bp = d.breakpoint("printName", hardware=True) + d.cont() + + r.recvuntil(b"name:") + r.sendline(b"Io_no") + + self.assertTrue(d.regs.pc == bp.address) + + d.cont() + + d.kill() + + def test_attach_and_detach_1(self): + r = process("binaries/attach_test") + + d = debugger() + + # Attach to the process + d.attach(r.pid) + d.detach() + + # Validate that, after detaching, the process is still running + r.recvuntil(b"name:", timeout=1) + r.sendline(b"Io_no") + + r.kill() + + def test_attach_and_detach_2(self): + d = debugger("binaries/attach_test") + + # Run the process + r = d.run() + d.detach() + + # Validate that, after detaching, the process is still running + r.recvuntil(b"name:", timeout=1) + r.sendline(b"Io_no") + + d.kill() + + def test_attach_and_detach_3(self): + d = debugger("binaries/attach_test") + + r = d.run() + + # We must ensure that any breakpoint is unset before detaching + d.breakpoint(0xa04, file="binary") + d.breakpoint(0xa08, hardware=True, file="binary") + + d.detach() + + # Validate that, after detaching, the process is still running + r.recvuntil(b"name:", timeout=1) + r.sendline(b"Io_no") + + d.kill() + + def test_attach_and_detach_4(self): + r = process("binaries/attach_test") + + d = debugger() + d.attach(r.pid) + d.detach() + d.kill() + + # Validate that, after detaching and killing, the process is effectively terminated + self.assertRaises(EOFError, r.sendline, b"provola") \ No newline at end of file diff --git a/test/aarch64/scripts/auto_waiting_test.py b/test/aarch64/scripts/auto_waiting_test.py new file mode 100644 index 00000000..8de2e2dc --- /dev/null +++ b/test/aarch64/scripts/auto_waiting_test.py @@ -0,0 +1,63 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import unittest + +from libdebug import debugger + + +class AutoWaitingTest(unittest.TestCase): + def setUp(self): + # Redirect logging to a string buffer + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def test_bps_auto_waiting(self): + d = debugger("binaries/breakpoint_test", auto_interrupt_on_command=False) + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary") + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + counter += 1 + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + + d.cont() + + d.kill() \ No newline at end of file diff --git a/test/aarch64/scripts/backtrace_test.py b/test/aarch64/scripts/backtrace_test.py new file mode 100644 index 00000000..7c497e13 --- /dev/null +++ b/test/aarch64/scripts/backtrace_test.py @@ -0,0 +1,108 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2023-2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class BacktraceTest(unittest.TestCase): + def setUp(self): + self.d = debugger("binaries/backtrace_test") + + def test_backtrace(self): + d = self.d + + d.run() + + bp0 = d.breakpoint("main+8") + bp1 = d.breakpoint("function1+8") + bp2 = d.breakpoint("function2+8") + bp3 = d.breakpoint("function3+8") + bp4 = d.breakpoint("function4+8") + bp5 = d.breakpoint("function5+8") + bp6 = d.breakpoint("function6+8") + + d.cont() + + self.assertTrue(d.regs.pc == bp0.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual(backtrace[:1], ["main+8"]) + + d.cont() + + self.assertTrue(d.regs.pc == bp1.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual(backtrace[:2], ["function1+8", "main+c"]) + + d.cont() + + self.assertTrue(d.regs.pc == bp2.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual(backtrace[:3], ["function2+8", "function1+10", "main+c"]) + + d.cont() + + self.assertTrue(d.regs.pc == bp3.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual( + backtrace[:4], ["function3+8", "function2+18", "function1+10", "main+c"] + ) + + d.cont() + + self.assertTrue(d.regs.pc == bp4.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual( + backtrace[:5], + ["function4+8", "function3+18", "function2+18", "function1+10", "main+c"], + ) + + d.cont() + + self.assertTrue(d.regs.pc == bp5.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual( + backtrace[:6], + [ + "function5+8", + "function4+18", + "function3+18", + "function2+18", + "function1+10", + "main+c", + ], + ) + + d.cont() + + self.assertTrue(d.regs.pc == bp6.address) + backtrace = d.backtrace(as_symbols=True) + self.assertIn("_start", backtrace.pop()) + self.assertEqual( + backtrace[:7], + [ + "function6+8", + "function5+18", + "function4+18", + "function3+18", + "function2+18", + "function1+10", + "main+c", + ], + ) + + d.kill() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/aarch64/scripts/basic_test.py b/test/aarch64/scripts/basic_test.py new file mode 100644 index 00000000..758ffb94 --- /dev/null +++ b/test/aarch64/scripts/basic_test.py @@ -0,0 +1,135 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class BasicTest(unittest.TestCase): + + def test_basic(self): + d = debugger("binaries/basic_test") + d.run() + bp = d.breakpoint("register_test") + d.cont() + assert bp.address == d.regs.pc + d.cont() + d.kill() + d.terminate() + + def test_registers(self): + d = debugger("binaries/basic_test") + d.run() + + bp = d.breakpoint(0x4008a4) + + d.cont() + + assert d.regs.pc == bp.address + + assert d.regs.x0 == 0x4444333322221111 + assert d.regs.x1 == 0x8888777766665555 + assert d.regs.x2 == 0xccccbbbbaaaa9999 + assert d.regs.x3 == 0x1111ffffeeeedddd + assert d.regs.x4 == 0x5555444433332222 + assert d.regs.x5 == 0x9999888877776666 + assert d.regs.x6 == 0xddddccccbbbbaaaa + assert d.regs.x7 == 0x22221111ffffeeee + assert d.regs.x8 == 0x6666555544443333 + assert d.regs.x9 == 0xaaaa999988887777 + assert d.regs.x10 == 0xeeeeddddccccbbbb + assert d.regs.x11 == 0x333322221111ffff + assert d.regs.x12 == 0x7777666655554444 + assert d.regs.x13 == 0xbbbbaaaa99998888 + assert d.regs.x14 == 0xffffeeeeddddcccc + assert d.regs.x15 == 0x4444333322221111 + assert d.regs.x16 == 0x8888777766665555 + assert d.regs.x17 == 0xccccbbbbaaaa9999 + assert d.regs.x18 == 0x1111ffffeeeedddd + assert d.regs.x19 == 0x5555444433332222 + assert d.regs.x20 == 0x9999888877776666 + assert d.regs.x21 == 0xddddccccbbbbaaaa + assert d.regs.x22 == 0x22221111ffffeeee + assert d.regs.x23 == 0x6666555544443333 + assert d.regs.x24 == 0xaaaa999988887777 + assert d.regs.x25 == 0xeeeeddddccccbbbb + assert d.regs.x26 == 0x333322221111ffff + assert d.regs.x27 == 0x7777666655554444 + assert d.regs.x28 == 0xbbbbaaaa99998888 + assert d.regs.x29 == 0xffffeeeeddddcccc + assert d.regs.x30 == 0x4444333322221111 + + assert d.regs.lr == 0x4444333322221111 + assert d.regs.fp == 0xffffeeeeddddcccc + assert d.regs.xzr == 0 + assert d.regs.wzr == 0 + + d.regs.xzr = 0x123456789abcdef0 + d.regs.wzr = 0x12345678 + + assert d.regs.xzr == 0 + assert d.regs.wzr == 0 + + assert d.regs.w0 == 0x22221111 + assert d.regs.w1 == 0x66665555 + assert d.regs.w2 == 0xaaaa9999 + assert d.regs.w3 == 0xeeeedddd + assert d.regs.w4 == 0x33332222 + assert d.regs.w5 == 0x77776666 + assert d.regs.w6 == 0xbbbbaaaa + assert d.regs.w7 == 0xffffeeee + assert d.regs.w8 == 0x44443333 + assert d.regs.w9 == 0x88887777 + assert d.regs.w10 == 0xccccbbbb + assert d.regs.w11 == 0x1111ffff + assert d.regs.w12 == 0x55554444 + assert d.regs.w13 == 0x99998888 + assert d.regs.w14 == 0xddddcccc + assert d.regs.w15 == 0x22221111 + assert d.regs.w16 == 0x66665555 + assert d.regs.w17 == 0xaaaa9999 + assert d.regs.w18 == 0xeeeedddd + assert d.regs.w19 == 0x33332222 + assert d.regs.w20 == 0x77776666 + assert d.regs.w21 == 0xbbbbaaaa + assert d.regs.w22 == 0xffffeeee + assert d.regs.w23 == 0x44443333 + assert d.regs.w24 == 0x88887777 + assert d.regs.w25 == 0xccccbbbb + assert d.regs.w26 == 0x1111ffff + assert d.regs.w27 == 0x55554444 + assert d.regs.w28 == 0x99998888 + assert d.regs.w29 == 0xddddcccc + assert d.regs.w30 == 0x22221111 + + d.cont() + + d.kill() + d.terminate() + + def test_step(self): + d = debugger("binaries/basic_test") + + d.run() + bp = d.breakpoint("register_test") + d.cont() + + assert bp.address == d.regs.pc + assert bp.hit_count == 1 + + d.step() + + assert (bp.address + 4) == d.regs.pc + assert bp.hit_count == 1 + + d.step() + + assert (bp.address + 8) == d.regs.pc + assert bp.hit_count == 1 + + d.kill() + d.terminate() diff --git a/test/aarch64/scripts/basic_test_hw.py b/test/aarch64/scripts/basic_test_hw.py new file mode 100644 index 00000000..e0320b84 --- /dev/null +++ b/test/aarch64/scripts/basic_test_hw.py @@ -0,0 +1,123 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class BasicTestHw(unittest.TestCase): + def test_basic(self): + d = debugger("binaries/basic_test") + d.run() + bp = d.breakpoint("register_test", hardware=True) + d.cont() + assert bp.address == d.regs.pc + d.cont() + d.kill() + d.terminate() + + def test_registers(self): + d = debugger("binaries/basic_test") + d.run() + + bp = d.breakpoint(0x4008a4, hardware=True) + + d.cont() + + assert d.regs.pc == bp.address + + assert d.regs.x0 == 0x4444333322221111 + assert d.regs.x1 == 0x8888777766665555 + assert d.regs.x2 == 0xccccbbbbaaaa9999 + assert d.regs.x3 == 0x1111ffffeeeedddd + assert d.regs.x4 == 0x5555444433332222 + assert d.regs.x5 == 0x9999888877776666 + assert d.regs.x6 == 0xddddccccbbbbaaaa + assert d.regs.x7 == 0x22221111ffffeeee + assert d.regs.x8 == 0x6666555544443333 + assert d.regs.x9 == 0xaaaa999988887777 + assert d.regs.x10 == 0xeeeeddddccccbbbb + assert d.regs.x11 == 0x333322221111ffff + assert d.regs.x12 == 0x7777666655554444 + assert d.regs.x13 == 0xbbbbaaaa99998888 + assert d.regs.x14 == 0xffffeeeeddddcccc + assert d.regs.x15 == 0x4444333322221111 + assert d.regs.x16 == 0x8888777766665555 + assert d.regs.x17 == 0xccccbbbbaaaa9999 + assert d.regs.x18 == 0x1111ffffeeeedddd + assert d.regs.x19 == 0x5555444433332222 + assert d.regs.x20 == 0x9999888877776666 + assert d.regs.x21 == 0xddddccccbbbbaaaa + assert d.regs.x22 == 0x22221111ffffeeee + assert d.regs.x23 == 0x6666555544443333 + assert d.regs.x24 == 0xaaaa999988887777 + assert d.regs.x25 == 0xeeeeddddccccbbbb + assert d.regs.x26 == 0x333322221111ffff + assert d.regs.x27 == 0x7777666655554444 + assert d.regs.x28 == 0xbbbbaaaa99998888 + assert d.regs.x29 == 0xffffeeeeddddcccc + assert d.regs.x30 == 0x4444333322221111 + + assert d.regs.w0 == 0x22221111 + assert d.regs.w1 == 0x66665555 + assert d.regs.w2 == 0xaaaa9999 + assert d.regs.w3 == 0xeeeedddd + assert d.regs.w4 == 0x33332222 + assert d.regs.w5 == 0x77776666 + assert d.regs.w6 == 0xbbbbaaaa + assert d.regs.w7 == 0xffffeeee + assert d.regs.w8 == 0x44443333 + assert d.regs.w9 == 0x88887777 + assert d.regs.w10 == 0xccccbbbb + assert d.regs.w11 == 0x1111ffff + assert d.regs.w12 == 0x55554444 + assert d.regs.w13 == 0x99998888 + assert d.regs.w14 == 0xddddcccc + assert d.regs.w15 == 0x22221111 + assert d.regs.w16 == 0x66665555 + assert d.regs.w17 == 0xaaaa9999 + assert d.regs.w18 == 0xeeeedddd + assert d.regs.w19 == 0x33332222 + assert d.regs.w20 == 0x77776666 + assert d.regs.w21 == 0xbbbbaaaa + assert d.regs.w22 == 0xffffeeee + assert d.regs.w23 == 0x44443333 + assert d.regs.w24 == 0x88887777 + assert d.regs.w25 == 0xccccbbbb + assert d.regs.w26 == 0x1111ffff + assert d.regs.w27 == 0x55554444 + assert d.regs.w28 == 0x99998888 + assert d.regs.w29 == 0xddddcccc + assert d.regs.w30 == 0x22221111 + + d.cont() + + d.kill() + d.terminate() + + def test_step(self): + d = debugger("binaries/basic_test") + + d.run() + bp = d.breakpoint("register_test", hardware=True) + d.cont() + + assert bp.address == d.regs.pc + assert bp.hit_count == 1 + + d.step() + + assert (bp.address + 4) == d.regs.pc + assert bp.hit_count == 1 + + d.step() + + assert (bp.address + 8) == d.regs.pc + assert bp.hit_count == 1 + + d.kill() + d.terminate() \ No newline at end of file diff --git a/test/aarch64/scripts/basic_test_pie.py b/test/aarch64/scripts/basic_test_pie.py new file mode 100644 index 00000000..1ff08c44 --- /dev/null +++ b/test/aarch64/scripts/basic_test_pie.py @@ -0,0 +1,24 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + +class BasicTestPie(unittest.TestCase): + def test_basic(self): + d = debugger("binaries/basic_test_pie") + + d.run() + bp = d.breakpoint(0x780, file="binary") + d.cont() + + assert bp.address == d.regs.pc + assert d.regs.x0 == 0xaabbccdd11223344 + + d.cont() + d.kill() + d.terminate() \ No newline at end of file diff --git a/test/aarch64/scripts/breakpoint_test.py b/test/aarch64/scripts/breakpoint_test.py new file mode 100644 index 00000000..b95b1c50 --- /dev/null +++ b/test/aarch64/scripts/breakpoint_test.py @@ -0,0 +1,375 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import unittest + +from libdebug import debugger + + +class BreakpointTest(unittest.TestCase): + def setUp(self): + # Redirect logging to a string buffer + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def test_bps(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary") + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + counter += 1 + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + + d.cont() + + assert bp2.hit_count == 10 + + d.kill() + d.terminate() + + def test_bp_disable(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary") + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + bp2.disable() + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + + d.cont() + + assert bp2.hit_count == 1 + + d.kill() + d.terminate() + + def test_bp_disable_hardware(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary", hardware=True) + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + bp2.disable() + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + + d.cont() + + assert bp2.hit_count == 1 + + d.kill() + d.terminate() + + def test_bp_disable_reenable(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary") + bp4 = d.breakpoint(0x814, file="binary") + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + if bp4.enabled: + bp4.disable() + else: + bp4.enable() + counter += 1 + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + elif bp4.hit_on(d): + pass + + d.cont() + + assert bp3.hit_count == 1 + assert bp4.hit_count == (bp2.hit_count // 2 + 1) + + d.kill() + d.terminate() + + def test_bp_disable_reenable_hardware(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function", hardware=True) + bp2 = d.breakpoint(0x7fc, file="binary", hardware=True) + bp4 = d.breakpoint(0x810, file="binary", hardware=True) + bp3 = d.breakpoint(0x820, file="binary", hardware=True) + + counter = 1 + + d.cont() + + for _ in range(20): + if d.regs.pc == bp1.address: + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + if bp4.enabled: + bp4.disable() + else: + bp4.enable() + counter += 1 + elif d.regs.pc == bp3.address: + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + elif bp4.hit_on(d): + pass + + d.cont() + + assert bp4.hit_count == (bp2.hit_count // 2 + 1) + + d.kill() + d.terminate() + + def test_bps_running(self): + d = debugger("binaries/breakpoint_test") + + d.run() + + bp1 = d.breakpoint("random_function") + bp2 = d.breakpoint(0x7fc, file="binary") + bp3 = d.breakpoint(0x820, file="binary") + + counter = 1 + + d.cont() + + while True: + if d.running: + pass + if d.regs.pc == bp1.address: + self.assertFalse(d.running) + self.assertTrue(bp1.hit_count == 1) + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + elif d.regs.pc == bp2.address: + self.assertFalse(d.running) + self.assertTrue(bp2.hit_count == counter) + self.assertTrue(bp2.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp3.hit_on(d)) + counter += 1 + elif d.regs.pc == bp3.address: + self.assertFalse(d.running) + self.assertTrue(bp3.hit_count == 1) + self.assertTrue(d.regs.w1 == 45) + self.assertTrue(d.regs.x1 == 45) + self.assertTrue(bp3.hit_on(d)) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + break + + d.cont() + + assert bp2.hit_count == 10 + + d.kill() + d.terminate() + + def test_bp_backing_file(self): + d = debugger("binaries/executable_section_test") + + d.run() + + bp1 = d.breakpoint(0x968, file="binary") + + d.cont() + + d.wait() + + if bp1.hit_on(d): + for vmap in d.maps(): + if "x" in vmap.permissions and "anon" in vmap.backing_file: + section = vmap.backing_file + bp2 = d.breakpoint(0x10, file=section) + d.cont() + + d.wait() + + if bp2.hit_on(d): + self.assertEqual(d.memory[d.regs.pc, 4], bytes.fromhex("ff430091")) + self.assertEqual(d.regs.w0, 9) + + d.kill() + + self.assertEqual(bp1.hit_count, 1) + self.assertEqual(bp2.hit_count, 1) + + d.run() + + bp1 = d.breakpoint(0x968, file="executable_section_test") + + d.cont() + + d.wait() + + if bp1.hit_on(d): + for vmap in d.maps(): + if "x" in vmap.permissions and "anon" in vmap.backing_file: + section = vmap.backing_file + bp2 = d.breakpoint(0x10, file=section) + d.cont() + + d.wait() + + if bp2.hit_on(d): + self.assertEqual(d.memory[d.regs.pc, 4], bytes.fromhex("ff430091")) + self.assertEqual(d.regs.w0, 9) + + d.run() + + bp1 = d.breakpoint(0x968, file="hybrid") + + d.cont() + + d.wait() + + if bp1.hit_on(d): + for vmap in d.maps(): + if "x" in vmap.permissions and "anon" in vmap.backing_file: + section = vmap.backing_file + bp2 = d.breakpoint(0x10, file=section) + d.cont() + + d.wait() + + if bp2.hit_on(d): + self.assertEqual(d.memory[d.regs.pc, 4], bytes.fromhex("ff430091")) + self.assertEqual(d.regs.w0, 9) + + d.kill() + + self.assertEqual(bp1.hit_count, 1) + self.assertEqual(bp2.hit_count, 1) + + d.run() + + with self.assertRaises(ValueError): + d.breakpoint(0x968, file="absolute") + + d.kill() + d.terminate() diff --git a/test/aarch64/scripts/brute_test.py b/test/aarch64/scripts/brute_test.py new file mode 100644 index 00000000..0b4e0c82 --- /dev/null +++ b/test/aarch64/scripts/brute_test.py @@ -0,0 +1,48 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio, Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import string +import unittest + +from libdebug import debugger + + +class BruteTest(unittest.TestCase): + def setUp(self): + pass + + def test_bruteforce(self): + flag = "" + counter = 1 + + d = debugger("binaries/brute_test") + + while not flag or flag != "BRUTINOBRUTONE": + for c in string.printable: + r = d.run() + bp = d.breakpoint(0x974, hardware=True, file="binary") + d.cont() + + r.sendlineafter(b"chars\n", (flag + c).encode()) + + while bp.address == d.regs.pc: + d.cont() + + if bp.hit_count > counter: + flag += c + counter = bp.hit_count + d.kill() + break + + message = r.recvline() + + d.kill() + + if message == b"Giusto!": + flag += c + break + + self.assertEqual(flag, "BRUTINOBRUTONE") diff --git a/test/aarch64/scripts/builtin_handler_test.py b/test/aarch64/scripts/builtin_handler_test.py new file mode 100644 index 00000000..96d0fe88 --- /dev/null +++ b/test/aarch64/scripts/builtin_handler_test.py @@ -0,0 +1,62 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest +import string + +from libdebug import debugger + + +class BuiltinHandlerTest(unittest.TestCase): + def test_antidebug_escaping(self): + d = debugger("binaries/antidebug_brute_test") + + # validate that without the handler the binary cannot be debugged + r = d.run() + d.cont() + msg = r.recvline() + self.assertEqual(msg, b"Debugger detected") + d.kill() + + # validate that with the handler the binary can be debugged + d = debugger("binaries/antidebug_brute_test", escape_antidebug=True) + r = d.run() + d.cont() + msg = r.recvline() + self.assertEqual(msg, b"Write up to 64 chars") + d.interrupt() + d.kill() + + # validate that the binary still works + flag = "" + counter = 1 + + while not flag or flag != "BRUTE": + for c in string.printable: + r = d.run() + bp = d.breakpoint(0xa10, hardware=True, file="binary") + d.cont() + + r.sendlineafter(b"chars\n", (flag + c).encode()) + + while bp.address == d.regs.pc: + d.cont() + + if bp.hit_count > counter: + flag += c + counter = bp.hit_count + d.kill() + break + + message = r.recvline() + + d.kill() + + if message == b"Giusto!": + flag += c + break + + self.assertEqual(flag, "BRUTE") diff --git a/test/aarch64/scripts/callback_test.py b/test/aarch64/scripts/callback_test.py new file mode 100644 index 00000000..e8ba0ef0 --- /dev/null +++ b/test/aarch64/scripts/callback_test.py @@ -0,0 +1,264 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio, Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import string +import unittest + +from libdebug import debugger + + +class CallbackTest(unittest.TestCase): + def setUp(self): + self.exceptions = [] + + def test_callback_simple(self): + self.exceptions.clear() + + global hit + hit = False + + d = debugger("binaries/basic_test") + + d.run() + + def callback(thread, bp): + global hit + + try: + self.assertEqual(bp.hit_count, 1) + self.assertTrue(bp.hit_on(thread)) + except Exception as e: + self.exceptions.append(e) + + hit = True + + d.breakpoint("register_test", callback=callback) + + d.cont() + + d.kill() + + self.assertTrue(hit) + + if self.exceptions: + raise self.exceptions[0] + + def test_callback_simple_hardware(self): + self.exceptions.clear() + + global hit + hit = False + + d = debugger("binaries/basic_test") + + d.run() + + def callback(thread, bp): + global hit + + try: + self.assertEqual(bp.hit_count, 1) + self.assertTrue(bp.hit_on(thread)) + except Exception as e: + self.exceptions.append(e) + + hit = True + + d.breakpoint("register_test", callback=callback, hardware=True) + + d.cont() + + d.kill() + + self.assertTrue(hit) + + if self.exceptions: + raise self.exceptions[0] + + def test_callback_memory(self): + self.exceptions.clear() + + global hit + hit = False + + d = debugger("binaries/memory_test") + + d.run() + + def callback(thread, bp): + global hit + + prev = bytes(range(256)) + try: + self.assertEqual(bp.address, thread.regs.pc) + self.assertEqual(bp.hit_count, 1) + self.assertEqual(thread.memory[thread.regs.x0, 256], prev) + + thread.memory[thread.regs.x0 + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertEqual(thread.memory[thread.regs.x0, 256], prev) + except Exception as e: + self.exceptions.append(e) + + hit = True + + d.breakpoint("change_memory", callback=callback) + + d.cont() + + d.kill() + + self.assertTrue(hit) + + if self.exceptions: + raise self.exceptions[0] + + def test_callback_bruteforce(self): + global flag + global counter + global new_counter + + flag = "" + counter = 1 + new_counter = 0 + + def brute_force(d, b): + global new_counter + try: + new_counter = b.hit_count + except Exception as e: + self.exceptions.append(e) + + d = debugger("binaries/brute_test") + while True: + end = False + for c in string.printable: + r = d.run() + + d.breakpoint(0x974, callback=brute_force, hardware=True) + d.cont() + + r.sendlineafter(b"chars\n", (flag + c).encode()) + + message = r.recvline() + + if new_counter > counter: + flag += c + counter = new_counter + d.kill() + break + d.kill() + if message == b"Giusto!": + flag += c + end = True + break + if end: + break + + self.assertEqual(flag, "BRUTINOBRUTONE") + + if self.exceptions: + raise self.exceptions[0] + + def test_callback_exception(self): + self.exceptions.clear() + + d = debugger("binaries/basic_test") + + d.run() + + def callback(thread, bp): + # This operation should not raise any exception + _ = d.regs.x0 + + d.breakpoint("register_test", callback=callback, hardware=True) + + d.cont() + + d.kill() + + def test_callback_step(self): + self.exceptions.clear() + + d = debugger("binaries/basic_test") + + d.run() + + def callback(t, bp): + self.assertEqual(t.regs.pc, bp.address) + d.step() + self.assertEqual(t.regs.pc, bp.address + 4) + + d.breakpoint("register_test", callback=callback) + + d.cont() + + d.kill() + + def test_callback_pid_accessible(self): + self.exceptions.clear() + + d = debugger("binaries/basic_test") + + d.run() + + hit = False + + def callback(t, bp): + nonlocal hit + self.assertEqual(t.process_id, d.process_id) + hit = True + + d.breakpoint("register_test", callback=callback) + + d.cont() + d.kill() + + self.assertTrue(hit) + + def test_callback_pid_accessible_alias(self): + self.exceptions.clear() + + d = debugger("binaries/basic_test") + + d.run() + + hit = False + + def callback(t, bp): + nonlocal hit + self.assertEqual(t.pid, d.pid) + self.assertEqual(t.pid, t.process_id) + hit = True + + d.breakpoint("register_test", callback=callback) + + d.cont() + d.kill() + + self.assertTrue(hit) + + def test_callback_tid_accessible_alias(self): + self.exceptions.clear() + + d = debugger("binaries/basic_test") + + d.run() + + hit = False + + def callback(t, bp): + nonlocal hit + self.assertEqual(t.tid, t.thread_id) + hit = True + + d.breakpoint("register_test", callback=callback) + + d.cont() + d.kill() + + self.assertTrue(hit) diff --git a/test/aarch64/scripts/catch_signal_test.py b/test/aarch64/scripts/catch_signal_test.py new file mode 100644 index 00000000..82e16d08 --- /dev/null +++ b/test/aarch64/scripts/catch_signal_test.py @@ -0,0 +1,1266 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import unittest + +from libdebug import debugger + + +class CatchSignalTest(unittest.TestCase): + def setUp(self): + # Redirect logging to a string buffer + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def tearDown(self): + self.logger.removeHandler(self.log_handler) + self.logger.handlers = self.original_handlers + self.log_handler.close() + + def test_signal_catch_signal_block(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + d.signals_to_block = ["SIGUSR1", 15, "SIGINT", 3, 13] + + d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal(2, callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + def test_signal_pass_to_process(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_signal_disable_catch_signal(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Uncatchering signals + if bp.hit_on(d): + catcher4.disable() + catcher5.disable() + d.cont() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 2) # 1 times less because of the disable catch + self.assertEqual(SIGPIPE_count, 2) # 1 times less because of the disable catch + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_signal_unblock(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + d.signals_to_block = [10, 15, 2, 3, 13] + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + bp = d.breakpoint(0x964) + + d.cont() + + # No block the signals anymore + if bp.hit_on(d): + d.signals_to_block = [] + + d.cont() + + signal_received = [] + while True: + try: + signal_received.append(r.recvline()) + except RuntimeError: + break + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + self.assertEqual(signal_received[0], b"Received signal 3") + self.assertEqual(signal_received[1], b"Received signal 13") + self.assertEqual(signal_received[2], b"Exiting normally.") + + self.assertEqual(len(signal_received), 3) + + def test_signal_disable_catch_signal_unblock(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + d.signals_to_block = [10, 15, 2, 3, 13] + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + bp = d.breakpoint(0x964) + + d.cont() + + # No block the signals anymore + if bp.hit_on(d): + d.signals_to_block = [] + catcher4.disable() + catcher5.disable() + + d.cont() + + signal_received = [] + while True: + try: + signal_received.append(r.recvline()) + except RuntimeError: + break + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 2) # 1 times less because of the disable catch + self.assertEqual(SIGPIPE_count, 2) # 1 times less because of the disable catch + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + self.assertEqual(signal_received[0], b"Received signal 3") + self.assertEqual(signal_received[1], b"Received signal 13") + self.assertEqual(signal_received[2], b"Exiting normally.") + + self.assertEqual(len(signal_received), 3) + + def test_hijack_signal_with_catch_signal(self): + def catcher_SIGUSR1(t, sc): + # Hijack to SIGTERM + t.signal = 15 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(catcher1.hit_count, 2) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_hijack_signal_with_api(self): + d = debugger("binaries/catch_signal_test") + + r = d.run() + + # Hijack to SIGTERM + catcher1 = d.hijack_signal("SIGUSR1", 15) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(catcher1.hit_count, 2) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_recursive_true_with_catch_signal(self): + SIGUSR1_count = 0 + SIGTERM_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + # Hijack to SIGTERM + t.signal = 15 + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1, recursive=True) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 4) # 2 times more because of the hijack + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_recursive_true_with_api(self): + SIGTERM_count = 0 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.hijack_signal(10, 15, recursive=True) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGTERM_count, 4) # 2 times more because of the hijack + self.assertEqual(catcher1.hit_count, 2) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_recursive_false_with_catch_signal(self): + SIGUSR1_count = 0 + SIGTERM_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + # Hijack to SIGTERM + t.signal = 15 + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1, recursive=False) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) # 2 times in total because of the recursive=False + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_recursive_false_with_api(self): + SIGTERM_count = 0 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.hijack_signal(10, 15, recursive=False) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(catcher1.hit_count, 2) + self.assertEqual(SIGTERM_count, 2) # 2 times in total because of the recursive=False + self.assertEqual(SIGTERM_count, catcher2.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 15" * 2) # hijacked signal + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + def test_hijack_signal_with_catch_signal_loop(self): + # Let create a loop of hijacking signals + + def catcher_SIGUSR1(t, sc): + # Hijack to SIGTERM + t.signal = 15 + + def catcher_SIGTERM(t, sc): + # Hijack to SIGINT + t.signal = 10 + + d = debugger("binaries/catch_signal_test") + + d.run() + + d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1, recursive=True) + d.catch_signal("SIGTERM", callback=catcher_SIGTERM, recursive=True) + + with self.assertRaises(RuntimeError): + d.cont() + d.wait() + + d.kill() + + # Now we set recursive=False to avoid the loop + d.run() + + d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1, recursive=False) + d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + + d.cont() + d.kill() + + d.run() + + d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + d.catch_signal("SIGTERM", callback=catcher_SIGTERM, recursive=False) + + d.cont() + d.kill() + + d.run() + + d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1, recursive=False) + d.catch_signal("SIGTERM", callback=catcher_SIGTERM, recursive=False) + + d.cont() + d.kill() + + def test_hijack_signal_with_api_loop(self): + # Let create a loop of hijacking signals + + d = debugger("binaries/catch_signal_test") + + d.run() + + d.hijack_signal("SIGUSR1", "SIGTERM", recursive=True) + d.hijack_signal(15, 10, recursive=True) + + with self.assertRaises(RuntimeError): + d.cont() + d.wait() + + d.kill() + + # Now we set recursive=False to avoid the loop + d.run() + + d.hijack_signal("SIGUSR1", "SIGTERM", recursive=False) + d.hijack_signal(15, 10) + + d.cont() + d.kill() + + d.run() + + d.hijack_signal("SIGUSR1", "SIGTERM") + d.hijack_signal(15, 10, recursive=False) + + d.cont() + d.kill() + + d.run() + + d.hijack_signal("SIGUSR1", "SIGTERM", recursive=False) + d.hijack_signal(15, 10, recursive=False) + + d.cont() + d.kill() + + def test_signal_unhijacking(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGTERM_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.hijack_signal("SIGQUIT", "SIGTERM", recursive=True) + catcher5 = d.hijack_signal("SIGPIPE", "SIGTERM", recursive=True) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Disable catching of signals + if bp.hit_on(d): + catcher4.disable() + catcher5.disable() + d.cont() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2 + 2 + 2) # 2 times more because of the hijacking * 2 (SIGQUIT and SIGPIPE) + self.assertEqual(SIGINT_count, 2) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 15" * 2 + b"Received signal 3") + self.assertEqual(SIGPIPE, b"Received signal 15" * 2 + b"Received signal 13") + + def test_override_catch_signal(self): + SIGPIPE_count_first = 0 + SIGPIPE_count_second = 0 + + def catcher_SIGPIPE_first(t, sc): + nonlocal SIGPIPE_count_first + + SIGPIPE_count_first += 1 + + def catcher_SIGPIPE_second(t, sc): + nonlocal SIGPIPE_count_second + + SIGPIPE_count_second += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE_first) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Overriding the catcher + if bp.hit_on(d): + self.assertEqual(catcher1.hit_count, 2) + catcher2 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE_second) + d.cont() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGPIPE_count_first, 2) + self.assertEqual(SIGPIPE_count_second, 1) + + self.assertEqual(SIGPIPE_count_first, catcher1.hit_count) + self.assertEqual(SIGPIPE_count_second, catcher2.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 13" * 3) + + self.assertEqual( + self.log_capture_string.getvalue().count("has already been caught. Overriding it."), + 1, + ) + + def test_override_hijack(self): + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.hijack_signal("SIGPIPE", 15) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Overriding the catcher + if bp.hit_on(d): + self.assertEqual(catcher1.hit_count, 2) + catcher2 = d.hijack_signal("SIGPIPE", "SIGINT") + d.cont() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(catcher1.hit_count, 2) + self.assertEqual(catcher2.hit_count, 1) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 15" * 2 + b"Received signal 2") + + self.assertEqual( + self.log_capture_string.getvalue().count("has already been caught. Overriding it."), + 1, + ) + + def test_override_hybrid(self): + SIGPIPE_count = 0 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.hijack_signal("SIGPIPE", 15) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Overriding the catcher + if bp.hit_on(d): + self.assertEqual(catcher1.hit_count, 2) + catcher2 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + d.cont() + + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(catcher1.hit_count, 2) + self.assertEqual(catcher2.hit_count, 1) + self.assertEqual(SIGPIPE_count, 1) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 2) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 3" * 3) + self.assertEqual(SIGPIPE, b"Received signal 15" * 2 + b"Received signal 13") + + self.assertEqual( + self.log_capture_string.getvalue().count("has already been caught. Overriding it."), + 1, + ) + + def test_signal_get_signal(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + self.assertEqual(t.signal, "SIGUSR1") + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + self.assertEqual(t.signal, "SIGTERM") + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + self.assertEqual(t.signal, "SIGINT") + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + self.assertEqual(t.signal, "SIGQUIT") + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + self.assertEqual(t.signal, "SIGPIPE") + + SIGPIPE_count += 1 + + d = debugger("binaries/catch_signal_test") + + d.signals_to_block = ["SIGUSR1", 15, "SIGINT", 3, 13] + + d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal(2, callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + def test_signal_send_signal(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGTERM_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.hijack_signal("SIGQUIT", "SIGTERM", recursive=True) + catcher5 = d.hijack_signal("SIGPIPE", "SIGTERM", recursive=True) + + bp = d.breakpoint(0x964) + + d.cont() + + SIGUSR1 = r.recvline() + SIGTERM = r.recvline() + SIGINT = r.recvline() + SIGQUIT = r.recvline() + SIGPIPE = r.recvline() + + SIGUSR1 += r.recvline() + SIGTERM += r.recvline() + SIGINT += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + # Uncatchering and send signals + if bp.hit_on(d): + catcher4.disable() + catcher5.disable() + d.signal = 10 + d.cont() + + SIGUSR1 += r.recvline() + SIGQUIT += r.recvline() + SIGPIPE += r.recvline() + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2 + 2 + 2) # 2 times more because of the hijacking * 2 (SIGQUIT and SIGPIPE) + self.assertEqual(SIGINT_count, 2) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + + self.assertEqual(SIGUSR1, b"Received signal 10" * 3) + self.assertEqual(SIGTERM, b"Received signal 15" * 2) + self.assertEqual(SIGINT, b"Received signal 2" * 2) + self.assertEqual(SIGQUIT, b"Received signal 15" * 2 + b"Received signal 3") + self.assertEqual(SIGPIPE, b"Received signal 15" * 2 + b"Received signal 13") + + def test_signal_catch_sync_block(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + d = debugger("binaries/catch_signal_test") + + d.signals_to_block = ["SIGUSR1", 15, "SIGINT", 3, 13] + + d.run() + + catcher1 = d.catch_signal(10) + catcher2 = d.catch_signal("SIGTERM") + catcher3 = d.catch_signal(2) + catcher4 = d.catch_signal("SIGQUIT") + catcher5 = d.catch_signal("SIGPIPE") + + while not d.dead: + d.cont() + d.wait() + if catcher1.hit_on(d): + SIGUSR1_count += 1 + elif catcher2.hit_on(d): + SIGTERM_count += 1 + elif catcher3.hit_on(d): + SIGINT_count += 1 + elif catcher4.hit_on(d): + SIGQUIT_count += 1 + elif catcher5.hit_on(d): + SIGPIPE_count += 1 + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + def test_signal_catch_sync_pass(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + signals = b"" + + d = debugger("binaries/catch_signal_test") + + r = d.run() + + catcher1 = d.catch_signal(10) + catcher2 = d.catch_signal("SIGTERM") + catcher3 = d.catch_signal(2) + catcher4 = d.catch_signal("SIGQUIT") + catcher5 = d.catch_signal("SIGPIPE") + + signals = b"" + while not d.dead: + d.cont() + try: + signals += r.recvline() + except: + pass + d.wait() + if catcher1.hit_on(d): + SIGUSR1_count += 1 + elif catcher2.hit_on(d): + SIGTERM_count += 1 + elif catcher3.hit_on(d): + SIGINT_count += 1 + elif catcher4.hit_on(d): + SIGQUIT_count += 1 + elif catcher5.hit_on(d): + SIGPIPE_count += 1 + + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + self.assertEqual(signals.count(b"Received signal 10"), 2) + self.assertEqual(signals.count(b"Received signal 15"), 2) + self.assertEqual(signals.count(b"Received signal 2"), 2) + self.assertEqual(signals.count(b"Received signal 3"), 3) + self.assertEqual(signals.count(b"Received signal 13"), 3) diff --git a/test/aarch64/scripts/control_flow_test.py b/test/aarch64/scripts/control_flow_test.py new file mode 100644 index 00000000..7aa3e1fb --- /dev/null +++ b/test/aarch64/scripts/control_flow_test.py @@ -0,0 +1,183 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class ControlFlowTest(unittest.TestCase): + def test_step_until_1(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp = d.breakpoint("main") + d.cont() + + self.assertTrue(bp.hit_on(d)) + + d.step_until(0x0000aaaaaaaa0854) + + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa0854) + self.assertTrue(bp.hit_count == 1) + self.assertFalse(bp.hit_on(d)) + + d.kill() + d.terminate() + + def test_step_until_2(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp = d.breakpoint(0x7fc, hardware=True) + d.cont() + + self.assertTrue(bp.hit_on(d)) + + d.step_until(0x0000aaaaaaaa0854, max_steps=7) + + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa0818) + self.assertTrue(bp.hit_count == 1) + self.assertFalse(bp.hit_on(d)) + + d.kill() + d.terminate() + + def test_step_until_3(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp = d.breakpoint(0x7fc) + + # Let's put some breakpoints in-between + d.breakpoint(0x804) + d.breakpoint(0x80c) + d.breakpoint(0x808, hardware=True) + + d.cont() + + self.assertTrue(bp.hit_on(d)) + + # trace is [0x7fc, 0x800, 0x804, 0x808, 0x80c, 0x810, 0x814, 0x818] + d.step_until(0x0000aaaaaaaa0854, max_steps=7) + + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa0818) + self.assertTrue(bp.hit_count == 1) + self.assertFalse(bp.hit_on(d)) + + d.kill() + d.terminate() + + def test_step_and_cont(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp1 = d.breakpoint("main") + bp2 = d.breakpoint("random_function") + d.cont() + + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step() + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa083c) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step() + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa0840) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + d.cont() + + d.kill() + d.terminate() + + def test_step_and_cont_hardware(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp1 = d.breakpoint("main", hardware=True) + bp2 = d.breakpoint("random_function", hardware=True) + d.cont() + + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step() + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa083c) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step() + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa0840) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + d.cont() + + d.kill() + d.terminate() + + def test_step_until_and_cont(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp1 = d.breakpoint("main") + bp2 = d.breakpoint("random_function") + d.cont() + + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step_until(0x0000aaaaaaaa083c) + + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa083c) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + d.cont() + + d.kill() + d.terminate() + + def test_step_until_and_cont_hardware(self): + d = debugger("binaries/breakpoint_test") + d.run() + + bp1 = d.breakpoint("main", hardware=True) + bp2 = d.breakpoint("random_function", hardware=True) + d.cont() + + self.assertTrue(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.step_until(0x0000aaaaaaaa083c) + self.assertTrue(d.regs.pc == 0x0000aaaaaaaa083c) + self.assertFalse(bp1.hit_on(d)) + self.assertFalse(bp2.hit_on(d)) + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + d.cont() + + d.kill() + d.terminate() \ No newline at end of file diff --git a/test/aarch64/scripts/death_test.py b/test/aarch64/scripts/death_test.py new file mode 100644 index 00000000..0a228b54 --- /dev/null +++ b/test/aarch64/scripts/death_test.py @@ -0,0 +1,154 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import unittest + +from libdebug import debugger + + +class DeathTest(unittest.TestCase): + def setUp(self): + # Redirect logging to a string buffer + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def tearDown(self): + self.logger.removeHandler(self.log_handler) + self.logger.handlers = self.original_handlers + self.log_handler.close() + + def test_io_death(self): + d = debugger("binaries/segfault_test") + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Hello, World!") + self.assertEqual(r.recvline(), b"Death is coming!") + + with self.assertRaises(RuntimeError): + r.recvline() + + d.kill() + + def test_cont_death(self): + d = debugger("binaries/segfault_test") + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Hello, World!") + self.assertEqual(r.recvline(), b"Death is coming!") + + d.wait() + + with self.assertRaises(RuntimeError): + d.cont() + + self.assertEqual(d.dead, True) + self.assertEqual(d.threads[0].dead, True) + + d.kill() + + def test_instr_death(self): + d = debugger("binaries/segfault_test") + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Hello, World!") + self.assertEqual(r.recvline(), b"Death is coming!") + + d.wait() + + self.assertEqual(d.regs.pc, 0xaaaaaaaa0784) + + d.kill() + + def test_exit_signal_death(self): + d = debugger("binaries/segfault_test") + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Hello, World!") + self.assertEqual(r.recvline(), b"Death is coming!") + + d.wait() + + self.assertEqual(d.exit_signal, "SIGSEGV") + self.assertEqual(d.exit_signal, d.threads[0].exit_signal) + + d.kill() + + def test_exit_code_death(self): + d = debugger("binaries/segfault_test") + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Hello, World!") + self.assertEqual(r.recvline(), b"Death is coming!") + + d.wait() + + d.exit_code + + self.assertEqual( + self.log_capture_string.getvalue().count("No exit code available."), + 1, + ) + + d.kill() + + def test_exit_code_normal(self): + d = debugger("binaries/basic_test") + + d.run() + + d.cont() + + d.wait() + + self.assertEqual(d.exit_code, 0) + + d.exit_signal + + self.assertEqual( + self.log_capture_string.getvalue().count("No exit signal available."), + 1, + ) + + d.kill() + + def test_post_mortem_after_kill(self): + d = debugger("binaries/basic_test") + + d.run() + + d.cont() + + d.interrupt() + d.kill() + + # We should be able to access the registers also after the process has been killed + d.regs.x0 + d.regs.x1 + d.regs.x2 diff --git a/test/aarch64/scripts/finish_test.py b/test/aarch64/scripts/finish_test.py new file mode 100644 index 00000000..e2d7e025 --- /dev/null +++ b/test/aarch64/scripts/finish_test.py @@ -0,0 +1,325 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco, Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# +import unittest + +from libdebug import debugger +from libdebug.architectures.stack_unwinding_provider import stack_unwinding_provider + +# Addresses of the dummy functions +C_ADDRESS = 0xaaaaaaaa0914 +B_ADDRESS = 0xaaaaaaaa08fc +A_ADDRESS = 0xaaaaaaaa0814 + +# Addresses of noteworthy instructions +RETURN_POINT_FROM_C = 0xaaaaaaaa0938 +RETURN_POINT_FROM_A = 0xaaaaaaaa0908 + +class FinishTest(unittest.TestCase): + def setUp(self): + pass + + def test_finish_exact_no_auto_interrupt_no_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # ------------------ Block 1 ------------------ # + # Return from the first function call # + # --------------------------------------------- # + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + # Finish function c + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + + d.kill() + + # ------------------ Block 2 ------------------ # + # Return from the nested function call # + # --------------------------------------------- # + + # Reach function a + d.run() + d.breakpoint(A_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, A_ADDRESS) + + # Finish function a + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_A) + + d.kill() + + def test_finish_heuristic_no_auto_interrupt_no_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # ------------------ Block 1 ------------------ # + # Return from the first function call # + # --------------------------------------------- # + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + # Finish function c + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + + d.kill() + + # ------------------ Block 2 ------------------ # + # Return from the nested function call # + # --------------------------------------------- # + + # Reach function a + d.run() + d.breakpoint(A_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, A_ADDRESS) + + # Finish function a + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_A) + + d.kill() + + def test_finish_exact_auto_interrupt_no_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=True) + + # ------------------ Block 1 ------------------ # + # Return from the first function call # + # --------------------------------------------- # + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + d.wait() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + # Finish function c + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + + d.kill() + + # ------------------ Block 2 ------------------ # + # Return from the nested function call # + # --------------------------------------------- # + + # Reach function a + d.run() + d.breakpoint(A_ADDRESS) + d.cont() + d.wait() + + self.assertEqual(d.regs.pc, A_ADDRESS) + + # Finish function a + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_A) + + d.kill() + + def test_finish_heuristic_auto_interrupt_no_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=True) + + # ------------------ Block 1 ------------------ # + # Return from the first function call # + # --------------------------------------------- # + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + d.wait() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + # Finish function c + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + + d.kill() + + # ------------------ Block 2 ------------------ # + # Return from the nested function call # + # --------------------------------------------- # + + # Reach function a + d.run() + d.breakpoint(A_ADDRESS) + d.cont() + d.wait() + + self.assertEqual(d.regs.pc, A_ADDRESS) + + # Finish function a + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_A) + + d.kill() + + def test_finish_exact_no_auto_interrupt_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + d.breakpoint(A_ADDRESS) + + # Finish function c + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, A_ADDRESS, f"Expected {hex(A_ADDRESS)} but got {hex(d.regs.pc)}") + + d.kill() + + def test_finish_heuristic_no_auto_interrupt_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + d.breakpoint(A_ADDRESS) + + # Finish function c + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, A_ADDRESS) + + d.kill() + + def test_heuristic_return_address(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + stack_unwinder = stack_unwinding_provider(d._internal_debugger.arch) + + # We need to repeat the check for the three stages of the function preamble + + # Get current return address + curr_srip = d.saved_ip + self.assertEqual(curr_srip, RETURN_POINT_FROM_C) + + d.step() + + # Get current return address + curr_srip = d.saved_ip + self.assertEqual(curr_srip, RETURN_POINT_FROM_C) + + d.step() + + # Get current return address + curr_srip = d.saved_ip + self.assertEqual(curr_srip, RETURN_POINT_FROM_C) + + d.kill() + + def test_exact_breakpoint_return(self): + BREAKPOINT_LOCATION = 0xaaaaaaaa0920 + + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + + # Place a breakpoint at a location inbetween + d.breakpoint(BREAKPOINT_LOCATION) + + # Finish function c + d.finish(heuristic="step-mode") + + self.assertEqual(d.regs.pc, BREAKPOINT_LOCATION) + + d.kill() + + def test_heuristic_breakpoint_return(self): + BREAKPOINT_LOCATION = 0xaaaaaaaa0920 + + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + + # Place a breakpoint a location in between + d.breakpoint(BREAKPOINT_LOCATION) + + # Finish function c + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, BREAKPOINT_LOCATION) + + d.kill() + + def test_breakpoint_collision(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + + # Reach function c + d.run() + d.breakpoint(C_ADDRESS) + d.cont() + + self.assertEqual(d.regs.pc, C_ADDRESS) + + # Place a breakpoint at the same location as the return address + d.breakpoint(RETURN_POINT_FROM_C) + + # Finish function c + d.finish(heuristic="backtrace") + + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + self.assertFalse(d.running) + + d.step() + + # Check that the execution is still running and nothing has broken + self.assertFalse(d.running) + self.assertFalse(d.dead) + + d.kill() diff --git a/test/aarch64/scripts/floating_point_test.py b/test/aarch64/scripts/floating_point_test.py new file mode 100644 index 00000000..770c67b6 --- /dev/null +++ b/test/aarch64/scripts/floating_point_test.py @@ -0,0 +1,84 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import sys +import unittest +from random import randint + +from libdebug import debugger + + +class FloatingPointTest(unittest.TestCase): + def test_floating_point_reg_access(self): + d = debugger("binaries/floating_point_test") + + d.run() + + bp1 = d.bp(0xb10, file="binary") + bp2 = d.bp(0xb44, file="binary") + + d.cont() + + assert bp1.hit_on(d) + + baseval = int.from_bytes(bytes(list(range(16))), sys.byteorder) + + for i in range(16): + assert hasattr(d.regs, f"q{i}") + assert getattr(d.regs, f"q{i}") == baseval + assert getattr(d.regs, f"v{i}") == baseval + assert getattr(d.regs, f"d{i}") == baseval & 0xFFFFFFFFFFFFFFFF + assert getattr(d.regs, f"s{i}") == baseval & 0xFFFFFFFF + assert getattr(d.regs, f"h{i}") == baseval & 0xFFFF + assert getattr(d.regs, f"b{i}") == baseval & 0xFF + baseval = (baseval >> 8) + ((baseval & 255) << 120) + + baseval = int.from_bytes(bytes(list(range(128, 128 + 16, 1))), sys.byteorder) + + for i in range(16, 32, 1): + assert hasattr(d.regs, f"q{i}") + assert getattr(d.regs, f"q{i}") == baseval + assert getattr(d.regs, f"v{i}") == baseval + assert getattr(d.regs, f"d{i}") == baseval & 0xFFFFFFFFFFFFFFFF + assert getattr(d.regs, f"s{i}") == baseval & 0xFFFFFFFF + assert getattr(d.regs, f"h{i}") == baseval & 0xFFFF + assert getattr(d.regs, f"b{i}") == baseval & 0xFF + baseval = (baseval >> 8) + ((baseval & 255) << 120) + + for i in range(32): + val = randint(0, (1 << 128) - 1) + setattr(d.regs, f"q{i}", val) + assert getattr(d.regs, f"q{i}") == val + assert getattr(d.regs, f"v{i}") == val + + for i in range(32): + val = randint(0, (1 << 64) - 1) + setattr(d.regs, f"d{i}", val) + assert getattr(d.regs, f"d{i}") == val + + for i in range(32): + val = randint(0, (1 << 32) - 1) + setattr(d.regs, f"s{i}", val) + assert getattr(d.regs, f"s{i}") == val + + for i in range(32): + val = randint(0, (1 << 16) - 1) + setattr(d.regs, f"h{i}", val) + assert getattr(d.regs, f"h{i}") == val + + for i in range(32): + val = randint(0, (1 << 8) - 1) + setattr(d.regs, f"b{i}", val) + assert getattr(d.regs, f"b{i}") == val + + d.regs.q0 = 0xdeadbeefdeadbeef + + d.cont() + + assert bp2.hit_on(d) + + d.kill() + d.terminate() diff --git a/test/aarch64/scripts/handle_syscall_test.py b/test/aarch64/scripts/handle_syscall_test.py new file mode 100644 index 00000000..9ca32a1d --- /dev/null +++ b/test/aarch64/scripts/handle_syscall_test.py @@ -0,0 +1,552 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import os +import sys +import unittest + +from libdebug import debugger + + +class HandleSyscallTest(unittest.TestCase): + def setUp(self): + # Redirect stdout + self.capturedOutput = io.StringIO() + sys.stdout = self.capturedOutput + sys.stderr = self.capturedOutput + + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def tearDown(self): + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + self.logger.removeHandler(self.log_handler) + self.logger.handlers = self.original_handlers + self.log_handler.close() + + def test_handles(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall("write", on_enter_write, None) + handler2 = d.handle_syscall("mmap", None, on_exit_mmap) + handler3 = d.handle_syscall("getcwd", on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + d.cont() + d.wait() + + d.kill() + d.terminate() + + self.assertEqual(write_count, 2) + self.assertEqual(handler1.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + def test_handles_with_pprint(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + d.pprint_syscalls = True + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall("write", on_enter_write, None) + handler2 = d.handle_syscall("mmap", None, on_exit_mmap) + handler3 = d.handle_syscall("getcwd", on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + d.cont() + d.wait() + + d.kill() + d.terminate() + + self.assertEqual(write_count, 2) + self.assertEqual(handler1.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + def test_handle_disabling(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall(0x40, on_enter_write, None) + handler2 = d.handle_syscall(222, None, on_exit_mmap) + handler3 = d.handle_syscall(17, on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + bp = d.breakpoint(0x9d4, file="binary") + + d.cont() + + d.wait() + + self.assertEqual(d.regs.pc, bp.address) + handler1.disable() + + d.cont() + + d.kill() + d.terminate() + + self.assertEqual(write_count, 1) + self.assertEqual(handler1.hit_count, 1) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + def test_handle_disabling_with_pprint(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + d.pprint_syscalls = True + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall(0x40, on_enter_write, None) + handler2 = d.handle_syscall(222, None, on_exit_mmap) + handler3 = d.handle_syscall(17, on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + bp = d.breakpoint(0x9d4, file="binary") + + d.cont() + + d.wait() + + self.assertEqual(d.regs.pc, bp.address) + handler1.disable() + + d.cont() + + d.kill() + d.terminate() + + self.assertEqual(write_count, 1) + self.assertEqual(handler1.hit_count, 1) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + def test_handle_overwrite(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + ptr = 0 + write_count_first = 0 + write_count_second = 0 + + def on_enter_write_first(d, sh): + nonlocal write_count_first + + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count_first += 1 + + def on_enter_write_second(d, sh): + nonlocal write_count_second + + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count_second += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1_1 = d.handle_syscall(0x40, on_enter_write_first, None) + handler2 = d.handle_syscall(222, None, on_exit_mmap) + handler3 = d.handle_syscall(17, on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + bp = d.breakpoint(0x9d4, file="binary") + + d.cont() + + d.wait() + + self.assertEqual(d.regs.pc, bp.address) + handler1_2 = d.handle_syscall(0x40, on_enter_write_second, None) + + d.cont() + + d.kill() + d.terminate() + + self.assertEqual(write_count_first, 1) + self.assertEqual(write_count_second, 1) + self.assertEqual(handler1_1.hit_count, 2) + self.assertEqual(handler1_2.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + self.assertIn("WARNING", self.log_capture_string.getvalue()) + self.assertIn( + "Syscall write is already handled by a user-defined handler. Overriding it.", + self.log_capture_string.getvalue(), + ) + + def test_handle_overwrite_with_pprint(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + d.pprint_syscalls = True + + ptr = 0 + write_count_first = 0 + write_count_second = 0 + + def on_enter_write_first(d, sh): + nonlocal write_count_first + + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count_first += 1 + + def on_enter_write_second(d, sh): + nonlocal write_count_second + + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count_second += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1_1 = d.handle_syscall(0x40, on_enter_write_first, None) + handler2 = d.handle_syscall(222, None, on_exit_mmap) + handler3 = d.handle_syscall(17, on_enter_getcwd, on_exit_getcwd) + + r.sendline(b"provola") + + bp = d.breakpoint(0x9d4, file="binary") + + d.cont() + + d.wait() + + self.assertEqual(d.regs.pc, bp.address) + handler1_2 = d.handle_syscall(0x40, on_enter_write_second, None) + + d.cont() + + d.kill() + d.terminate() + + self.assertEqual(write_count_first, 1) + self.assertEqual(write_count_second, 1) + self.assertEqual(handler1_1.hit_count, 2) + self.assertEqual(handler1_2.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + self.assertIn("WARNING", self.log_capture_string.getvalue()) + self.assertIn( + "Syscall write is already handled by a user-defined handler. Overriding it.", + self.log_capture_string.getvalue(), + ) + + + def test_handles_sync(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall("write") + handler2 = d.handle_syscall("mmap") + handler3 = d.handle_syscall("getcwd") + + r.sendline(b"provola") + + while not d.dead: + d.cont() + d.wait() + if handler1.hit_on_enter(d): + on_enter_write(d, handler1) + elif handler2.hit_on_exit(d): + on_exit_mmap(d, handler2) + elif handler3.hit_on_enter(d): + on_enter_getcwd(d, handler3) + elif handler3.hit_on_exit(d): + on_exit_getcwd(d, handler3) + + d.kill() + d.terminate() + + self.assertEqual(write_count, 2) + self.assertEqual(handler1.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) + + def test_handles_sync_with_pprint(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + + ptr = 0 + write_count = 0 + + def on_enter_write(d, sh): + nonlocal write_count + + if write_count == 0: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 13], b"Hello, World!") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + else: + self.assertTrue(sh.syscall_number == 0x40) + self.assertEqual(d.memory[d.syscall_arg1, 7], b"provola") + self.assertEqual(d.syscall_arg0, 1) + write_count += 1 + + def on_exit_mmap(d, sh): + self.assertTrue(sh.syscall_number == 222) + + nonlocal ptr + + ptr = d.regs.x0 + + def on_enter_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.syscall_arg0, ptr) + + def on_exit_getcwd(d, sh): + self.assertTrue(sh.syscall_number == 17) + self.assertEqual(d.memory[ptr, 8], os.getcwd()[:8].encode()) + + handler1 = d.handle_syscall("write") + handler2 = d.handle_syscall("mmap") + handler3 = d.handle_syscall("getcwd") + + d.pprint_syscalls = True + + r.sendline(b"provola") + + while not d.dead: + d.cont() + d.wait() + if handler1.hit_on_enter(d): + on_enter_write(d, handler1) + elif handler2.hit_on_exit(d): + on_exit_mmap(d, handler2) + elif handler3.hit_on_enter(d): + on_enter_getcwd(d, handler3) + elif handler3.hit_on_exit(d): + on_exit_getcwd(d, handler3) + + d.kill() + d.terminate() + + self.assertEqual(write_count, 2) + self.assertEqual(handler1.hit_count, 2) + self.assertEqual(handler2.hit_count, 1) + self.assertEqual(handler3.hit_count, 1) diff --git a/test/aarch64/scripts/hijack_syscall_test.py b/test/aarch64/scripts/hijack_syscall_test.py new file mode 100644 index 00000000..4e39195b --- /dev/null +++ b/test/aarch64/scripts/hijack_syscall_test.py @@ -0,0 +1,333 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import sys +import unittest + +from libdebug import debugger + + +class HijackSyscallTest(unittest.TestCase): + def setUp(self): + # Redirect stdout + self.capturedOutput = io.StringIO() + sys.stdout = self.capturedOutput + + def tearDown(self): + sys.stdout = sys.__stdout__ + + def test_hijack_syscall(self): + def on_enter_write(d, sh): + nonlocal write_count + + write_count += 1 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + d.hijack_syscall("getcwd", "write", recursive=True) + + # recursive is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=True) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + write_count = 0 + r = d.run() + + d.hijack_syscall("getcwd", "write", recursive=False) + + # recursive is off, we expect the write handler to be called only twice + handler = d.handle_syscall("write", on_enter=on_enter_write) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 2) + + def test_hijack_syscall_with_pprint(self): + def on_enter_write(d, sh): + nonlocal write_count + + write_count += 1 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + d.pprint_syscalls = True + d.hijack_syscall("getcwd", "write", recursive=True) + + # recursive is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=True) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + write_count = 0 + r = d.run() + + d.pprint_syscalls = True + d.hijack_syscall("getcwd", "write", recursive=False) + + # recursive is off, we expect the write handler to be called only twice + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=False) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 2) + + def test_hijack_handle_syscall(self): + def on_enter_write(d, sh): + nonlocal write_count + + write_count += 1 + + def on_enter_getcwd(d, sh): + d.syscall_number = 0x40 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + d.handle_syscall("getcwd", on_enter=on_enter_getcwd, recursive=True) + + # recursive is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + write_count = 0 + r = d.run() + + d.handle_syscall("getcwd", on_enter=on_enter_getcwd, recursive=False) + + # recursive is off, we expect the write handler to be called only twice + handler = d.handle_syscall("write", on_enter=on_enter_write) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 2) + + def test_hijack_handle_syscall_with_pprint(self): + def on_enter_write(d, sh): + nonlocal write_count + + write_count += 1 + + def on_enter_getcwd(d, sh): + d.syscall_number = 0x40 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + d.pprint_syscalls = True + d.handle_syscall("getcwd", on_enter=on_enter_getcwd, recursive=True) + + # recursive hijack is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=True) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + write_count = 0 + r = d.run() + + d.pprint_syscalls = True + d.handle_syscall("getcwd", on_enter=on_enter_getcwd, recursive=False) + + # recursive is off, we expect the write handler to be called only twice + handler = d.handle_syscall("write", on_enter=on_enter_write) + + r.sendline(b"provola") + + d.cont() + + d.kill() + + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 2) + + def test_hijack_syscall_args(self): + write_buffer = None + + def on_enter_write(d, sh): + nonlocal write_buffer + nonlocal write_count + + write_buffer = d.syscall_arg1 + + write_count += 1 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + # recursive hijack is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=True) + d.breakpoint(0x9f0, file="binary") + + d.cont() + print(r.recvline()) + # Install the hijack. We expect to receive the "Hello, World!" string + + d.wait() + + d.hijack_syscall( + "read", + "write", + syscall_arg0=0x1, + syscall_arg1=write_buffer, + syscall_arg2=14, + recursive=True, + ) + + d.cont() + + print(r.recvline()) + + d.kill() + + self.assertEqual(self.capturedOutput.getvalue().count("Hello, World!"), 2) + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + def test_hijack_syscall_args_with_pprint(self): + write_buffer = None + + def on_enter_write(d, sh): + nonlocal write_buffer + nonlocal write_count + + write_buffer = d.syscall_arg1 + + write_count += 1 + + d = debugger("binaries/handle_syscall_test") + + write_count = 0 + r = d.run() + + d.pprint_syscalls = True + + # recursive hijack is on, we expect the write handler to be called three times + handler = d.handle_syscall("write", on_enter=on_enter_write, recursive=True) + d.breakpoint(0x9f0, file="binary") + + d.cont() + print(r.recvline()) + # Install the hijack. We expect to receive the "Hello, World!" string + + d.wait() + + d.hijack_syscall( + "read", + "write", + syscall_arg0=0x1, + syscall_arg1=write_buffer, + syscall_arg2=14, + recursive=True, + ) + + d.cont() + + print(r.recvline()) + + d.kill() + + self.assertEqual(self.capturedOutput.getvalue().count("Hello, World!"), 2) + self.assertEqual(self.capturedOutput.getvalue().count("write"), 3) + self.assertEqual(self.capturedOutput.getvalue().count("0xaaaaaaaa0ab0"), 3) + self.assertEqual(write_count, handler.hit_count) + self.assertEqual(handler.hit_count, 3) + + def test_hijack_syscall_wrong_args(self): + d = debugger("binaries/handle_syscall_test") + + d.run() + + with self.assertRaises(ValueError): + d.hijack_syscall("read", "write", syscall_arg26=0x1) + + d.kill() + + def loop_detection_test(self): + d = debugger("binaries/handle_syscall_test") + + r = d.run() + d.hijack_syscall("getcwd", "write", recursive=True) + d.hijack_syscall("write", "getcwd", recursive=True) + r.sendline(b"provola") + + # We expect an exception to be raised + with self.assertRaises(RuntimeError): + d.cont() + d.wait() + d.kill() + + r = d.run() + d.hijack_syscall("getcwd", "write", recursive=False) + d.hijack_syscall("write", "getcwd", recursive=True) + r.sendline(b"provola") + + # We expect no exception to be raised + d.cont() + + r = d.run() + d.hijack_syscall("getcwd", "write", recursive=True) + d.hijack_syscall("write", "getcwd", recursive=False) + r.sendline(b"provola") + + # We expect no exception to be raised + d.cont() diff --git a/test/scripts/jumpstart_test.py b/test/aarch64/scripts/jumpstart_test.py similarity index 100% rename from test/scripts/jumpstart_test.py rename to test/aarch64/scripts/jumpstart_test.py diff --git a/test/scripts/memory_test.py b/test/aarch64/scripts/memory_test.py similarity index 91% rename from test/scripts/memory_test.py rename to test/aarch64/scripts/memory_test.py index edfc3525..91daa85c 100644 --- a/test/scripts/memory_test.py +++ b/test/aarch64/scripts/memory_test.py @@ -1,11 +1,12 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # import io import logging +import sys import unittest from libdebug import debugger, libcontext @@ -35,9 +36,9 @@ def test_memory(self): d.cont() - assert d.regs.rip == bp.address + assert d.regs.pc == bp.address - address = d.regs.rdi + address = d.regs.x0 prev = bytes(range(256)) self.assertTrue(d.memory[address, 256] == prev) @@ -58,14 +59,14 @@ def test_mem_access_libs(self): d.cont() - assert d.regs.rip == bp.address + assert d.regs.pc == bp.address - address = d.regs.rdi + address = d.regs.x0 with libcontext.tmp(sym_lvl=5): arena = d.memory["main_arena", 256, "libc"] def p64(x): - return x.to_bytes(8, "little") + return x.to_bytes(8, sys.byteorder) self.assertTrue(p64(address - 0x10) in arena) @@ -86,9 +87,9 @@ def test_memory_exceptions(self): # File should start with ELF magic number self.assertTrue(file.startswith(b"\x7fELF")) - assert d.regs.rip == bp.address + assert d.regs.pc == bp.address - address = d.regs.rdi + address = d.regs.x0 prev = bytes(range(256)) self.assertTrue(d.memory[address, 256] == prev) @@ -110,9 +111,9 @@ def test_memory_multiple_runs(self): d.cont() - assert d.regs.rip == bp.address + assert d.regs.pc == bp.address - address = d.regs.rdi + address = d.regs.x0 prev = bytes(range(256)) self.assertTrue(d.memory[address, 256] == prev) @@ -134,9 +135,9 @@ def test_memory_access_while_running(self): d.cont() # Verify that memory access is only possible when the process is stopped - value = int.from_bytes(d.memory["state", 8], "little") + value = int.from_bytes(d.memory["state", 8], sys.byteorder) self.assertEqual(value, 0xDEADBEEF) - self.assertEqual(d.regs.rip, bp.address) + self.assertEqual(d.regs.pc, bp.address) d.kill() @@ -145,7 +146,7 @@ def test_memory_access_methods(self): d.run() - base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + base = d.regs.pc & 0xFFFFFFFFFFFFF000 # Test different ways to access memory at the start of the file file_0 = d.memory[base, 256] @@ -213,7 +214,7 @@ def test_memory_access_methods_backing_file(self): d.run() - base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + base = d.regs.pc & 0xFFFFFFFFFFFFF000 # Validate that slices work correctly file_0 = d.memory[0x0:"do_nothing", "binary"] @@ -285,7 +286,3 @@ def test_memory_access_methods_backing_file(self): d.memory["main":"main+8", "absolute"] = b"abcd1234" d.kill() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/aarch64/scripts/next_test.py b/test/aarch64/scripts/next_test.py new file mode 100644 index 00000000..8ebadcb4 --- /dev/null +++ b/test/aarch64/scripts/next_test.py @@ -0,0 +1,111 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# +import unittest + +from libdebug import debugger + +TEST_ENTRYPOINT = 0xaaaaaaaa0930 + +# Addresses of the dummy functions +CALL_C_ADDRESS = 0xaaaaaaaa0934 +TEST_BREAKPOINT_ADDRESS = 0xaaaaaaaa0920 + +# Addresses of noteworthy instructions +RETURN_POINT_FROM_C = 0xaaaaaaaa0938 + +class NextTest(unittest.TestCase): + def setUp(self): + pass + + def test_next(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.pc, TEST_ENTRYPOINT) + + # -------- Block 1 ------- # + # Simple Step # + # ------------------------ # + + # Reach call of function c + d.next() + self.assertEqual(d.regs.pc, CALL_C_ADDRESS) + + # -------- Block 2 ------- # + # Skip a call # + # ------------------------ # + + d.next() + self.assertEqual(d.regs.pc, RETURN_POINT_FROM_C) + + d.kill() + d.terminate() + + def test_next_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.pc, TEST_ENTRYPOINT) + + # Reach call of function c + d.next() + + self.assertEqual(d.regs.pc, CALL_C_ADDRESS) + + # -------- Block 1 ------- # + # Call with breakpoint # + # ------------------------ # + + # Set breakpoint + test_breakpoint = d.breakpoint(TEST_BREAKPOINT_ADDRESS) + + d.next() + + # Check we hit the breakpoint + self.assertEqual(d.regs.pc, TEST_BREAKPOINT_ADDRESS) + self.assertEqual(test_breakpoint.hit_count, 1) + + d.kill() + d.terminate() + + def test_next_breakpoint_hw(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.pc, TEST_ENTRYPOINT) + + # Reach call of function c + d.next() + + self.assertEqual(d.regs.pc, CALL_C_ADDRESS) + + # -------- Block 1 ------- # + # Call with breakpoint # + # ------------------------ # + + # Set breakpoint + test_breakpoint = d.breakpoint(TEST_BREAKPOINT_ADDRESS, hardware=True) + + d.next() + + # Check we hit the breakpoint + self.assertEqual(d.regs.pc, TEST_BREAKPOINT_ADDRESS) + self.assertEqual(test_breakpoint.hit_count, 1) + + d.kill() + d.terminate() \ No newline at end of file diff --git a/test/aarch64/scripts/signals_multithread_test.py b/test/aarch64/scripts/signals_multithread_test.py new file mode 100644 index 00000000..392e20d4 --- /dev/null +++ b/test/aarch64/scripts/signals_multithread_test.py @@ -0,0 +1,423 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class SignalMultithreadTest(unittest.TestCase): + def test_signal_multithread_undet_catch_signal_block(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/signals_multithread_undet_test") + + r = d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal(2, callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.signals_to_block = ["SIGUSR1", 15, "SIGINT", 3, 13] + + d.cont() + + r.sendline(b"sync") + r.sendline(b"sync") + + # Receive the exit message + r.recvline(2) + + d.kill() + + self.assertEqual(SIGUSR1_count, 4) + self.assertEqual(SIGTERM_count, 4) + self.assertEqual(SIGINT_count, 4) + self.assertEqual(SIGQUIT_count, 6) + self.assertEqual(SIGPIPE_count, 6) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + def test_signal_multithread_undet_pass(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + + SIGUSR1_count += 1 + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + + SIGTERM_count += 1 + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + + SIGINT_count += 1 + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + + SIGQUIT_count += 1 + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + + SIGPIPE_count += 1 + + d = debugger("binaries/signals_multithread_undet_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + received = [] + for _ in range(24): + received.append(r.recvline()) + + r.sendline(b"sync") + r.sendline(b"sync") + + received.append(r.recvline()) + received.append(r.recvline()) + + d.kill() + + self.assertEqual(SIGUSR1_count, 4) + self.assertEqual(SIGTERM_count, 4) + self.assertEqual(SIGINT_count, 4) + self.assertEqual(SIGQUIT_count, 6) + self.assertEqual(SIGPIPE_count, 6) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + # Count the number of times each signal was received + self.assertEqual(received.count(b"Received signal 10"), 4) + self.assertEqual(received.count(b"Received signal 15"), 4) + self.assertEqual(received.count(b"Received signal 2"), 4) + self.assertEqual(received.count(b"Received signal 3"), 6) + self.assertEqual(received.count(b"Received signal 13"), 6) + # Note: sometimes the signals are passed to ptrace once and received twice + # Maybe another ptrace/kernel/whatever problem in multithreaded programs (?) + # Using raise(sig) instead of kill(pid, sig) to send signals in the original + # program seems to mitigate the problem for whatever reason + # I will investigate this further in the future, but for now this is fine + + def test_signal_multithread_det_catch_signal_block(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + tids = [] + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + nonlocal tids + + SIGUSR1_count += 1 + tids.append(t.thread_id) + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + nonlocal tids + + SIGTERM_count += 1 + tids.append(t.thread_id) + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + nonlocal tids + + SIGINT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + nonlocal tids + + SIGQUIT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + nonlocal tids + + SIGPIPE_count += 1 + tids.append(t.thread_id) + + d = debugger("binaries/signals_multithread_det_test") + + r = d.run() + + catcher1 = d.catch_signal(10, callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal(2, callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.signals_to_block = ["SIGUSR1", 15, "SIGINT", 3, 13] + + d.cont() + + # Receive the exit message + r.recvline(timeout=15) + r.sendline(b"sync") + r.recvline() + + receiver = d.threads[1].thread_id + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + set_tids = set(tids) + self.assertEqual(len(set_tids), 1) + self.assertEqual(set_tids.pop(), receiver) + + def test_signal_multithread_det_pass(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + tids = [] + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + nonlocal tids + + SIGUSR1_count += 1 + tids.append(t.thread_id) + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + nonlocal tids + + SIGTERM_count += 1 + tids.append(t.thread_id) + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + nonlocal tids + + SIGINT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + nonlocal tids + + SIGQUIT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + nonlocal tids + + SIGPIPE_count += 1 + tids.append(t.thread_id) + + d = debugger("binaries/signals_multithread_det_test") + + r = d.run() + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + received = [] + for _ in range(13): + received.append(r.recvline(timeout=5)) + + r.sendline(b"sync") + received.append(r.recvline(timeout=5)) + + receiver = d.threads[1].thread_id + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + # Count the number of times each signal was received + self.assertEqual(received.count(b"Received signal on receiver 10"), 2) + self.assertEqual(received.count(b"Received signal on receiver 15"), 2) + self.assertEqual(received.count(b"Received signal on receiver 2"), 2) + self.assertEqual(received.count(b"Received signal on receiver 3"), 3) + self.assertEqual(received.count(b"Received signal on receiver 13"), 3) + + set_tids = set(tids) + self.assertEqual(len(set_tids), 1) + self.assertEqual(set_tids.pop(), receiver) + + def test_signal_multithread_send_signal(self): + SIGUSR1_count = 0 + SIGINT_count = 0 + SIGQUIT_count = 0 + SIGTERM_count = 0 + SIGPIPE_count = 0 + tids = [] + + def catcher_SIGUSR1(t, sc): + nonlocal SIGUSR1_count + nonlocal tids + + SIGUSR1_count += 1 + tids.append(t.thread_id) + + def catcher_SIGTERM(t, sc): + nonlocal SIGTERM_count + nonlocal tids + + SIGTERM_count += 1 + tids.append(t.thread_id) + + def catcher_SIGINT(t, sc): + nonlocal SIGINT_count + nonlocal tids + + SIGINT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGQUIT(t, sc): + nonlocal SIGQUIT_count + nonlocal tids + + SIGQUIT_count += 1 + tids.append(t.thread_id) + + def catcher_SIGPIPE(t, sc): + nonlocal SIGPIPE_count + nonlocal tids + + SIGPIPE_count += 1 + tids.append(t.thread_id) + + d = debugger("binaries/signals_multithread_det_test") + + # Set a breakpoint to stop the program before the end of the receiver thread + r = d.run() + + bp = d.breakpoint(0xf1c, hardware=True, file="binary") + + catcher1 = d.catch_signal("SIGUSR1", callback=catcher_SIGUSR1) + catcher2 = d.catch_signal("SIGTERM", callback=catcher_SIGTERM) + catcher3 = d.catch_signal("SIGINT", callback=catcher_SIGINT) + catcher4 = d.catch_signal("SIGQUIT", callback=catcher_SIGQUIT) + catcher5 = d.catch_signal("SIGPIPE", callback=catcher_SIGPIPE) + + d.cont() + + received = [] + for _ in range(13): + received.append(r.recvline(timeout=5)) + + r.sendline(b"sync") + + d.wait() + if bp.hit_on(d.threads[1]): + d.threads[1].signal = "SIGUSR1" + d.cont() + received.append(r.recvline(timeout=5)) + received.append(r.recvline(timeout=5)) + + receiver = d.threads[1].thread_id + d.kill() + + self.assertEqual(SIGUSR1_count, 2) + self.assertEqual(SIGTERM_count, 2) + self.assertEqual(SIGINT_count, 2) + self.assertEqual(SIGQUIT_count, 3) + self.assertEqual(SIGPIPE_count, 3) + + self.assertEqual(SIGUSR1_count, catcher1.hit_count) + self.assertEqual(SIGTERM_count, catcher2.hit_count) + self.assertEqual(SIGINT_count, catcher3.hit_count) + self.assertEqual(SIGQUIT_count, catcher4.hit_count) + self.assertEqual(SIGPIPE_count, catcher5.hit_count) + + # Count the number of times each signal was received + self.assertEqual(received.count(b"Received signal on receiver 10"), 3) + self.assertEqual(received.count(b"Received signal on receiver 15"), 2) + self.assertEqual(received.count(b"Received signal on receiver 2"), 2) + self.assertEqual(received.count(b"Received signal on receiver 3"), 3) + self.assertEqual(received.count(b"Received signal on receiver 13"), 3) + + set_tids = set(tids) + self.assertEqual(len(set_tids), 1) + self.assertEqual(set_tids.pop(), receiver) diff --git a/test/aarch64/scripts/speed_test.py b/test/aarch64/scripts/speed_test.py new file mode 100644 index 00000000..c86bf708 --- /dev/null +++ b/test/aarch64/scripts/speed_test.py @@ -0,0 +1,57 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest +from time import perf_counter_ns + +from libdebug import debugger + + +class SpeedTest(unittest.TestCase): + def setUp(self): + self.d = debugger("binaries/speed_test") + + def test_speed(self): + d = self.d + + start_time = perf_counter_ns() + + d.run() + + bp = d.breakpoint("do_nothing") + + d.cont() + + for _ in range(65536): + self.assertTrue(bp.address == d.regs.pc) + d.cont() + + d.kill() + + end_time = perf_counter_ns() + + self.assertTrue((end_time - start_time) < 15 * 1e9) # 15 seconds + + def test_speed_hardware(self): + d = self.d + + start_time = perf_counter_ns() + + d.run() + + bp = d.breakpoint("do_nothing", hardware=True) + + d.cont() + + for _ in range(65536): + self.assertTrue(bp.address == d.regs.pc) + d.cont() + + d.kill() + + end_time = perf_counter_ns() + + self.assertTrue((end_time - start_time) < 15 * 1e9) # 15 seconds diff --git a/test/aarch64/scripts/thread_test.py b/test/aarch64/scripts/thread_test.py new file mode 100644 index 00000000..0a31bc68 --- /dev/null +++ b/test/aarch64/scripts/thread_test.py @@ -0,0 +1,83 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class ThreadTest(unittest.TestCase): + def setUp(self): + pass + + def test_thread(self): + d = debugger("binaries/thread_test") + + d.run() + + bp_t0 = d.breakpoint("do_nothing") + bp_t1 = d.breakpoint("thread_1_function") + bp_t2 = d.breakpoint("thread_2_function") + bp_t3 = d.breakpoint("thread_3_function") + + t1_done, t2_done, t3_done = False, False, False + + d.cont() + + for _ in range(150): + if bp_t0.address == d.regs.pc: + self.assertTrue(t1_done) + self.assertTrue(t2_done) + self.assertTrue(t3_done) + break + + if len(d.threads) > 1 and bp_t1.address == d.threads[1].regs.pc: + t1_done = True + if len(d.threads) > 2 and bp_t2.address == d.threads[2].regs.pc: + t2_done = True + if len(d.threads) > 3 and bp_t3.address == d.threads[3].regs.pc: + t3_done = True + + d.cont() + + d.kill() + d.terminate() + + def test_thread_hardware(self): + d = debugger("binaries/thread_test") + + d.run() + + bp_t0 = d.breakpoint("do_nothing", hardware=True) + bp_t1 = d.breakpoint("thread_1_function", hardware=True) + bp_t2 = d.breakpoint("thread_2_function", hardware=True) + bp_t3 = d.breakpoint("thread_3_function", hardware=True) + + t1_done, t2_done, t3_done = False, False, False + + d.cont() + + for _ in range(15): + if bp_t0.address == d.regs.pc: + self.assertTrue(t1_done) + self.assertTrue(t2_done) + self.assertTrue(t3_done) + break + + if len(d.threads) > 1 and bp_t1.address == d.threads[1].regs.pc: + t1_done = True + if len(d.threads) > 2 and bp_t2.address == d.threads[2].regs.pc: + t2_done = True + if len(d.threads) > 3 and bp_t3.address == d.threads[3].regs.pc: + t3_done = True + + d.cont() + + d.kill() + d.terminate() + +if __name__ == "__main__": + unittest.main() diff --git a/test/aarch64/scripts/thread_test_complex.py b/test/aarch64/scripts/thread_test_complex.py new file mode 100644 index 00000000..786f61d7 --- /dev/null +++ b/test/aarch64/scripts/thread_test_complex.py @@ -0,0 +1,67 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini, Gabriele Digregorio. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class ThreadTestComplex(unittest.TestCase): + def setUp(self): + pass + + def test_thread(self): + def factorial(n): + if n == 0: + return 1 + else: + return (n * factorial(n - 1)) & (2**32 - 1) + + d = debugger("binaries/thread_test_complex") + + d.run() + + bp1_t0 = d.breakpoint("do_nothing") + bp2_t1 = d.breakpoint("thread_1_function+18") + bp3_t2 = d.breakpoint("thread_2_function+24") + + bp1_hit, bp2_hit, bp3_hit = False, False, False + t1, t2 = None, None + + d.cont() + + while True: + if len(d.threads) == 2: + t1 = d.threads[1] + + if len(d.threads) == 3: + t2 = d.threads[2] + + if t1 and bp2_t1.address == t1.regs.pc: + bp2_hit = True + self.assertTrue(bp2_t1.hit_count == (t1.regs.w0 + 1)) + + if bp1_t0.address == d.regs.pc: + bp1_hit = True + self.assertTrue(bp2_hit) + self.assertEqual(bp2_t1.hit_count, 50) + self.assertFalse(bp3_hit) + self.assertEqual(bp1_t0.hit_count, 1) + + if t2 and bp3_t2.address == t2.regs.pc: + bp3_hit = True + self.assertTrue(factorial(bp3_t2.hit_count) == t2.regs.x0) + self.assertTrue(bp2_hit) + self.assertTrue(bp1_hit) + + d.cont() + + if bp3_t2.hit_count == 49: + break + + d.kill() + d.terminate() + diff --git a/test/aarch64/scripts/watchpoint_test.py b/test/aarch64/scripts/watchpoint_test.py new file mode 100644 index 00000000..326543da --- /dev/null +++ b/test/aarch64/scripts/watchpoint_test.py @@ -0,0 +1,150 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2023-2024 Francesco Panebianco, Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + + +class WatchpointTest(unittest.TestCase): + def test_watchpoint(self): + d = debugger("binaries/watchpoint_test", auto_interrupt_on_command=False) + + d.run() + + d.breakpoint("global_char", hardware=True, condition="rw", length=1) + d.breakpoint("global_int", hardware=True, condition="w", length=4) + d.breakpoint("global_short", hardware=True, condition="r", length=2) + d.breakpoint("global_long", hardware=True, condition="rw", length=8) + + d.cont() + + base = d.regs.pc & ~0xfff + + # strb w1, [x0] => global_char + self.assertEqual(d.regs.pc, base + 0x724) + + d.cont() + + # str w1, [x0] => global_int + self.assertEqual(d.regs.pc, base + 0x748) + + d.cont() + + # str x1, [x0] => global_long + self.assertEqual(d.regs.pc, base + 0x764) + + d.cont() + + # ldrb w0, [x0] => global_char + self.assertEqual(d.regs.pc, base + 0x780) + + d.cont() + + # ldr w0, [x0] => global_short + self.assertEqual(d.regs.pc, base + 0x790) + + d.cont() + + # ldr x0, [x0] => global_long + self.assertEqual(d.regs.pc, base + 0x7b0) + + d.cont() + + d.kill() + + def test_watchpoint_callback(self): + global_char_ip = [] + global_int_ip = [] + global_short_ip = [] + global_long_ip = [] + + def watchpoint_global_char(t, b): + nonlocal global_char_ip + + global_char_ip.append(t.regs.pc) + + def watchpoint_global_int(t, b): + nonlocal global_int_ip + + global_int_ip.append(t.regs.pc) + + def watchpoint_global_short(t, b): + nonlocal global_short_ip + + global_short_ip.append(t.regs.pc) + + def watchpoint_global_long(t, b): + nonlocal global_long_ip + + global_long_ip.append(t.regs.pc) + + d = debugger("binaries/watchpoint_test", auto_interrupt_on_command=False) + + d.run() + + base = d.regs.pc & ~0xfff + + wp1 = d.breakpoint( + "global_char", + hardware=True, + condition="rw", + length=1, + callback=watchpoint_global_char, + ) + wp2 = d.breakpoint( + "global_int", + hardware=True, + condition="w", + length=4, + callback=watchpoint_global_int, + ) + wp3 = d.breakpoint( + "global_long", + hardware=True, + condition="rw", + length=8, + callback=watchpoint_global_long, + ) + wp4 = d.breakpoint( + "global_short", + hardware=True, + condition="r", + length=2, + callback=watchpoint_global_short, + ) + + d.cont() + + d.kill() + + # strb w1, [x0] => global_char + self.assertEqual(global_char_ip[0], base + 0x724) + + # str w1, [x0] => global_int + self.assertEqual(global_int_ip[0], base + 0x748) + + # str x1, [x0] => global_long + self.assertEqual(global_long_ip[0], base + 0x764) + + # ldrb w0, [x0] => global_char + self.assertEqual(global_char_ip[1], base + 0x780) + + # ldr w0, [x0] => global_short + self.assertEqual(global_short_ip[0], base + 0x790) + + # ldr x0, [x0] => global_long + self.assertEqual(global_long_ip[1], base + 0x7b0) + + self.assertEqual(len(global_char_ip), 2) + self.assertEqual(len(global_int_ip), 1) + self.assertEqual(len(global_short_ip), 1) + self.assertEqual(len(global_long_ip), 2) + self.assertEqual(wp1.hit_count, 2) + self.assertEqual(wp2.hit_count, 1) + self.assertEqual(wp3.hit_count, 2) + self.assertEqual(wp4.hit_count, 1) + diff --git a/test/aarch64/srcs/basic_test.c b/test/aarch64/srcs/basic_test.c new file mode 100644 index 00000000..66b52dc1 --- /dev/null +++ b/test/aarch64/srcs/basic_test.c @@ -0,0 +1,168 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include +#include + +#pragma GCC optimize ("O0") + +void register_test() +{ + asm volatile ( + "sub sp, sp, #96\n\t" + "stp x19, x20, [sp, #0]\n\t" + "stp x21, x22, [sp, #16]\n\t" + "stp x23, x24, [sp, #32]\n\t" + "stp x25, x26, [sp, #48]\n\t" + "stp x27, x28, [sp, #64]\n\t" + "stp x29, x30, [sp, #80]\n\t" + "nop\n\t" + "movk x0, #0x1111, lsl #0\n\t" + "movk x0, #0x2222, lsl #16\n\t" + "movk x0, #0x3333, lsl #32\n\t" + "movk x0, #0x4444, lsl #48\n\t" + "movk x1, #0x5555, lsl #0\n\t" + "movk x1, #0x6666, lsl #16\n\t" + "movk x1, #0x7777, lsl #32\n\t" + "movk x1, #0x8888, lsl #48\n\t" + "movk x2, #0x9999, lsl #0\n\t" + "movk x2, #0xaaaa, lsl #16\n\t" + "movk x2, #0xbbbb, lsl #32\n\t" + "movk x2, #0xcccc, lsl #48\n\t" + "movk x3, #0xdddd, lsl #0\n\t" + "movk x3, #0xeeee, lsl #16\n\t" + "movk x3, #0xffff, lsl #32\n\t" + "movk x3, #0x1111, lsl #48\n\t" + "movk x4, #0x2222, lsl #0\n\t" + "movk x4, #0x3333, lsl #16\n\t" + "movk x4, #0x4444, lsl #32\n\t" + "movk x4, #0x5555, lsl #48\n\t" + "movk x5, #0x6666, lsl #0\n\t" + "movk x5, #0x7777, lsl #16\n\t" + "movk x5, #0x8888, lsl #32\n\t" + "movk x5, #0x9999, lsl #48\n\t" + "movk x6, #0xaaaa, lsl #0\n\t" + "movk x6, #0xbbbb, lsl #16\n\t" + "movk x6, #0xcccc, lsl #32\n\t" + "movk x6, #0xdddd, lsl #48\n\t" + "movk x7, #0xeeee, lsl #0\n\t" + "movk x7, #0xffff, lsl #16\n\t" + "movk x7, #0x1111, lsl #32\n\t" + "movk x7, #0x2222, lsl #48\n\t" + "movk x8, #0x3333, lsl #0\n\t" + "movk x8, #0x4444, lsl #16\n\t" + "movk x8, #0x5555, lsl #32\n\t" + "movk x8, #0x6666, lsl #48\n\t" + "movk x9, #0x7777, lsl #0\n\t" + "movk x9, #0x8888, lsl #16\n\t" + "movk x9, #0x9999, lsl #32\n\t" + "movk x9, #0xaaaa, lsl #48\n\t" + "movk x10, #0xbbbb, lsl #0\n\t" + "movk x10, #0xcccc, lsl #16\n\t" + "movk x10, #0xdddd, lsl #32\n\t" + "movk x10, #0xeeee, lsl #48\n\t" + "movk x11, #0xffff, lsl #0\n\t" + "movk x11, #0x1111, lsl #16\n\t" + "movk x11, #0x2222, lsl #32\n\t" + "movk x11, #0x3333, lsl #48\n\t" + "movk x12, #0x4444, lsl #0\n\t" + "movk x12, #0x5555, lsl #16\n\t" + "movk x12, #0x6666, lsl #32\n\t" + "movk x12, #0x7777, lsl #48\n\t" + "movk x13, #0x8888, lsl #0\n\t" + "movk x13, #0x9999, lsl #16\n\t" + "movk x13, #0xaaaa, lsl #32\n\t" + "movk x13, #0xbbbb, lsl #48\n\t" + "movk x14, #0xcccc, lsl #0\n\t" + "movk x14, #0xdddd, lsl #16\n\t" + "movk x14, #0xeeee, lsl #32\n\t" + "movk x14, #0xffff, lsl #48\n\t" + "movk x15, #0x1111, lsl #0\n\t" + "movk x15, #0x2222, lsl #16\n\t" + "movk x15, #0x3333, lsl #32\n\t" + "movk x15, #0x4444, lsl #48\n\t" + "movk x16, #0x5555, lsl #0\n\t" + "movk x16, #0x6666, lsl #16\n\t" + "movk x16, #0x7777, lsl #32\n\t" + "movk x16, #0x8888, lsl #48\n\t" + "movk x17, #0x9999, lsl #0\n\t" + "movk x17, #0xaaaa, lsl #16\n\t" + "movk x17, #0xbbbb, lsl #32\n\t" + "movk x17, #0xcccc, lsl #48\n\t" + "movk x18, #0xdddd, lsl #0\n\t" + "movk x18, #0xeeee, lsl #16\n\t" + "movk x18, #0xffff, lsl #32\n\t" + "movk x18, #0x1111, lsl #48\n\t" + "movk x19, #0x2222, lsl #0\n\t" + "movk x19, #0x3333, lsl #16\n\t" + "movk x19, #0x4444, lsl #32\n\t" + "movk x19, #0x5555, lsl #48\n\t" + "movk x20, #0x6666, lsl #0\n\t" + "movk x20, #0x7777, lsl #16\n\t" + "movk x20, #0x8888, lsl #32\n\t" + "movk x20, #0x9999, lsl #48\n\t" + "movk x21, #0xaaaa, lsl #0\n\t" + "movk x21, #0xbbbb, lsl #16\n\t" + "movk x21, #0xcccc, lsl #32\n\t" + "movk x21, #0xdddd, lsl #48\n\t" + "movk x22, #0xeeee, lsl #0\n\t" + "movk x22, #0xffff, lsl #16\n\t" + "movk x22, #0x1111, lsl #32\n\t" + "movk x22, #0x2222, lsl #48\n\t" + "movk x23, #0x3333, lsl #0\n\t" + "movk x23, #0x4444, lsl #16\n\t" + "movk x23, #0x5555, lsl #32\n\t" + "movk x23, #0x6666, lsl #48\n\t" + "movk x24, #0x7777, lsl #0\n\t" + "movk x24, #0x8888, lsl #16\n\t" + "movk x24, #0x9999, lsl #32\n\t" + "movk x24, #0xaaaa, lsl #48\n\t" + "movk x25, #0xbbbb, lsl #0\n\t" + "movk x25, #0xcccc, lsl #16\n\t" + "movk x25, #0xdddd, lsl #32\n\t" + "movk x25, #0xeeee, lsl #48\n\t" + "movk x26, #0xffff, lsl #0\n\t" + "movk x26, #0x1111, lsl #16\n\t" + "movk x26, #0x2222, lsl #32\n\t" + "movk x26, #0x3333, lsl #48\n\t" + "movk x27, #0x4444, lsl #0\n\t" + "movk x27, #0x5555, lsl #16\n\t" + "movk x27, #0x6666, lsl #32\n\t" + "movk x27, #0x7777, lsl #48\n\t" + "movk x28, #0x8888, lsl #0\n\t" + "movk x28, #0x9999, lsl #16\n\t" + "movk x28, #0xaaaa, lsl #32\n\t" + "movk x28, #0xbbbb, lsl #48\n\t" + "movk x29, #0xcccc, lsl #0\n\t" + "movk x29, #0xdddd, lsl #16\n\t" + "movk x29, #0xeeee, lsl #32\n\t" + "movk x29, #0xffff, lsl #48\n\t" + "movk x30, #0x1111, lsl #0\n\t" + "movk x30, #0x2222, lsl #16\n\t" + "movk x30, #0x3333, lsl #32\n\t" + "movk x30, #0x4444, lsl #48\n\t" + "nop\n\t" + "ldp x19, x20, [sp, #0]\n\t" + "ldp x21, x22, [sp, #16]\n\t" + "ldp x23, x24, [sp, #32]\n\t" + "ldp x25, x26, [sp, #48]\n\t" + "ldp x27, x28, [sp, #64]\n\t" + "ldp x29, x30, [sp, #80]\n\t" + "add sp, sp, #96\n\t" + : + : + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18" + ); +} + +int main() +{ + printf("Provola\n"); + + register_test(); + + return EXIT_SUCCESS; +} diff --git a/test/aarch64/srcs/floating_point_test.c b/test/aarch64/srcs/floating_point_test.c new file mode 100644 index 00000000..5bfdb39b --- /dev/null +++ b/test/aarch64/srcs/floating_point_test.c @@ -0,0 +1,106 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +void rotate(char value[16]) +{ + char temp = value[0]; + for (int i = 0; i < 15; i++) { + value[i] = value[i + 1]; + } + value[15] = temp; +} + +int main() +{ + char value[16]; + + for (int i = 0; i < 16; i++) { + value[i] = i; + } + + // aarch64 floating point registers + __asm__ __volatile__("ld1 {v0.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v1.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v2.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v3.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v4.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v5.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v6.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v7.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v8.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v9.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v10.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v11.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v12.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v13.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v14.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v15.16b}, [%0]" : : "r" (value)); + rotate(value); + + for (int i = 0; i < 16; i++) { + value[i] = 0x80 + i; + } + + __asm__ __volatile__("ld1 {v16.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v17.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v18.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v19.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v20.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v21.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v22.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v23.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v24.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v25.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v26.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v27.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v28.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v29.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v30.16b}, [%0]" : : "r" (value)); + rotate(value); + __asm__ __volatile__("ld1 {v31.16b}, [%0]" : : "r" (value)); + + __asm__ __volatile__("nop\n\t"); + + char result[16]; + __asm__ __volatile__("st1 {v0.16b}, [%0]" : : "r" (result)); + + unsigned long check = *(unsigned long*)result; + + if (check == 0xdeadbeefdeadbeef) { + __asm__ __volatile__("nop\n\t"); + } + + return 0; +} \ No newline at end of file diff --git a/test/aarch64/srcs/thread_test.c b/test/aarch64/srcs/thread_test.c new file mode 100644 index 00000000..54672484 --- /dev/null +++ b/test/aarch64/srcs/thread_test.c @@ -0,0 +1,59 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include +#include +#include + +void thread_1_function() +{ + asm volatile ( + "movk x0, #0x1111, lsl #0\n\t" + "movk x0, #0x2222, lsl #16\n\t" + "movk x0, #0x3333, lsl #32\n\t" + "movk x0, #0x4444, lsl #48\n\t" + "nop\n\t"::: "x0"); +} + +void thread_2_function() +{ + asm volatile ( + "movk x0, #0x6666, lsl #0\n\t" + "movk x0, #0x7777, lsl #16\n\t" + "movk x0, #0x8888, lsl #32\n\t" + "movk x0, #0x9999, lsl #48\n\t" + "nop\n\t"::: "x0"); +} + +void thread_3_function() +{ + asm volatile ( + "movk x0, #0xeeee, lsl #0\n\t" + "movk x0, #0xffff, lsl #16\n\t" + "movk x0, #0x1111, lsl #32\n\t" + "movk x0, #0x2222, lsl #48\n\t" + "nop\n\t"::: "x0"); +} + +void do_nothing() +{ + asm volatile ("nop\n\t"); +} + +int main() +{ + pthread_t thread_1, thread_2, thread_3; + pthread_create(&thread_1, NULL, (void *)thread_1_function, NULL); + pthread_create(&thread_2, NULL, (void *)thread_2_function, NULL); + pthread_create(&thread_3, NULL, (void *)thread_3_function, NULL); + pthread_join(thread_1, NULL); + pthread_join(thread_2, NULL); + pthread_join(thread_3, NULL); + + do_nothing(); + + return 0; +} diff --git a/test/aarch64/srcs/watchpoint_test.c b/test/aarch64/srcs/watchpoint_test.c new file mode 100644 index 00000000..4ca44540 --- /dev/null +++ b/test/aarch64/srcs/watchpoint_test.c @@ -0,0 +1,34 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include +#include +#include + +uint8_t __attribute__((aligned(8))) global_char = 0; +uint16_t __attribute__((aligned(8))) global_short = 0; +uint32_t __attribute__((aligned(8))) global_int = 0; +uint64_t __attribute__((aligned(8))) global_long = 0; + +int main() +{ + global_char = 0x01; + global_short = 0x0203; + global_int = 0x04050607; + global_long = 0x08090a0b0c0d0e0f; + + uint8_t local_char = 0; + uint16_t local_short = 0; + uint32_t local_int = 0; + uint64_t local_long = 0; + + local_char = global_char; + local_short = global_short; + local_int = global_int; + local_long = global_long; + + return 0; +} diff --git a/test/CTF/0 b/test/amd64/CTF/0 similarity index 100% rename from test/CTF/0 rename to test/amd64/CTF/0 diff --git a/test/CTF/1 b/test/amd64/CTF/1 similarity index 100% rename from test/CTF/1 rename to test/amd64/CTF/1 diff --git a/test/CTF/2 b/test/amd64/CTF/2 similarity index 100% rename from test/CTF/2 rename to test/amd64/CTF/2 diff --git a/test/CTF/deep-dive-division b/test/amd64/CTF/deep-dive-division similarity index 100% rename from test/CTF/deep-dive-division rename to test/amd64/CTF/deep-dive-division diff --git a/test/CTF/jumpout b/test/amd64/CTF/jumpout similarity index 100% rename from test/CTF/jumpout rename to test/amd64/CTF/jumpout diff --git a/test/CTF/vmwhere1 b/test/amd64/CTF/vmwhere1 similarity index 100% rename from test/CTF/vmwhere1 rename to test/amd64/CTF/vmwhere1 diff --git a/test/CTF/vmwhere1_program b/test/amd64/CTF/vmwhere1_program similarity index 100% rename from test/CTF/vmwhere1_program rename to test/amd64/CTF/vmwhere1_program diff --git a/test/Makefile b/test/amd64/Makefile similarity index 100% rename from test/Makefile rename to test/amd64/Makefile diff --git a/test/benchmarks/benchmarks.md b/test/amd64/benchmarks/benchmarks.md similarity index 100% rename from test/benchmarks/benchmarks.md rename to test/amd64/benchmarks/benchmarks.md diff --git a/test/benchmarks/breakpoint_gdb.py b/test/amd64/benchmarks/breakpoint_gdb.py similarity index 100% rename from test/benchmarks/breakpoint_gdb.py rename to test/amd64/benchmarks/breakpoint_gdb.py diff --git a/test/benchmarks/breakpoint_libdebug.py b/test/amd64/benchmarks/breakpoint_libdebug.py similarity index 100% rename from test/benchmarks/breakpoint_libdebug.py rename to test/amd64/benchmarks/breakpoint_libdebug.py diff --git a/test/benchmarks/results/breakpoint_benchmark.svg b/test/amd64/benchmarks/results/breakpoint_benchmark.svg similarity index 100% rename from test/benchmarks/results/breakpoint_benchmark.svg rename to test/amd64/benchmarks/results/breakpoint_benchmark.svg diff --git a/test/benchmarks/results/breakpoint_gdb.pkl b/test/amd64/benchmarks/results/breakpoint_gdb.pkl similarity index 100% rename from test/benchmarks/results/breakpoint_gdb.pkl rename to test/amd64/benchmarks/results/breakpoint_gdb.pkl diff --git a/test/benchmarks/results/breakpoint_libdebug.pkl b/test/amd64/benchmarks/results/breakpoint_libdebug.pkl similarity index 100% rename from test/benchmarks/results/breakpoint_libdebug.pkl rename to test/amd64/benchmarks/results/breakpoint_libdebug.pkl diff --git a/test/benchmarks/results/syscall_benchmark.svg b/test/amd64/benchmarks/results/syscall_benchmark.svg similarity index 100% rename from test/benchmarks/results/syscall_benchmark.svg rename to test/amd64/benchmarks/results/syscall_benchmark.svg diff --git a/test/benchmarks/results/syscall_gdb.pkl b/test/amd64/benchmarks/results/syscall_gdb.pkl similarity index 100% rename from test/benchmarks/results/syscall_gdb.pkl rename to test/amd64/benchmarks/results/syscall_gdb.pkl diff --git a/test/benchmarks/results/syscall_libdebug.pkl b/test/amd64/benchmarks/results/syscall_libdebug.pkl similarity index 100% rename from test/benchmarks/results/syscall_libdebug.pkl rename to test/amd64/benchmarks/results/syscall_libdebug.pkl diff --git a/test/benchmarks/syscall_gdb.py b/test/amd64/benchmarks/syscall_gdb.py similarity index 100% rename from test/benchmarks/syscall_gdb.py rename to test/amd64/benchmarks/syscall_gdb.py diff --git a/test/benchmarks/syscall_libdebug.py b/test/amd64/benchmarks/syscall_libdebug.py similarity index 100% rename from test/benchmarks/syscall_libdebug.py rename to test/amd64/benchmarks/syscall_libdebug.py diff --git a/test/binaries/antidebug_brute_test b/test/amd64/binaries/antidebug_brute_test similarity index 100% rename from test/binaries/antidebug_brute_test rename to test/amd64/binaries/antidebug_brute_test diff --git a/test/binaries/attach_test b/test/amd64/binaries/attach_test similarity index 100% rename from test/binaries/attach_test rename to test/amd64/binaries/attach_test diff --git a/test/binaries/backtrace_test b/test/amd64/binaries/backtrace_test similarity index 100% rename from test/binaries/backtrace_test rename to test/amd64/binaries/backtrace_test diff --git a/test/binaries/basic_test b/test/amd64/binaries/basic_test similarity index 100% rename from test/binaries/basic_test rename to test/amd64/binaries/basic_test diff --git a/test/binaries/basic_test_pie b/test/amd64/binaries/basic_test_pie similarity index 100% rename from test/binaries/basic_test_pie rename to test/amd64/binaries/basic_test_pie diff --git a/test/binaries/benchmark b/test/amd64/binaries/benchmark similarity index 100% rename from test/binaries/benchmark rename to test/amd64/binaries/benchmark diff --git a/test/binaries/breakpoint_test b/test/amd64/binaries/breakpoint_test similarity index 100% rename from test/binaries/breakpoint_test rename to test/amd64/binaries/breakpoint_test diff --git a/test/binaries/brute_test b/test/amd64/binaries/brute_test similarity index 100% rename from test/binaries/brute_test rename to test/amd64/binaries/brute_test diff --git a/test/binaries/catch_signal_test b/test/amd64/binaries/catch_signal_test similarity index 100% rename from test/binaries/catch_signal_test rename to test/amd64/binaries/catch_signal_test diff --git a/test/binaries/cc_workshop b/test/amd64/binaries/cc_workshop similarity index 100% rename from test/binaries/cc_workshop rename to test/amd64/binaries/cc_workshop diff --git a/test/binaries/complex_thread_test b/test/amd64/binaries/complex_thread_test similarity index 100% rename from test/binaries/complex_thread_test rename to test/amd64/binaries/complex_thread_test diff --git a/test/binaries/executable_section_test b/test/amd64/binaries/executable_section_test similarity index 100% rename from test/binaries/executable_section_test rename to test/amd64/binaries/executable_section_test diff --git a/test/binaries/finish_test b/test/amd64/binaries/finish_test similarity index 100% rename from test/binaries/finish_test rename to test/amd64/binaries/finish_test diff --git a/test/amd64/binaries/floating_point_2696_test b/test/amd64/binaries/floating_point_2696_test new file mode 100755 index 00000000..341cd335 Binary files /dev/null and b/test/amd64/binaries/floating_point_2696_test differ diff --git a/test/amd64/binaries/floating_point_512_test b/test/amd64/binaries/floating_point_512_test new file mode 100755 index 00000000..a9596012 Binary files /dev/null and b/test/amd64/binaries/floating_point_512_test differ diff --git a/test/amd64/binaries/floating_point_896_test b/test/amd64/binaries/floating_point_896_test new file mode 100755 index 00000000..4eb630bb Binary files /dev/null and b/test/amd64/binaries/floating_point_896_test differ diff --git a/test/binaries/handle_syscall_test b/test/amd64/binaries/handle_syscall_test similarity index 100% rename from test/binaries/handle_syscall_test rename to test/amd64/binaries/handle_syscall_test diff --git a/test/amd64/binaries/infinite_loop_test b/test/amd64/binaries/infinite_loop_test new file mode 100755 index 00000000..d893a624 Binary files /dev/null and b/test/amd64/binaries/infinite_loop_test differ diff --git a/test/binaries/jumpstart_test b/test/amd64/binaries/jumpstart_test similarity index 100% rename from test/binaries/jumpstart_test rename to test/amd64/binaries/jumpstart_test diff --git a/test/binaries/jumpstart_test_preload.so b/test/amd64/binaries/jumpstart_test_preload.so similarity index 99% rename from test/binaries/jumpstart_test_preload.so rename to test/amd64/binaries/jumpstart_test_preload.so index ef48e265..a68b0e47 100755 Binary files a/test/binaries/jumpstart_test_preload.so and b/test/amd64/binaries/jumpstart_test_preload.so differ diff --git a/test/binaries/math_loop_test b/test/amd64/binaries/math_loop_test similarity index 100% rename from test/binaries/math_loop_test rename to test/amd64/binaries/math_loop_test diff --git a/test/binaries/memory_test b/test/amd64/binaries/memory_test similarity index 100% rename from test/binaries/memory_test rename to test/amd64/binaries/memory_test diff --git a/test/binaries/memory_test_2 b/test/amd64/binaries/memory_test_2 similarity index 100% rename from test/binaries/memory_test_2 rename to test/amd64/binaries/memory_test_2 diff --git a/test/amd64/binaries/memory_test_3 b/test/amd64/binaries/memory_test_3 new file mode 100755 index 00000000..81032021 Binary files /dev/null and b/test/amd64/binaries/memory_test_3 differ diff --git a/test/amd64/binaries/memory_test_4 b/test/amd64/binaries/memory_test_4 new file mode 100755 index 00000000..9ec0cc40 Binary files /dev/null and b/test/amd64/binaries/memory_test_4 differ diff --git a/test/binaries/node b/test/amd64/binaries/node similarity index 100% rename from test/binaries/node rename to test/amd64/binaries/node diff --git a/test/binaries/segfault_test b/test/amd64/binaries/segfault_test similarity index 100% rename from test/binaries/segfault_test rename to test/amd64/binaries/segfault_test diff --git a/test/binaries/signals_multithread_det_test b/test/amd64/binaries/signals_multithread_det_test similarity index 100% rename from test/binaries/signals_multithread_det_test rename to test/amd64/binaries/signals_multithread_det_test diff --git a/test/binaries/signals_multithread_undet_test b/test/amd64/binaries/signals_multithread_undet_test similarity index 100% rename from test/binaries/signals_multithread_undet_test rename to test/amd64/binaries/signals_multithread_undet_test diff --git a/test/binaries/speed_test b/test/amd64/binaries/speed_test similarity index 100% rename from test/binaries/speed_test rename to test/amd64/binaries/speed_test diff --git a/test/binaries/thread_test b/test/amd64/binaries/thread_test similarity index 100% rename from test/binaries/thread_test rename to test/amd64/binaries/thread_test diff --git a/test/binaries/watchpoint_test b/test/amd64/binaries/watchpoint_test similarity index 100% rename from test/binaries/watchpoint_test rename to test/amd64/binaries/watchpoint_test diff --git a/test/dockerfiles/archlinux.Dockerfile b/test/amd64/dockerfiles/archlinux.Dockerfile similarity index 100% rename from test/dockerfiles/archlinux.Dockerfile rename to test/amd64/dockerfiles/archlinux.Dockerfile diff --git a/test/dockerfiles/archlinux.Dockerfile.dockerignore b/test/amd64/dockerfiles/archlinux.Dockerfile.dockerignore similarity index 100% rename from test/dockerfiles/archlinux.Dockerfile.dockerignore rename to test/amd64/dockerfiles/archlinux.Dockerfile.dockerignore diff --git a/test/dockerfiles/debian.Dockerfile b/test/amd64/dockerfiles/debian.Dockerfile similarity index 100% rename from test/dockerfiles/debian.Dockerfile rename to test/amd64/dockerfiles/debian.Dockerfile diff --git a/test/dockerfiles/debian.Dockerfile.dockerignore b/test/amd64/dockerfiles/debian.Dockerfile.dockerignore similarity index 100% rename from test/dockerfiles/debian.Dockerfile.dockerignore rename to test/amd64/dockerfiles/debian.Dockerfile.dockerignore diff --git a/test/dockerfiles/fedora.Dockerfile b/test/amd64/dockerfiles/fedora.Dockerfile similarity index 100% rename from test/dockerfiles/fedora.Dockerfile rename to test/amd64/dockerfiles/fedora.Dockerfile diff --git a/test/dockerfiles/fedora.Dockerfile.dockerignore b/test/amd64/dockerfiles/fedora.Dockerfile.dockerignore similarity index 100% rename from test/dockerfiles/fedora.Dockerfile.dockerignore rename to test/amd64/dockerfiles/fedora.Dockerfile.dockerignore diff --git a/test/dockerfiles/run_tests.sh b/test/amd64/dockerfiles/run_tests.sh similarity index 100% rename from test/dockerfiles/run_tests.sh rename to test/amd64/dockerfiles/run_tests.sh diff --git a/test/dockerfiles/ubuntu.Dockerfile b/test/amd64/dockerfiles/ubuntu.Dockerfile similarity index 100% rename from test/dockerfiles/ubuntu.Dockerfile rename to test/amd64/dockerfiles/ubuntu.Dockerfile diff --git a/test/dockerfiles/ubuntu.Dockerfile.dockerignore b/test/amd64/dockerfiles/ubuntu.Dockerfile.dockerignore similarity index 100% rename from test/dockerfiles/ubuntu.Dockerfile.dockerignore rename to test/amd64/dockerfiles/ubuntu.Dockerfile.dockerignore diff --git a/test/other_tests/gdb_migration_test.py b/test/amd64/other_tests/gdb_migration_test.py similarity index 100% rename from test/other_tests/gdb_migration_test.py rename to test/amd64/other_tests/gdb_migration_test.py diff --git a/test/run_containerized_tests.sh b/test/amd64/run_containerized_tests.sh similarity index 100% rename from test/run_containerized_tests.sh rename to test/amd64/run_containerized_tests.sh diff --git a/test/amd64/run_suite.py b/test/amd64/run_suite.py new file mode 100644 index 00000000..9a70dea6 --- /dev/null +++ b/test/amd64/run_suite.py @@ -0,0 +1,268 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2023-2024 Gabriele Digregorio, Roberto Alessandro Bertolini, Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import sys +import unittest + +from scripts.alias_test import AliasTest +from scripts.atexit_handler_test import AtexitHandlerTest +from scripts.attach_detach_test import AttachDetachTest +from scripts.auto_waiting_test import AutoWaitingNlinks, AutoWaitingTest +from scripts.backtrace_test import BacktraceTest +from scripts.basic_test import BasicPieTest, BasicTest, ControlFlowTest, HwBasicTest +from scripts.breakpoint_test import BreakpointTest +from scripts.brute_test import BruteTest +from scripts.builtin_handler_test import AntidebugEscapingTest +from scripts.callback_test import CallbackTest +from scripts.catch_signal_test import SignalCatchTest +from scripts.death_test import DeathTest +from scripts.deep_dive_division_test import DeepDiveDivision +from scripts.finish_test import FinishTest +from scripts.floating_point_test import FloatingPointTest +from scripts.handle_syscall_test import HandleSyscallTest +from scripts.hijack_syscall_test import SyscallHijackTest +from scripts.jumpout_test import Jumpout +from scripts.jumpstart_test import JumpstartTest +from scripts.large_binary_sym_test import LargeBinarySymTest +from scripts.memory_test import MemoryTest +from scripts.memory_fast_test import MemoryFastTest +from scripts.multiple_debuggers_test import MultipleDebuggersTest +from scripts.next_test import NextTest +from scripts.nlinks_test import Nlinks +from scripts.pprint_syscalls_test import PPrintSyscallsTest +from scripts.signals_multithread_test import SignalMultithreadTest +from scripts.speed_test import SpeedTest +from scripts.thread_test import ComplexThreadTest, ThreadTest +from scripts.vmwhere1_test import Vmwhere1 +from scripts.waiting_test import WaitingNlinks, WaitingTest +from scripts.watchpoint_alias_test import WatchpointAliasTest +from scripts.watchpoint_test import WatchpointTest + + +def fast_suite(): + suite = unittest.TestSuite() + suite.addTest(BasicTest("test_basic")) + suite.addTest(BasicTest("test_registers")) + suite.addTest(BasicTest("test_step")) + suite.addTest(BasicTest("test_step_hardware")) + suite.addTest(BasicPieTest("test_basic")) + suite.addTest(BreakpointTest("test_bps")) + suite.addTest(BreakpointTest("test_bp_disable")) + suite.addTest(BreakpointTest("test_bp_disable_hw")) + suite.addTest(BreakpointTest("test_bp_disable_reenable")) + suite.addTest(BreakpointTest("test_bp_disable_reenable_hw")) + suite.addTest(BreakpointTest("test_bps_running")) + suite.addTest(BreakpointTest("test_bp_backing_file")) + suite.addTest(BreakpointTest("test_bp_disable_on_creation")) + suite.addTest(BreakpointTest("test_bp_disable_on_creation_2")) + suite.addTest(BreakpointTest("test_bp_disable_on_creation_hardware")) + suite.addTest(BreakpointTest("test_bp_disable_on_creation_2_hardware")) + suite.addTest(MemoryTest("test_memory")) + suite.addTest(MemoryTest("test_mem_access_libs")) + suite.addTest(MemoryTest("test_memory_access_methods_backing_file")) + suite.addTest(MemoryTest("test_memory_exceptions")) + suite.addTest(MemoryTest("test_memory_multiple_runs")) + suite.addTest(MemoryTest("test_memory_access_while_running")) + suite.addTest(MemoryTest("test_memory_access_methods")) + suite.addTest(MemoryFastTest("test_memory")) + suite.addTest(MemoryFastTest("test_mem_access_libs")) + suite.addTest(MemoryFastTest("test_memory_access_methods_backing_file")) + suite.addTest(MemoryFastTest("test_memory_exceptions")) + suite.addTest(MemoryFastTest("test_memory_multiple_runs")) + suite.addTest(MemoryFastTest("test_memory_access_while_running")) + suite.addTest(MemoryFastTest("test_memory_access_methods")) + suite.addTest(MemoryFastTest("test_memory_large_read")) + suite.addTest(MemoryFastTest("test_invalid_memory_location")) + suite.addTest(MemoryFastTest("test_memory_multiple_threads")) + suite.addTest(MemoryFastTest("test_memory_mixed_access")) + suite.addTest(HwBasicTest("test_basic")) + suite.addTest(HwBasicTest("test_registers")) + suite.addTest(BacktraceTest("test_backtrace_as_symbols")) + suite.addTest(BacktraceTest("test_backtrace")) + suite.addTest(AttachDetachTest("test_attach")) + suite.addTest(AttachDetachTest("test_attach_and_detach_1")) + suite.addTest(AttachDetachTest("test_attach_and_detach_2")) + suite.addTest(AttachDetachTest("test_attach_and_detach_3")) + suite.addTest(AttachDetachTest("test_attach_and_detach_4")) + suite.addTest(ThreadTest("test_thread")) + suite.addTest(ThreadTest("test_thread_hardware")) + suite.addTest(ComplexThreadTest("test_thread")) + suite.addTest(CallbackTest("test_callback_simple")) + suite.addTest(CallbackTest("test_callback_simple_hardware")) + suite.addTest(CallbackTest("test_callback_memory")) + suite.addTest(CallbackTest("test_callback_jumpout")) + suite.addTest(CallbackTest("test_callback_intermixing")) + suite.addTest(CallbackTest("test_callback_exception")) + suite.addTest(CallbackTest("test_callback_step")) + suite.addTest(CallbackTest("test_callback_pid_accessible")) + suite.addTest(CallbackTest("test_callback_pid_accessible_alias")) + suite.addTest(CallbackTest("test_callback_tid_accessible_alias")) + suite.addTest(FinishTest("test_finish_exact_no_auto_interrupt_no_breakpoint")) + suite.addTest(FinishTest("test_finish_heuristic_no_auto_interrupt_no_breakpoint")) + suite.addTest(FinishTest("test_finish_exact_auto_interrupt_no_breakpoint")) + suite.addTest(FinishTest("test_finish_heuristic_auto_interrupt_no_breakpoint")) + suite.addTest(FinishTest("test_finish_exact_no_auto_interrupt_breakpoint")) + suite.addTest(FinishTest("test_finish_heuristic_no_auto_interrupt_breakpoint")) + suite.addTest(FinishTest("test_heuristic_return_address")) + suite.addTest(FinishTest("test_exact_breakpoint_return")) + suite.addTest(FinishTest("test_heuristic_breakpoint_return")) + suite.addTest(FinishTest("test_breakpoint_collision")) + suite.addTest(FloatingPointTest("test_floating_point_reg_access")) + suite.addTest(Jumpout("test_jumpout")) + suite.addTest(Nlinks("test_nlinks")) + suite.addTest(JumpstartTest("test_cursed_ldpreload")) + suite.addTest(ControlFlowTest("test_step_until_1")) + suite.addTest(ControlFlowTest("test_step_until_2")) + suite.addTest(ControlFlowTest("test_step_until_3")) + suite.addTest(ControlFlowTest("test_step_and_cont")) + suite.addTest(ControlFlowTest("test_step_and_cont_hardware")) + suite.addTest(ControlFlowTest("test_step_until_and_cont")) + suite.addTest(ControlFlowTest("test_step_until_and_cont_hardware")) + suite.addTest(MultipleDebuggersTest("test_multiple_debuggers")) + suite.addTest(LargeBinarySymTest("test_large_binary_symbol_load_times")) + suite.addTest(LargeBinarySymTest("test_large_binary_demangle")) + suite.addTest(WaitingTest("test_bps_waiting")) + suite.addTest(WaitingTest("test_jumpout_waiting")) + suite.addTest(WaitingNlinks("test_nlinks")) + suite.addTest(AutoWaitingTest("test_bps_auto_waiting")) + suite.addTest(AutoWaitingTest("test_jumpout_auto_waiting")) + suite.addTest(NextTest("test_next")) + suite.addTest(NextTest("test_next_breakpoint")) + suite.addTest(NextTest("test_next_breakpoint_hw")) + suite.addTest(AutoWaitingNlinks("test_nlinks")) + suite.addTest(WatchpointTest("test_watchpoint")) + suite.addTest(WatchpointTest("test_watchpoint_callback")) + suite.addTest(WatchpointTest("test_watchpoint_disable")) + suite.addTest(WatchpointTest("test_watchpoint_disable_reenable")) + suite.addTest(WatchpointAliasTest("test_watchpoint_alias")) + suite.addTest(WatchpointAliasTest("test_watchpoint_callback")) + suite.addTest(HandleSyscallTest("test_handles")) + suite.addTest(HandleSyscallTest("test_handles_with_pprint")) + suite.addTest(HandleSyscallTest("test_handle_disabling")) + suite.addTest(HandleSyscallTest("test_handle_disabling_with_pprint")) + suite.addTest(HandleSyscallTest("test_handle_overwrite")) + suite.addTest(HandleSyscallTest("test_handle_overwrite_with_pprint")) + suite.addTest(HandleSyscallTest("test_handles_sync")) + suite.addTest(HandleSyscallTest("test_handles_sync_with_pprint")) + suite.addTest(AntidebugEscapingTest("test_antidebug_escaping")) + suite.addTest(SyscallHijackTest("test_hijack_syscall")) + suite.addTest(SyscallHijackTest("test_hijack_syscall_with_pprint")) + suite.addTest(SyscallHijackTest("test_hijack_handle_syscall")) + suite.addTest(SyscallHijackTest("test_hijack_handle_syscall_with_pprint")) + suite.addTest(SyscallHijackTest("test_hijack_syscall_args")) + suite.addTest(SyscallHijackTest("test_hijack_syscall_args_with_pprint")) + suite.addTest(SyscallHijackTest("test_hijack_syscall_wrong_args")) + suite.addTest(SyscallHijackTest("loop_detection_test")) + suite.addTest(PPrintSyscallsTest("test_pprint_syscalls_generic")) + suite.addTest(PPrintSyscallsTest("test_pprint_syscalls_with_statement")) + suite.addTest(PPrintSyscallsTest("test_pprint_handle_syscalls")) + suite.addTest(PPrintSyscallsTest("test_pprint_hijack_syscall")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_after")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_before")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_after_and_before")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_after")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_before")) + suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_after_and_before")) + suite.addTest(SignalCatchTest("test_signal_catch_signal_block")) + suite.addTest(SignalCatchTest("test_signal_pass_to_process")) + suite.addTest(SignalCatchTest("test_signal_disable_catch_signal")) + suite.addTest(SignalCatchTest("test_signal_unblock")) + suite.addTest(SignalCatchTest("test_signal_disable_catch_signal_unblock")) + suite.addTest(SignalCatchTest("test_hijack_signal_with_catch_signal")) + suite.addTest(SignalCatchTest("test_hijack_signal_with_api")) + suite.addTest(SignalCatchTest("test_recursive_true_with_catch_signal")) + suite.addTest(SignalCatchTest("test_recursive_true_with_api")) + suite.addTest(SignalCatchTest("test_recursive_false_with_catch_signal")) + suite.addTest(SignalCatchTest("test_recursive_false_with_api")) + suite.addTest(SignalCatchTest("test_hijack_signal_with_catch_signal_loop")) + suite.addTest(SignalCatchTest("test_hijack_signal_with_api_loop")) + suite.addTest(SignalCatchTest("test_signal_unhijacking")) + suite.addTest(SignalCatchTest("test_override_catch_signal")) + suite.addTest(SignalCatchTest("test_override_hijack")) + suite.addTest(SignalCatchTest("test_override_hybrid")) + suite.addTest(SignalCatchTest("test_signal_get_signal")) + suite.addTest(SignalCatchTest("test_signal_send_signal")) + suite.addTest(SignalCatchTest("test_signal_catch_sync_block")) + suite.addTest(SignalCatchTest("test_signal_catch_sync_pass")) + suite.addTest(SignalMultithreadTest("test_signal_multithread_undet_catch_signal_block")) + suite.addTest(SignalMultithreadTest("test_signal_multithread_undet_pass")) + suite.addTest(SignalMultithreadTest("test_signal_multithread_det_catch_signal_block")) + suite.addTest(SignalMultithreadTest("test_signal_multithread_det_pass")) + suite.addTest(SignalMultithreadTest("test_signal_multithread_send_signal")) + suite.addTest(DeathTest("test_io_death")) + suite.addTest(DeathTest("test_cont_death")) + suite.addTest(DeathTest("test_instr_death")) + suite.addTest(DeathTest("test_exit_signal_death")) + suite.addTest(DeathTest("test_exit_code_death")) + suite.addTest(DeathTest("test_exit_code_normal")) + suite.addTest(DeathTest("test_post_mortem_after_kill")) + suite.addTest(AliasTest("test_basic_alias")) + suite.addTest(AliasTest("test_step_alias")) + suite.addTest(AliasTest("test_step_until_alias")) + suite.addTest(AliasTest("test_memory_alias")) + suite.addTest(AliasTest("test_finish_alias")) + suite.addTest(AliasTest("test_waiting_alias")) + suite.addTest(AliasTest("test_interrupt_alias")) + suite.addTest(AtexitHandlerTest("test_attach_detach_1")) + suite.addTest(AtexitHandlerTest("test_attach_detach_2")) + suite.addTest(AtexitHandlerTest("test_attach_detach_3")) + suite.addTest(AtexitHandlerTest("test_attach_detach_4")) + suite.addTest(AtexitHandlerTest("test_attach_1")) + suite.addTest(AtexitHandlerTest("test_attach_2")) + suite.addTest(AtexitHandlerTest("test_run_1")) + suite.addTest(AtexitHandlerTest("test_run_2")) + suite.addTest(AtexitHandlerTest("test_run_3")) + suite.addTest(AtexitHandlerTest("test_run_4")) + return suite + + +def complete_suite(): + suite = fast_suite() + suite.addTest(Vmwhere1("test_vmwhere1")) + suite.addTest(Vmwhere1("test_vmwhere1_callback")) + suite.addTest(BruteTest("test_bruteforce")) + suite.addTest(CallbackTest("test_callback_bruteforce")) + suite.addTest(SpeedTest("test_speed")) + suite.addTest(SpeedTest("test_speed_hardware")) + suite.addTest(DeepDiveDivision("test_deep_dive_division")) + return suite + + +def thread_stress_suite(): + suite = unittest.TestSuite() + for _ in range(1024): + suite.addTest(ThreadTest("test_thread")) + suite.addTest(ThreadTest("test_thread_hardware")) + suite.addTest(ComplexThreadTest("test_thread")) + return suite + + +if __name__ == "__main__": + if sys.version_info >= (3, 12): + runner = unittest.TextTestRunner(verbosity=2, durations=3) + else: + runner = unittest.TextTestRunner(verbosity=2) + + if len(sys.argv) > 1 and sys.argv[1].lower() == "slow": + suite = complete_suite() + elif len(sys.argv) > 1 and sys.argv[1].lower() == "thread_stress": + suite = thread_stress_suite() + runner.verbosity = 1 + else: + suite = fast_suite() + + result = runner.run(suite) + + if result.wasSuccessful(): + print("All tests passed") + else: + print("Some tests failed") + print("\nFailed Tests:") + for test, err in result.failures: + print(f"{test}: {err}") + print("\nErrors:") + for test, err in result.errors: + print(f"{test}: {err}") diff --git a/test/scripts/alias_test.py b/test/amd64/scripts/alias_test.py similarity index 100% rename from test/scripts/alias_test.py rename to test/amd64/scripts/alias_test.py diff --git a/test/amd64/scripts/atexit_handler_test.py b/test/amd64/scripts/atexit_handler_test.py new file mode 100644 index 00000000..790fcd09 --- /dev/null +++ b/test/amd64/scripts/atexit_handler_test.py @@ -0,0 +1,242 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import os +import psutil +import signal +import unittest +from pwn import process + +from libdebug import debugger +from libdebug.debugger.internal_debugger_holder import _cleanup_internal_debugger + + +class AtexitHandlerTest(unittest.TestCase): + def test_run_1(self): + d = debugger("binaries/infinite_loop_test") + + r = d.run() + + pid = d.pid + + d.cont() + + r.sendline(b"3") + + _cleanup_internal_debugger() + + # The process should have been killed + self.assertNotIn(pid, psutil.pids()) + + def test_run_2(self): + d = debugger("binaries/infinite_loop_test", kill_on_exit=False) + + r = d.run() + + pid = d.pid + + d.cont() + + r.sendline(b"3") + + _cleanup_internal_debugger() + + # The process should not have been killed + self.assertIn(pid, psutil.pids()) + + # We can actually still use the debugger + d.interrupt() + d.kill() + + # The process should now be dead + self.assertNotIn(pid, psutil.pids()) + + def test_run_3(self): + d = debugger("binaries/infinite_loop_test", kill_on_exit=False) + + r = d.run() + + pid = d.pid + + d.cont() + + r.sendline(b"3") + + d.kill_on_exit = True + + _cleanup_internal_debugger() + + # The process should have been killed + self.assertNotIn(pid, psutil.pids()) + + def test_run_4(self): + d = debugger("binaries/infinite_loop_test") + + r = d.run() + + pid = d.pid + + d.cont() + + d.kill_on_exit = False + + r.sendline(b"3") + + _cleanup_internal_debugger() + + # The process should not have been killed + self.assertIn(pid, psutil.pids()) + + # We can actually still use the debugger + d.interrupt() + d.kill() + + # The process should now be dead + self.assertNotIn(pid, psutil.pids()) + + def test_attach_detach_1(self): + p = process("binaries/infinite_loop_test") + + d = debugger() + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + d.detach() + + # If the process is still running, poll() should return None + self.assertIsNone(p.poll(block=False)) + + _cleanup_internal_debugger() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + def test_attach_detach_2(self): + p = process("binaries/infinite_loop_test") + + d = debugger(kill_on_exit=False) + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + d.detach() + + # If the process is still running, poll() should return None + self.assertIsNone(p.poll(block=False)) + + _cleanup_internal_debugger() + + # We set kill_on_exit to False, so the process should still be alive + # The process should still be alive + self.assertIsNone(p.poll(block=False)) + + p.kill() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + def test_attach_detach_3(self): + p = process("binaries/infinite_loop_test") + + d = debugger(kill_on_exit=False) + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + d.detach() + + # If the process is still running, poll() should return None + self.assertIsNone(p.poll(block=False)) + + d.kill_on_exit = True + + _cleanup_internal_debugger() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + def test_attach_detach_4(self): + p = process("binaries/infinite_loop_test") + + d = debugger() + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + d.detach() + + # If the process is still running, poll() should return None + self.assertIsNone(p.poll(block=False)) + + d.kill_on_exit = False + + _cleanup_internal_debugger() + + # We set kill_on_exit to False, so the process should still be alive + # The process should still be alive + self.assertIsNone(p.poll(block=False)) + + p.kill() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + def test_attach_1(self): + p = process("binaries/infinite_loop_test") + + d = debugger() + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + # If the process is still running, poll() should return None + self.assertIsNone(p.poll(block=False)) + + _cleanup_internal_debugger() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + def test_attach_2(self): + p = process("binaries/infinite_loop_test") + + d = debugger() + + d.attach(p.pid) + + p.sendline(b"3") + + d.step() + d.step() + + p.kill() + + # The process should now be dead + self.assertIsNotNone(p.poll(block=False)) + + # Even if we kill the process, the next call should not raise an exception + _cleanup_internal_debugger() diff --git a/test/scripts/attach_detach_test.py b/test/amd64/scripts/attach_detach_test.py similarity index 100% rename from test/scripts/attach_detach_test.py rename to test/amd64/scripts/attach_detach_test.py diff --git a/test/scripts/auto_waiting_test.py b/test/amd64/scripts/auto_waiting_test.py similarity index 100% rename from test/scripts/auto_waiting_test.py rename to test/amd64/scripts/auto_waiting_test.py diff --git a/test/scripts/backtrace_test.py b/test/amd64/scripts/backtrace_test.py similarity index 100% rename from test/scripts/backtrace_test.py rename to test/amd64/scripts/backtrace_test.py diff --git a/test/scripts/basic_test.py b/test/amd64/scripts/basic_test.py similarity index 100% rename from test/scripts/basic_test.py rename to test/amd64/scripts/basic_test.py diff --git a/test/scripts/breakpoint_test.py b/test/amd64/scripts/breakpoint_test.py similarity index 100% rename from test/scripts/breakpoint_test.py rename to test/amd64/scripts/breakpoint_test.py diff --git a/test/scripts/brute_test.py b/test/amd64/scripts/brute_test.py similarity index 100% rename from test/scripts/brute_test.py rename to test/amd64/scripts/brute_test.py diff --git a/test/scripts/builtin_handler_test.py b/test/amd64/scripts/builtin_handler_test.py similarity index 100% rename from test/scripts/builtin_handler_test.py rename to test/amd64/scripts/builtin_handler_test.py diff --git a/test/scripts/callback_test.py b/test/amd64/scripts/callback_test.py similarity index 89% rename from test/scripts/callback_test.py rename to test/amd64/scripts/callback_test.py index 9b95f3c6..0e716e2d 100644 --- a/test/scripts/callback_test.py +++ b/test/amd64/scripts/callback_test.py @@ -117,6 +117,45 @@ def callback(thread, bp): if self.exceptions: raise self.exceptions[0] + def test_callback_fast_memory(self): + self.exceptions.clear() + + global hit + hit = False + + d = debugger("binaries/memory_test", fast_memory=True) + + d.run() + + def callback(thread, bp): + global hit + + prev = bytes(range(256)) + try: + self.assertEqual(bp.address, thread.regs.rip) + self.assertEqual(bp.hit_count, 1) + self.assertEqual(thread.memory[thread.regs.rdi, 256], prev) + + thread.memory[thread.regs.rdi + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertEqual(thread.memory[thread.regs.rdi, 256], prev) + except Exception as e: + self.exceptions.append(e) + + hit = True + + d.breakpoint("change_memory", callback=callback) + + d.cont() + + d.kill() + + self.assertTrue(hit) + + if self.exceptions: + raise self.exceptions[0] + def test_callback_bruteforce(self): global flag global counter @@ -303,7 +342,7 @@ def callback(t, bp): d.kill() self.assertTrue(hit) - + def test_callback_pid_accessible_alias(self): self.exceptions.clear() @@ -325,7 +364,7 @@ def callback(t, bp): d.kill() self.assertTrue(hit) - + def test_callback_tid_accessible_alias(self): self.exceptions.clear() diff --git a/test/scripts/catch_signal_test.py b/test/amd64/scripts/catch_signal_test.py similarity index 100% rename from test/scripts/catch_signal_test.py rename to test/amd64/scripts/catch_signal_test.py diff --git a/test/scripts/death_test.py b/test/amd64/scripts/death_test.py similarity index 100% rename from test/scripts/death_test.py rename to test/amd64/scripts/death_test.py diff --git a/test/scripts/deep_dive_division_test.py b/test/amd64/scripts/deep_dive_division_test.py similarity index 100% rename from test/scripts/deep_dive_division_test.py rename to test/amd64/scripts/deep_dive_division_test.py diff --git a/test/scripts/finish_test.py b/test/amd64/scripts/finish_test.py similarity index 97% rename from test/scripts/finish_test.py rename to test/amd64/scripts/finish_test.py index ae120fe3..1c6337d4 100644 --- a/test/scripts/finish_test.py +++ b/test/amd64/scripts/finish_test.py @@ -229,24 +229,24 @@ def test_heuristic_return_address(self): self.assertEqual(d.regs.rip, C_ADDRESS) - stack_unwinder = stack_unwinding_provider() + stack_unwinder = stack_unwinding_provider(d._internal_debugger.arch) # We need to repeat the check for the three stages of the function preamble # Get current return address - curr_srip = stack_unwinder.get_return_address(d) + curr_srip = d.saved_ip self.assertEqual(curr_srip, RETURN_POINT_FROM_C) d.step() # Get current return address - curr_srip = stack_unwinder.get_return_address(d) + curr_srip = d.saved_ip self.assertEqual(curr_srip, RETURN_POINT_FROM_C) d.step() # Get current return address - curr_srip = stack_unwinder.get_return_address(d) + curr_srip = d.saved_ip self.assertEqual(curr_srip, RETURN_POINT_FROM_C) d.kill() diff --git a/test/amd64/scripts/floating_point_test.py b/test/amd64/scripts/floating_point_test.py new file mode 100644 index 00000000..733c725e --- /dev/null +++ b/test/amd64/scripts/floating_point_test.py @@ -0,0 +1,285 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest +from pathlib import Path +from random import randint + +from libdebug import debugger + + +class FloatingPointTest(unittest.TestCase): + def test_floating_point_reg_access(self): + # This test is divided into two parts, depending on the current hardware + + # Let's check if we have AVX512 + with Path("/proc/cpuinfo").open() as f: + cpuinfo = f.read() + + if "avx512" in cpuinfo: + # Run an AVX512 test + self.avx512() + self.avx() + self.mmx() + elif "avx" in cpuinfo: + # Run an AVX test + self.avx() + self.mmx() + else: + # Run a generic test + self.mmx() + + def avx512(self): + d = debugger("binaries/floating_point_2696_test") + + d.run() + + bp1 = d.bp(0x40143E) + bp2 = d.bp(0x401467) + + d.cont() + + self.assertTrue(bp1.hit_on(d)) + + self.assertTrue(hasattr(d.regs, "xmm0")) + self.assertTrue(hasattr(d.regs, "xmm31")) + self.assertTrue(hasattr(d.regs, "ymm0")) + self.assertTrue(hasattr(d.regs, "ymm31")) + self.assertTrue(hasattr(d.regs, "zmm0")) + self.assertTrue(hasattr(d.regs, "zmm31")) + + baseval = int.from_bytes(bytes(list(range(64))), "little") + + for i in range(32): + self.assertEqual(getattr(d.regs, f"xmm{i}"), baseval & ((1 << 128) - 1)) + self.assertEqual(getattr(d.regs, f"ymm{i}"), baseval & ((1 << 256) - 1)) + self.assertEqual(getattr(d.regs, f"zmm{i}"), baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 504) + + d.regs.zmm0 = 0xDEADBEEFDEADBEEF + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + for i in range(32): + val = randint(0, 2**512 - 1) + setattr(d.regs, f"zmm{i}", val) + self.assertEqual(getattr(d.regs, f"zmm{i}"), val) + + d.kill() + + def avx(self): + d = debugger("binaries/floating_point_896_test") + + d.run() + + bp1 = d.bp(0x40159E) + bp2 = d.bp(0x4015C5) + + d.cont() + + self.assertTrue(bp1.hit_on(d)) + + self.assertTrue(hasattr(d.regs, "xmm0")) + self.assertTrue(hasattr(d.regs, "ymm0")) + self.assertTrue(hasattr(d.regs, "xmm15")) + self.assertTrue(hasattr(d.regs, "ymm15")) + + baseval = int.from_bytes(bytes(list(range(0, 256, 17)) + list(range(16))), "little") + + self.assertEqual(d.regs.ymm0, baseval) + self.assertEqual(d.regs.xmm0, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm1, baseval) + self.assertEqual(d.regs.xmm1, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm2, baseval) + self.assertEqual(d.regs.xmm2, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm3, baseval) + self.assertEqual(d.regs.xmm3, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm4, baseval) + self.assertEqual(d.regs.xmm4, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm5, baseval) + self.assertEqual(d.regs.xmm5, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm6, baseval) + self.assertEqual(d.regs.xmm6, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm7, baseval) + self.assertEqual(d.regs.xmm7, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm8, baseval) + self.assertEqual(d.regs.xmm8, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm9, baseval) + self.assertEqual(d.regs.xmm9, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm10, baseval) + self.assertEqual(d.regs.xmm10, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm11, baseval) + self.assertEqual(d.regs.xmm11, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm12, baseval) + self.assertEqual(d.regs.xmm12, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm13, baseval) + self.assertEqual(d.regs.xmm13, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm14, baseval) + self.assertEqual(d.regs.xmm14, baseval & ((1 << 128) - 1)) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + self.assertEqual(d.regs.ymm15, baseval) + self.assertEqual(d.regs.xmm15, baseval & ((1 << 128) - 1)) + + d.regs.ymm0 = 0xDEADBEEFDEADBEEF + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + for i in range(16): + val = randint(0, 2**256 - 1) + setattr(d.regs, f"ymm{i}", val) + self.assertEqual(getattr(d.regs, f"xmm{i}"), val & ((1 << 128) - 1)) + self.assertEqual(getattr(d.regs, f"ymm{i}"), val) + + # validate that register states are correctly flushed and then restored + values = [] + + for i in range(16): + val = randint(0, 2**256 - 1) + setattr(d.regs, f"ymm{i}", val) + values.append(val) + + d.step() + + for i in range(16): + self.assertEqual(getattr(d.regs, f"ymm{i}"), values[i]) + + d.regs.ymm7 = 0xDEADBEEFDEADBEEF + + for i in range(16): + if i == 7: + continue + + self.assertEqual(getattr(d.regs, f"ymm{i}"), values[i]) + + d.step() + + for i in range(16): + if i == 7: + continue + + self.assertEqual(getattr(d.regs, f"ymm{i}"), values[i]) + + self.assertEqual(d.regs.ymm7, 0xDEADBEEFDEADBEEF) + + d.kill() + + def callback(t, _): + baseval = int.from_bytes(bytes(list(range(0, 256, 17)) + list(range(16))), "little") + for i in range(16): + self.assertEqual(getattr(d.regs, f"xmm{i}"), baseval & ((1 << 128) - 1)) + self.assertEqual(getattr(d.regs, f"ymm{i}"), baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 248) + + t.regs.ymm0 = 0xDEADBEEFDEADBEEF + + d.run() + + d.bp(0x40159E, callback=callback) + bp = d.bp(0x4015C5) + + d.cont() + + self.assertTrue(bp.hit_on(d)) + + d.kill() + + def mmx(self): + d = debugger("binaries/floating_point_512_test") + + d.run() + + bp1 = d.bp(0x401372) + bp2 = d.bp(0x401399) + + d.cont() + + self.assertTrue(bp1.hit_on(d)) + + self.assertTrue(hasattr(d.regs, "xmm0")) + self.assertTrue(hasattr(d.regs, "xmm15")) + + baseval = int.from_bytes(bytes(list(range(0, 256, 17))), "little") + self.assertEqual(d.regs.xmm0, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm1, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm2, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm3, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm4, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm5, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm6, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm7, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm8, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm9, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm10, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm11, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm12, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm13, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm14, baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + self.assertEqual(d.regs.xmm15, baseval) + + d.regs.xmm0 = 0xDEADBEEFDEADBEEF + + d.cont() + + self.assertTrue(bp2.hit_on(d)) + + for i in range(16): + val = randint(0, 2**128 - 1) + setattr(d.regs, f"xmm{i}", val) + self.assertEqual(getattr(d.regs, f"xmm{i}"), val) + + d.kill() + + def callback(t, _): + baseval = int.from_bytes(bytes(list(range(0, 256, 17))), "little") + for i in range(16): + self.assertEqual(getattr(d.regs, f"xmm{i}"), baseval) + baseval = (baseval >> 8) + ((baseval & 255) << 120) + + t.regs.xmm0 = 0xDEADBEEFDEADBEEF + + d.run() + + d.bp(0x401372, callback=callback) + bp = d.bp(0x401399) + + d.cont() + + self.assertTrue(bp.hit_on(d)) + + d.kill() diff --git a/test/scripts/handle_syscall_test.py b/test/amd64/scripts/handle_syscall_test.py similarity index 100% rename from test/scripts/handle_syscall_test.py rename to test/amd64/scripts/handle_syscall_test.py diff --git a/test/scripts/hijack_syscall_test.py b/test/amd64/scripts/hijack_syscall_test.py similarity index 100% rename from test/scripts/hijack_syscall_test.py rename to test/amd64/scripts/hijack_syscall_test.py diff --git a/test/scripts/jumpout_test.py b/test/amd64/scripts/jumpout_test.py similarity index 100% rename from test/scripts/jumpout_test.py rename to test/amd64/scripts/jumpout_test.py diff --git a/test/amd64/scripts/jumpstart_test.py b/test/amd64/scripts/jumpstart_test.py new file mode 100644 index 00000000..d5e1d7a3 --- /dev/null +++ b/test/amd64/scripts/jumpstart_test.py @@ -0,0 +1,24 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger + +class JumpstartTest(unittest.TestCase): + + def test_cursed_ldpreload(self): + d = debugger("binaries/jumpstart_test", env={"LD_PRELOAD": "binaries/jumpstart_test_preload.so"}) + + r = d.run() + + d.cont() + + self.assertEqual(r.recvline(), b"Preload library loaded") + self.assertEqual(r.recvline(), b"Jumpstart test") + self.assertEqual(r.recvline(), b"execve(/bin/ls, (nil), (nil))") + + d.kill() diff --git a/test/scripts/large_binary_sym_test.py b/test/amd64/scripts/large_binary_sym_test.py similarity index 100% rename from test/scripts/large_binary_sym_test.py rename to test/amd64/scripts/large_binary_sym_test.py diff --git a/test/amd64/scripts/memory_fast_test.py b/test/amd64/scripts/memory_fast_test.py new file mode 100644 index 00000000..80bcd7a9 --- /dev/null +++ b/test/amd64/scripts/memory_fast_test.py @@ -0,0 +1,391 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import unittest + +from libdebug import debugger, libcontext + + +class MemoryFastTest(unittest.TestCase): + def test_memory(self): + d = debugger("binaries/memory_test", fast_memory=True) + + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + d.terminate() + + def test_mem_access_libs(self): + d = debugger("binaries/memory_test", fast_memory=True) + + d.run() + + bp = d.breakpoint("leak_address") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + with libcontext.tmp(sym_lvl=5): + arena = d.memory["main_arena", 256, "libc"] + + def p64(x): + return x.to_bytes(8, "little") + + self.assertTrue(p64(address - 0x10) in arena) + + d.kill() + d.terminate() + + def test_memory_exceptions(self): + d = debugger("binaries/memory_test", fast_memory=True) + + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + # This should not raise an exception + file = d.memory[0x0, 256] + + # File should start with ELF magic number + self.assertTrue(file.startswith(b"\x7fELF")) + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + d.terminate() + + def test_memory_multiple_runs(self): + d = debugger("binaries/memory_test", fast_memory=True) + + for _ in range(10): + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + + d.terminate() + + def test_memory_access_while_running(self): + d = debugger("binaries/memory_test_2", fast_memory=True) + + d.run() + + bp = d.breakpoint("do_nothing") + + d.cont() + + # Verify that memory access is only possible when the process is stopped + value = int.from_bytes(d.memory["state", 8], "little") + self.assertEqual(value, 0xDEADBEEF) + self.assertEqual(d.regs.rip, bp.address) + + d.kill() + d.terminate() + + def test_memory_access_methods(self): + d = debugger("binaries/memory_test_2", fast_memory=True) + + d.run() + + base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + + # Test different ways to access memory at the start of the file + file_0 = d.memory[base, 256] + file_1 = d.memory[0x0, 256] + file_2 = d.memory[0x0:0x100] + + self.assertEqual(file_0, file_1) + self.assertEqual(file_0, file_2) + + # Validate that the length of the read bytes is correct + file_0 = d.memory[0x0] + file_1 = d.memory[base] + + self.assertEqual(file_0, file_1) + self.assertEqual(len(file_0), 1) + + # Validate that slices work correctly + file_0 = d.memory[0x0:"do_nothing"] + file_1 = d.memory[base:"do_nothing"] + + self.assertEqual(file_0, file_1) + + self.assertRaises(ValueError, lambda: d.memory[0x1000:0x0]) + # _fini is after main + self.assertRaises(ValueError, lambda: d.memory["_fini":"main"]) + + # Test different ways to write memory + d.memory[0x0, 8] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8], b"abcd1234") + + d.memory[0x0, 8] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[base:] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.memory[base:] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[base] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.memory[base] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8], b"abcd1234") + + d.memory[0x0, 8] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.memory["main":] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main"] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.memory["main"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8"] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.kill() + d.terminate() + + def test_memory_access_methods_backing_file(self): + d = debugger("binaries/memory_test_2", fast_memory=True) + + d.run() + + base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + + # Validate that slices work correctly + file_0 = d.memory[0x0:"do_nothing", "binary"] + file_1 = d.memory[0x0:"do_nothing", "memory_test_2"] + file_2 = d.memory[base:"do_nothing", "binary"] + file_3 = d.memory[base:"do_nothing", "memory_test_2"] + + self.assertEqual(file_0, file_1) + self.assertEqual(file_1, file_2) + self.assertEqual(file_2, file_3) + + # Test different ways to write memory + d.memory[0x0, 8, "binary"] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0, 8, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8, "binary"] = b"abcd1234" + self.assertEqual(d.memory[0x0:8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory[0x0:8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":, "binary"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory["main":, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory["main":, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main", "binary"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main", "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "binary"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "hybrid"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "hybrid"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with self.assertRaises(ValueError): + d.memory["main":"main+8", "absolute"] = b"abcd1234" + + d.kill() + d.terminate() + + def test_memory_large_read(self): + d = debugger("binaries/memory_test_3", fast_memory=True) + + d.run() + + bp = d.bp("do_nothing") + + d.cont() + + assert bp.hit_on(d) + + leak = d.regs.rdi + + # Read 4MB of memory + data = d.memory[leak, 4 * 1024 * 1024] + + assert data == b"".join(x.to_bytes(4, "little") for x in range(1024 * 1024)) + + d.kill() + d.terminate() + + def test_invalid_memory_location(self): + d = debugger("binaries/memory_test", fast_memory=True) + + d.run() + + bp = d.bp("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = 0xDEADBEEFD00D + + with self.assertRaises(ValueError): + d.memory[address, 256, "absolute"] + + d.kill() + d.terminate() + + def test_memory_multiple_threads(self): + d = debugger("binaries/memory_test_4", fast_memory=True) + + d.run() + + leaks = [] + leak_addresses = [] + + def leak(t, _): + leaks.append(t.memory[t.regs.rdi, 16]) + leak_addresses.append(t.regs.rdi) + + d.bp("leak", callback=leak, hardware=True) + exit = d.bp("before_exit", hardware=True) + + d.cont() + d.wait() + + assert exit.hit_on(d) + + for i in range(8): + assert (chr(i).encode("latin-1") * 16) in leaks + + leaks = [d.memory[x, 16] for x in leak_addresses] + + # threads are stopped, check we correctly read the memory + for i in range(8): + assert (chr(i).encode("latin-1") * 16) in leaks + + d.kill() + d.terminate() + + def test_memory_mixed_access(self): + d = debugger("binaries/memory_test_2", fast_memory=True) + + d.run() + + base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + + # Test different ways to access memory at the start of the file + file_0 = d.memory[base, 256] + d.fast_memory = False + file_1 = d.memory[0x0, 256] + d.fast_memory = True + file_2 = d.memory[0x0:0x100] + d.fast_memory = False + file_3 = d.memory[0x0:0x100] + + self.assertEqual(file_0, file_1) + self.assertEqual(file_0, file_2) + self.assertEqual(file_0, file_3) + + for _ in range(3): + d.step() + + d.fast_memory = False + d.memory[base] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.fast_memory = True + self.assertEqual(d.memory[base, 8], b"abcd1234") + d.memory[base] = b"\x01\x02\x03\x04\x05\x06\x07\x08" + self.assertEqual(d.memory[base, 8], b"\x01\x02\x03\x04\x05\x06\x07\x08") + + d.fast_memory = False + self.assertEqual(d.memory[base, 8], b"\x01\x02\x03\x04\x05\x06\x07\x08") + d.memory[base] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.kill() + d.terminate() diff --git a/test/amd64/scripts/memory_test.py b/test/amd64/scripts/memory_test.py new file mode 100644 index 00000000..febbf77e --- /dev/null +++ b/test/amd64/scripts/memory_test.py @@ -0,0 +1,363 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2023-2024 Gabriele Digregorio, Roberto Alessandro Bertolini. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# + +import io +import logging +import unittest + +from libdebug import debugger, libcontext + + +class MemoryTest(unittest.TestCase): + def setUp(self) -> None: + self.d = debugger("binaries/memory_test") + + # Redirect logging to a string buffer + self.log_capture_string = io.StringIO() + self.log_handler = logging.StreamHandler(self.log_capture_string) + self.log_handler.setLevel(logging.WARNING) + + self.logger = logging.getLogger("libdebug") + self.original_handlers = self.logger.handlers + self.logger.handlers = [] + self.logger.addHandler(self.log_handler) + self.logger.setLevel(logging.WARNING) + + def test_memory(self): + d = self.d + + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + + def test_mem_access_libs(self): + d = self.d + + d.run() + + bp = d.breakpoint("leak_address") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + with libcontext.tmp(sym_lvl=5): + arena = d.memory["main_arena", 256, "libc"] + + def p64(x): + return x.to_bytes(8, "little") + + self.assertTrue(p64(address - 0x10) in arena) + + d.kill() + + def test_memory_exceptions(self): + d = self.d + + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + # This should not raise an exception + file = d.memory[0x0, 256] + + # File should start with ELF magic number + self.assertTrue(file.startswith(b"\x7fELF")) + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + + def test_memory_multiple_runs(self): + d = self.d + + for _ in range(10): + d.run() + + bp = d.breakpoint("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = d.regs.rdi + prev = bytes(range(256)) + + self.assertTrue(d.memory[address, 256] == prev) + + d.memory[address + 128 :] = b"abcd123456" + prev = prev[:128] + b"abcd123456" + prev[138:] + + self.assertTrue(d.memory[address : address + 256] == prev) + + d.kill() + + def test_memory_access_while_running(self): + d = debugger("binaries/memory_test_2") + + d.run() + + bp = d.breakpoint("do_nothing") + + d.cont() + + # Verify that memory access is only possible when the process is stopped + value = int.from_bytes(d.memory["state", 8], "little") + self.assertEqual(value, 0xDEADBEEF) + self.assertEqual(d.regs.rip, bp.address) + + d.kill() + + def test_memory_access_methods(self): + d = debugger("binaries/memory_test_2") + + d.run() + + base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + + # Test different ways to access memory at the start of the file + file_0 = d.memory[base, 256] + file_1 = d.memory[0x0, 256] + file_2 = d.memory[0x0:0x100] + + self.assertEqual(file_0, file_1) + self.assertEqual(file_0, file_2) + + # Validate that the length of the read bytes is correct + file_0 = d.memory[0x0] + file_1 = d.memory[base] + + self.assertEqual(file_0, file_1) + self.assertEqual(len(file_0), 1) + + # Validate that slices work correctly + file_0 = d.memory[0x0:"do_nothing"] + file_1 = d.memory[base:"do_nothing"] + + self.assertEqual(file_0, file_1) + + self.assertRaises(ValueError, lambda: d.memory[0x1000:0x0]) + # _fini is after main + self.assertRaises(ValueError, lambda: d.memory["_fini":"main"]) + + # Test different ways to write memory + d.memory[0x0, 8] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8], b"abcd1234") + + d.memory[0x0, 8] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[base:] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.memory[base:] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[base] = b"abcd1234" + self.assertEqual(d.memory[base, 8], b"abcd1234") + + d.memory[base] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8], b"abcd1234") + + d.memory[0x0, 8] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.memory["main":] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main"] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.memory["main"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8"] = b"abcd1234" + self.assertEqual(d.memory["main", 8], b"abcd1234") + + d.kill() + + def test_memory_access_methods_backing_file(self): + d = debugger("binaries/memory_test_2") + + d.run() + + base = d.regs.rip & 0xFFFFFFFFFFFFF000 - 0x1000 + + # Validate that slices work correctly + file_0 = d.memory[0x0:"do_nothing", "binary"] + file_1 = d.memory[0x0:"do_nothing", "memory_test_2"] + file_2 = d.memory[base:"do_nothing", "binary"] + file_3 = d.memory[base:"do_nothing", "memory_test_2"] + + self.assertEqual(file_0, file_1) + self.assertEqual(file_1, file_2) + self.assertEqual(file_2, file_3) + + # Test different ways to write memory + d.memory[0x0, 8, "binary"] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0, 8, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory[0x0, 8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8, "binary"] = b"abcd1234" + self.assertEqual(d.memory[0x0:8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory[0x0:0x8, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory[0x0:8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":, "binary"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory["main":, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":, "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory["main":, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main", "binary"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main", "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main", 8, "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "memory_test_2"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "binary"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "binary"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "memory_test_2"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "memory_test_2"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + d.memory["main":"main+8", "hybrid"] = b"abcd1234" + self.assertEqual(d.memory["main":"main+8", "hybrid"], b"abcd1234") + + d.memory[0x0, 8, "binary"] = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with self.assertRaises(ValueError): + d.memory["main":"main+8", "absolute"] = b"abcd1234" + + d.kill() + + def test_memory_large_read(self): + d = debugger("binaries/memory_test_3") + + d.run() + + bp = d.bp("do_nothing") + + d.cont() + + assert bp.hit_on(d) + + leak = d.regs.rdi + + # Read 256K of memory + data = d.memory[leak, 256 * 1024] + + assert data == b"".join(x.to_bytes(4, "little") for x in range(64 * 1024)) + + d.kill() + d.terminate() + + def test_invalid_memory_location(self): + d = debugger("binaries/memory_test") + + d.run() + + bp = d.bp("change_memory") + + d.cont() + + assert d.regs.rip == bp.address + + address = 0xDEADBEEFD00D + + with self.assertRaises(ValueError): + d.memory[address, 256, "absolute"] + + d.kill() + d.terminate() + + def test_memory_multiple_threads(self): + d = debugger("binaries/memory_test_4") + + d.run() + + leaks = [] + leak_addresses = [] + + def leak(t, _): + leaks.append(t.memory[t.regs.rdi, 16]) + leak_addresses.append(t.regs.rdi) + + d.bp("leak", callback=leak, hardware=True) + exit = d.bp("before_exit", hardware=True) + + d.cont() + d.wait() + + assert exit.hit_on(d) + + for i in range(8): + assert (chr(i).encode("latin-1") * 16) in leaks + + leaks = [d.memory[x, 16] for x in leak_addresses] + + # threads are stopped, check we correctly read the memory + for i in range(8): + assert (chr(i).encode("latin-1") * 16) in leaks + + d.kill() + d.terminate() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/scripts/multiple_debuggers_test.py b/test/amd64/scripts/multiple_debuggers_test.py similarity index 100% rename from test/scripts/multiple_debuggers_test.py rename to test/amd64/scripts/multiple_debuggers_test.py diff --git a/test/amd64/scripts/next_test.py b/test/amd64/scripts/next_test.py new file mode 100644 index 00000000..af9096d0 --- /dev/null +++ b/test/amd64/scripts/next_test.py @@ -0,0 +1,111 @@ +# +# This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +# Copyright (c) 2024 Francesco Panebianco. All rights reserved. +# Licensed under the MIT license. See LICENSE file in the project root for details. +# +import unittest + +from libdebug import debugger + +TEST_ENTRYPOINT = 0x4011f8 + +# Addresses of the dummy functions +CALL_C_ADDRESS = 0x4011fd +TEST_BREAKPOINT_ADDRESS = 0x4011f1 + +# Addresses of noteworthy instructions +RETURN_POINT_FROM_C = 0x401202 + +class NextTest(unittest.TestCase): + def setUp(self): + pass + + def test_next(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.rip, TEST_ENTRYPOINT) + + # -------- Block 1 ------- # + # Simple Step # + # ------------------------ # + + # Reach call of function c + d.next() + self.assertEqual(d.regs.rip, CALL_C_ADDRESS) + + # -------- Block 2 ------- # + # Skip a call # + # ------------------------ # + + d.next() + self.assertEqual(d.regs.rip, RETURN_POINT_FROM_C) + + d.kill() + d.terminate() + + def test_next_breakpoint(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.rip, TEST_ENTRYPOINT) + + # Reach call of function c + d.next() + + self.assertEqual(d.regs.rip, CALL_C_ADDRESS) + + # -------- Block 1 ------- # + # Call with breakpoint # + # ------------------------ # + + # Set breakpoint + test_breakpoint = d.breakpoint(TEST_BREAKPOINT_ADDRESS) + + d.next() + + # Check we hit the breakpoint + self.assertEqual(d.regs.rip, TEST_BREAKPOINT_ADDRESS) + self.assertEqual(test_breakpoint.hit_count, 1) + + d.kill() + d.terminate() + + def test_next_breakpoint_hw(self): + d = debugger("binaries/finish_test", auto_interrupt_on_command=False) + d.run() + + # Get to test entrypoint + entrypoint_bp = d.breakpoint(TEST_ENTRYPOINT) + d.cont() + + self.assertEqual(d.regs.rip, TEST_ENTRYPOINT) + + # Reach call of function c + d.next() + + self.assertEqual(d.regs.rip, CALL_C_ADDRESS) + + # -------- Block 1 ------- # + # Call with breakpoint # + # ------------------------ # + + # Set breakpoint + test_breakpoint = d.breakpoint(TEST_BREAKPOINT_ADDRESS, hardware=True) + + d.next() + + # Check we hit the breakpoint + self.assertEqual(d.regs.rip, TEST_BREAKPOINT_ADDRESS) + self.assertEqual(test_breakpoint.hit_count, 1) + + d.kill() + d.terminate() \ No newline at end of file diff --git a/test/scripts/nlinks_test.py b/test/amd64/scripts/nlinks_test.py similarity index 100% rename from test/scripts/nlinks_test.py rename to test/amd64/scripts/nlinks_test.py diff --git a/test/scripts/pprint_syscalls_test.py b/test/amd64/scripts/pprint_syscalls_test.py similarity index 100% rename from test/scripts/pprint_syscalls_test.py rename to test/amd64/scripts/pprint_syscalls_test.py diff --git a/test/scripts/signals_multithread_test.py b/test/amd64/scripts/signals_multithread_test.py similarity index 100% rename from test/scripts/signals_multithread_test.py rename to test/amd64/scripts/signals_multithread_test.py diff --git a/test/scripts/speed_test.py b/test/amd64/scripts/speed_test.py similarity index 100% rename from test/scripts/speed_test.py rename to test/amd64/scripts/speed_test.py diff --git a/test/scripts/thread_test.py b/test/amd64/scripts/thread_test.py similarity index 100% rename from test/scripts/thread_test.py rename to test/amd64/scripts/thread_test.py diff --git a/test/scripts/vmwhere1_test.py b/test/amd64/scripts/vmwhere1_test.py similarity index 100% rename from test/scripts/vmwhere1_test.py rename to test/amd64/scripts/vmwhere1_test.py diff --git a/test/scripts/waiting_test.py b/test/amd64/scripts/waiting_test.py similarity index 100% rename from test/scripts/waiting_test.py rename to test/amd64/scripts/waiting_test.py diff --git a/test/scripts/watchpoint_alias_test.py b/test/amd64/scripts/watchpoint_alias_test.py similarity index 100% rename from test/scripts/watchpoint_alias_test.py rename to test/amd64/scripts/watchpoint_alias_test.py diff --git a/test/scripts/watchpoint_test.py b/test/amd64/scripts/watchpoint_test.py similarity index 100% rename from test/scripts/watchpoint_test.py rename to test/amd64/scripts/watchpoint_test.py diff --git a/test/srcs/basic_test.c b/test/amd64/srcs/basic_test.c similarity index 100% rename from test/srcs/basic_test.c rename to test/amd64/srcs/basic_test.c diff --git a/test/amd64/srcs/floating_point_2696_test.c b/test/amd64/srcs/floating_point_2696_test.c new file mode 100644 index 00000000..ded32393 --- /dev/null +++ b/test/amd64/srcs/floating_point_2696_test.c @@ -0,0 +1,100 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +void rotate(char value[64]) +{ + char temp = value[0]; + for (int i = 0; i < 63; i++) { + value[i] = value[i + 1]; + } + value[63] = temp; +} + +int main() +{ + char value[64]; + + for (int i = 0; i < 64; i++) { + value[i] = i; + } + + __asm__ __volatile__("vmovdqu8 %0, %%zmm0" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm1" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm2" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm3" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm4" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm5" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm6" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm7" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm8" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm9" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm10" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm11" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm12" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm13" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm14" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm15" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm16" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm17" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm18" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm19" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm20" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm21" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm22" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm23" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm24" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm25" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm26" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm27" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm28" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm29" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm30" : : "m" (value)); + rotate(value); + __asm__ __volatile__("vmovdqu8 %0, %%zmm31" : : "m" (value)); + + __asm__ __volatile__("nop"); + + char result[64]; + __asm__ __volatile__("vmovdqu8 %%zmm0, %0" : "=m" (result)); + + unsigned long check = *(unsigned long*)result; + + if (check == 0xdeadbeefdeadbeef) { + __asm__ __volatile__("nop"); + } + + return 0; +} diff --git a/test/amd64/srcs/floating_point_512_test.c b/test/amd64/srcs/floating_point_512_test.c new file mode 100644 index 00000000..75763e39 --- /dev/null +++ b/test/amd64/srcs/floating_point_512_test.c @@ -0,0 +1,54 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +int main() +{ + char value0[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}; + __asm__ __volatile__("vmovdqu %0, %%xmm0" : : "m" (value0)); + char value1[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00}; + __asm__ __volatile__("vmovdqu %0, %%xmm1" : : "m" (value1)); + char value2[] = {0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11}; + __asm__ __volatile__("vmovdqu %0, %%xmm2" : : "m" (value2)); + char value3[] = {0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22}; + __asm__ __volatile__("vmovdqu %0, %%xmm3" : : "m" (value3)); + char value4[] = {0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33}; + __asm__ __volatile__("vmovdqu %0, %%xmm4" : : "m" (value4)); + char value5[] = {0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44}; + __asm__ __volatile__("vmovdqu %0, %%xmm5" : : "m" (value5)); + char value6[] = {0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55}; + __asm__ __volatile__("vmovdqu %0, %%xmm6" : : "m" (value6)); + char value7[] = {0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66}; + __asm__ __volatile__("vmovdqu %0, %%xmm7" : : "m" (value7)); + char value8[] = {0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; + __asm__ __volatile__("vmovdqu %0, %%xmm8" : : "m" (value8)); + char value9[] = {0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}; + __asm__ __volatile__("vmovdqu %0, %%xmm9" : : "m" (value9)); + char value10[] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99}; + __asm__ __volatile__("vmovdqu %0, %%xmm10" : : "m" (value10)); + char value11[] = {0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA}; + __asm__ __volatile__("vmovdqu %0, %%xmm11" : : "m" (value11)); + char value12[] = {0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB}; + __asm__ __volatile__("vmovdqu %0, %%xmm12" : : "m" (value12)); + char value13[] = {0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC}; + __asm__ __volatile__("vmovdqu %0, %%xmm13" : : "m" (value13)); + char value14[] = {0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD}; + __asm__ __volatile__("vmovdqu %0, %%xmm14" : : "m" (value14)); + char value15[] = {0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE}; + __asm__ __volatile__("vmovdqu %0, %%xmm15" : : "m" (value15)); + + __asm__ __volatile__("nop"); + + char value[16]; + __asm__ __volatile__("vmovdqu %%xmm0, %0" : "=m" (value)); + + unsigned long check = *(unsigned long*)value; + + if (check == 0xdeadbeefdeadbeef) { + __asm__ __volatile__("nop"); + } + + return 0; +} diff --git a/test/amd64/srcs/floating_point_896_test.c b/test/amd64/srcs/floating_point_896_test.c new file mode 100644 index 00000000..5d1a6960 --- /dev/null +++ b/test/amd64/srcs/floating_point_896_test.c @@ -0,0 +1,57 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +int main() +{ + + // load value into floating point register + char value0[] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}; + __asm__ __volatile__("vmovdqu %0, %%ymm0" : : "m" (value0)); + char value1[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00}; + __asm__ __volatile__("vmovdqu %0, %%ymm1" : : "m" (value1)); + char value2[] = {0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11}; + __asm__ __volatile__("vmovdqu %0, %%ymm2" : : "m" (value2)); + char value3[] = {0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22}; + __asm__ __volatile__("vmovdqu %0, %%ymm3" : : "m" (value3)); + char value4[] = {0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33}; + __asm__ __volatile__("vmovdqu %0, %%ymm4" : : "m" (value4)); + char value5[] = {0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44}; + __asm__ __volatile__("vmovdqu %0, %%ymm5" : : "m" (value5)); + char value6[] = {0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55}; + __asm__ __volatile__("vmovdqu %0, %%ymm6" : : "m" (value6)); + char value7[] = {0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66}; + __asm__ __volatile__("vmovdqu %0, %%ymm7" : : "m" (value7)); + char value8[] = {0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; + __asm__ __volatile__("vmovdqu %0, %%ymm8" : : "m" (value8)); + char value9[] = {0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}; + __asm__ __volatile__("vmovdqu %0, %%ymm9" : : "m" (value9)); + char value10[] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99}; + __asm__ __volatile__("vmovdqu %0, %%ymm10" : : "m" (value10)); + char value11[] = {0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA}; + __asm__ __volatile__("vmovdqu %0, %%ymm11" : : "m" (value11)); + char value12[] = {0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB}; + __asm__ __volatile__("vmovdqu %0, %%ymm12" : : "m" (value12)); + char value13[] = {0xDD, 0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC}; + __asm__ __volatile__("vmovdqu %0, %%ymm13" : : "m" (value13)); + char value14[] = {0xEE, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD}; + __asm__ __volatile__("vmovdqu %0, %%ymm14" : : "m" (value14)); + char value15[] = {0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE}; + __asm__ __volatile__("vmovdqu %0, %%ymm15" : : "m" (value15)); + + // breakpoint location 1 + __asm__ __volatile__("nop"); + + char value[32]; + __asm__ __volatile__("vmovdqu %%ymm0, %0" : "=m" (value)); + + unsigned long check = *(unsigned long*)value; + + if (check == 0xdeadbeefdeadbeef) { + __asm__ __volatile__("nop"); + } + + return 0; +} diff --git a/test/srcs/thread_test.c b/test/amd64/srcs/thread_test.c similarity index 100% rename from test/srcs/thread_test.c rename to test/amd64/srcs/thread_test.c diff --git a/test/srcs/antidebug_brute_test.c b/test/common/srcs/antidebug_brute_test.c similarity index 100% rename from test/srcs/antidebug_brute_test.c rename to test/common/srcs/antidebug_brute_test.c diff --git a/test/srcs/attach_test.c b/test/common/srcs/attach_test.c similarity index 100% rename from test/srcs/attach_test.c rename to test/common/srcs/attach_test.c diff --git a/test/srcs/backtrace.c b/test/common/srcs/backtrace.c similarity index 100% rename from test/srcs/backtrace.c rename to test/common/srcs/backtrace.c diff --git a/test/srcs/basic_test_pie.c b/test/common/srcs/basic_test_pie.c similarity index 100% rename from test/srcs/basic_test_pie.c rename to test/common/srcs/basic_test_pie.c diff --git a/test/srcs/benchmark.c b/test/common/srcs/benchmark.c similarity index 100% rename from test/srcs/benchmark.c rename to test/common/srcs/benchmark.c diff --git a/test/srcs/breakpoint_test.c b/test/common/srcs/breakpoint_test.c similarity index 100% rename from test/srcs/breakpoint_test.c rename to test/common/srcs/breakpoint_test.c diff --git a/test/srcs/brute_test.c b/test/common/srcs/brute_test.c similarity index 100% rename from test/srcs/brute_test.c rename to test/common/srcs/brute_test.c diff --git a/test/srcs/catch_signal_test.c b/test/common/srcs/catch_signal_test.c similarity index 100% rename from test/srcs/catch_signal_test.c rename to test/common/srcs/catch_signal_test.c diff --git a/test/srcs/complex_thread_test.c b/test/common/srcs/complex_thread_test.c similarity index 100% rename from test/srcs/complex_thread_test.c rename to test/common/srcs/complex_thread_test.c diff --git a/test/srcs/executable_section_test.c b/test/common/srcs/executable_section_test.c similarity index 100% rename from test/srcs/executable_section_test.c rename to test/common/srcs/executable_section_test.c diff --git a/test/srcs/finish_test.c b/test/common/srcs/finish_test.c similarity index 100% rename from test/srcs/finish_test.c rename to test/common/srcs/finish_test.c diff --git a/test/srcs/handle_syscall_test.c b/test/common/srcs/handle_syscall_test.c similarity index 100% rename from test/srcs/handle_syscall_test.c rename to test/common/srcs/handle_syscall_test.c diff --git a/test/common/srcs/infinite_loop_test.c b/test/common/srcs/infinite_loop_test.c new file mode 100644 index 00000000..8fdc3f1c --- /dev/null +++ b/test/common/srcs/infinite_loop_test.c @@ -0,0 +1,18 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include + +int main() +{ + int number = 0; + + scanf("%d", &number); + + while (1); + + return 0; +} diff --git a/test/srcs/jumpstart_test.c b/test/common/srcs/jumpstart_test.c similarity index 100% rename from test/srcs/jumpstart_test.c rename to test/common/srcs/jumpstart_test.c diff --git a/test/srcs/jumpstart_test_preload.c b/test/common/srcs/jumpstart_test_preload.c similarity index 100% rename from test/srcs/jumpstart_test_preload.c rename to test/common/srcs/jumpstart_test_preload.c diff --git a/test/srcs/math_loop_test.c b/test/common/srcs/math_loop_test.c similarity index 100% rename from test/srcs/math_loop_test.c rename to test/common/srcs/math_loop_test.c diff --git a/test/srcs/memory_test.c b/test/common/srcs/memory_test.c similarity index 100% rename from test/srcs/memory_test.c rename to test/common/srcs/memory_test.c diff --git a/test/srcs/memory_test_2.c b/test/common/srcs/memory_test_2.c similarity index 100% rename from test/srcs/memory_test_2.c rename to test/common/srcs/memory_test_2.c diff --git a/test/common/srcs/memory_test_3.c b/test/common/srcs/memory_test_3.c new file mode 100644 index 00000000..e1992d4c --- /dev/null +++ b/test/common/srcs/memory_test_3.c @@ -0,0 +1,28 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include +#include + +void do_nothing(int *leak) +{ + +} + +int main() +{ + int *buffer = mmap(NULL, sizeof(int) * 1024 * 1024, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + + for (int i = 0; i < 1024 * 1024; i++) { + buffer[i] = i; + } + + do_nothing(buffer); + + munmap(buffer, sizeof(int) * 1024 * 1024); + + return 0; +} \ No newline at end of file diff --git a/test/common/srcs/memory_test_4.c b/test/common/srcs/memory_test_4.c new file mode 100644 index 00000000..5b59f92f --- /dev/null +++ b/test/common/srcs/memory_test_4.c @@ -0,0 +1,71 @@ +// +// This file is part of libdebug Python library (https://github.com/libdebug/libdebug). +// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for details. +// + +#include +#include +#include +#include +#include + +sem_t semaphores[4]; +sem_t leaks_done; + +void leak(char *ptr) +{ + +} + +void before_exit() +{ + +} + +void* thread_fun(void *arg) +{ + // cast arg to int + int thread_index = (int) ((unsigned long) arg); + + char test[16]; + + memset(test, thread_index, 16); + + char *test_ptr = malloc(16); + + memset(test_ptr, thread_index + 4, 16); + + leak(test); + leak(test_ptr); + + sem_post(&leaks_done); + + sem_wait(&semaphores[thread_index]); +} + +int main() +{ + // allocate four threads + pthread_t threads[4]; + + sem_init(&leaks_done, 0, 0); + + for (int i = 0; i < 4; i++) { + sem_init(&semaphores[i], 0, 0); + pthread_create(&threads[i], NULL, thread_fun, (void *) ((unsigned long) i)); + } + + for (int i = 0; i < 4; i++) + sem_wait(&leaks_done); + + before_exit(); + + for (int i = 0; i < 4; i++) + sem_post(&semaphores[i]); + + for (int i = 0; i < 4; i++) + pthread_join(threads[i], NULL); + + return 0; +} \ No newline at end of file diff --git a/test/srcs/segfault_test.c b/test/common/srcs/segfault_test.c similarity index 100% rename from test/srcs/segfault_test.c rename to test/common/srcs/segfault_test.c diff --git a/test/srcs/signals_multithread_det_test.c b/test/common/srcs/signals_multithread_det_test.c similarity index 100% rename from test/srcs/signals_multithread_det_test.c rename to test/common/srcs/signals_multithread_det_test.c diff --git a/test/srcs/signals_multithread_undet_test.c b/test/common/srcs/signals_multithread_undet_test.c similarity index 100% rename from test/srcs/signals_multithread_undet_test.c rename to test/common/srcs/signals_multithread_undet_test.c diff --git a/test/srcs/speed_test.c b/test/common/srcs/speed_test.c similarity index 100% rename from test/srcs/speed_test.c rename to test/common/srcs/speed_test.c diff --git a/test/srcs/watchpoint_test.c b/test/common/srcs/watchpoint_test.c similarity index 100% rename from test/srcs/watchpoint_test.c rename to test/common/srcs/watchpoint_test.c diff --git a/test/run_suite.py b/test/run_suite.py index 27486054..7ebf2c61 100644 --- a/test/run_suite.py +++ b/test/run_suite.py @@ -1,239 +1,35 @@ # # This file is part of libdebug Python library (https://github.com/libdebug/libdebug). -# Copyright (c) 2023-2024 Gabriele Digregorio, Roberto Alessandro Bertolini, Francesco Panebianco. All rights reserved. +# Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for details. # +import os +import platform import sys -import unittest -from scripts.alias_test import AliasTest -from scripts.attach_detach_test import AttachDetachTest -from scripts.auto_waiting_test import AutoWaitingNlinks, AutoWaitingTest -from scripts.backtrace_test import BacktraceTest -from scripts.basic_test import BasicPieTest, BasicTest, ControlFlowTest, HwBasicTest -from scripts.breakpoint_test import BreakpointTest -from scripts.brute_test import BruteTest -from scripts.builtin_handler_test import AntidebugEscapingTest -from scripts.callback_test import CallbackTest -from scripts.catch_signal_test import SignalCatchTest -from scripts.death_test import DeathTest -from scripts.deep_dive_division_test import DeepDiveDivision -from scripts.finish_test import FinishTest -from scripts.handle_syscall_test import HandleSyscallTest -from scripts.hijack_syscall_test import SyscallHijackTest -from scripts.jumpout_test import Jumpout -from scripts.jumpstart_test import JumpstartTest -from scripts.large_binary_sym_test import LargeBinarySymTest -from scripts.memory_test import MemoryTest -from scripts.multiple_debuggers_test import MultipleDebuggersTest -from scripts.nlinks_test import Nlinks -from scripts.pprint_syscalls_test import PPrintSyscallsTest -from scripts.signals_multithread_test import SignalMultithreadTest -from scripts.speed_test import SpeedTest -from scripts.thread_test import ComplexThreadTest, ThreadTest -from scripts.vmwhere1_test import Vmwhere1 -from scripts.waiting_test import WaitingNlinks, WaitingTest -from scripts.watchpoint_alias_test import WatchpointAliasTest -from scripts.watchpoint_test import WatchpointTest - - -def fast_suite(): - suite = unittest.TestSuite() - suite.addTest(BasicTest("test_basic")) - suite.addTest(BasicTest("test_registers")) - suite.addTest(BasicTest("test_step")) - suite.addTest(BasicTest("test_step_hardware")) - suite.addTest(BasicPieTest("test_basic")) - suite.addTest(BreakpointTest("test_bps")) - suite.addTest(BreakpointTest("test_bp_disable")) - suite.addTest(BreakpointTest("test_bp_disable_hw")) - suite.addTest(BreakpointTest("test_bp_disable_reenable")) - suite.addTest(BreakpointTest("test_bp_disable_reenable_hw")) - suite.addTest(BreakpointTest("test_bps_running")) - suite.addTest(BreakpointTest("test_bp_backing_file")) - suite.addTest(BreakpointTest("test_bp_disable_on_creation")) - suite.addTest(BreakpointTest("test_bp_disable_on_creation_2")) - suite.addTest(BreakpointTest("test_bp_disable_on_creation_hardware")) - suite.addTest(BreakpointTest("test_bp_disable_on_creation_2_hardware")) - suite.addTest(MemoryTest("test_memory")) - suite.addTest(MemoryTest("test_mem_access_libs")) - suite.addTest(MemoryTest("test_memory_access_methods_backing_file")) - suite.addTest(MemoryTest("test_memory_exceptions")) - suite.addTest(MemoryTest("test_memory_multiple_runs")) - suite.addTest(MemoryTest("test_memory_access_while_running")) - suite.addTest(MemoryTest("test_memory_access_methods")) - suite.addTest(HwBasicTest("test_basic")) - suite.addTest(HwBasicTest("test_registers")) - suite.addTest(BacktraceTest("test_backtrace_as_symbols")) - suite.addTest(BacktraceTest("test_backtrace")) - suite.addTest(AttachDetachTest("test_attach")) - suite.addTest(AttachDetachTest("test_attach_and_detach_1")) - suite.addTest(AttachDetachTest("test_attach_and_detach_2")) - suite.addTest(AttachDetachTest("test_attach_and_detach_3")) - suite.addTest(AttachDetachTest("test_attach_and_detach_4")) - suite.addTest(ThreadTest("test_thread")) - suite.addTest(ThreadTest("test_thread_hardware")) - suite.addTest(ComplexThreadTest("test_thread")) - suite.addTest(CallbackTest("test_callback_simple")) - suite.addTest(CallbackTest("test_callback_simple_hardware")) - suite.addTest(CallbackTest("test_callback_memory")) - suite.addTest(CallbackTest("test_callback_jumpout")) - suite.addTest(CallbackTest("test_callback_intermixing")) - suite.addTest(CallbackTest("test_callback_exception")) - suite.addTest(CallbackTest("test_callback_step")) - suite.addTest(CallbackTest("test_callback_pid_accessible")) - suite.addTest(CallbackTest("test_callback_pid_accessible_alias")) - suite.addTest(CallbackTest("test_callback_tid_accessible_alias")) - suite.addTest(FinishTest("test_finish_exact_no_auto_interrupt_no_breakpoint")) - suite.addTest(FinishTest("test_finish_heuristic_no_auto_interrupt_no_breakpoint")) - suite.addTest(FinishTest("test_finish_exact_auto_interrupt_no_breakpoint")) - suite.addTest(FinishTest("test_finish_heuristic_auto_interrupt_no_breakpoint")) - suite.addTest(FinishTest("test_finish_exact_no_auto_interrupt_breakpoint")) - suite.addTest(FinishTest("test_finish_heuristic_no_auto_interrupt_breakpoint")) - suite.addTest(FinishTest("test_heuristic_return_address")) - suite.addTest(FinishTest("test_exact_breakpoint_return")) - suite.addTest(FinishTest("test_heuristic_breakpoint_return")) - suite.addTest(FinishTest("test_breakpoint_collision")) - suite.addTest(Jumpout("test_jumpout")) - suite.addTest(Nlinks("test_nlinks")) - suite.addTest(JumpstartTest("test_cursed_ldpreload")) - suite.addTest(ControlFlowTest("test_step_until_1")) - suite.addTest(ControlFlowTest("test_step_until_2")) - suite.addTest(ControlFlowTest("test_step_until_3")) - suite.addTest(ControlFlowTest("test_step_and_cont")) - suite.addTest(ControlFlowTest("test_step_and_cont_hardware")) - suite.addTest(ControlFlowTest("test_step_until_and_cont")) - suite.addTest(ControlFlowTest("test_step_until_and_cont_hardware")) - suite.addTest(MultipleDebuggersTest("test_multiple_debuggers")) - suite.addTest(LargeBinarySymTest("test_large_binary_symbol_load_times")) - suite.addTest(LargeBinarySymTest("test_large_binary_demangle")) - suite.addTest(WaitingTest("test_bps_waiting")) - suite.addTest(WaitingTest("test_jumpout_waiting")) - suite.addTest(WaitingNlinks("test_nlinks")) - suite.addTest(AutoWaitingTest("test_bps_auto_waiting")) - suite.addTest(AutoWaitingTest("test_jumpout_auto_waiting")) - suite.addTest(AutoWaitingNlinks("test_nlinks")) - suite.addTest(WatchpointTest("test_watchpoint")) - suite.addTest(WatchpointTest("test_watchpoint_callback")) - suite.addTest(WatchpointTest("test_watchpoint_disable")) - suite.addTest(WatchpointTest("test_watchpoint_disable_reenable")) - suite.addTest(WatchpointAliasTest("test_watchpoint_alias")) - suite.addTest(WatchpointAliasTest("test_watchpoint_callback")) - suite.addTest(HandleSyscallTest("test_handles")) - suite.addTest(HandleSyscallTest("test_handles_with_pprint")) - suite.addTest(HandleSyscallTest("test_handle_disabling")) - suite.addTest(HandleSyscallTest("test_handle_disabling_with_pprint")) - suite.addTest(HandleSyscallTest("test_handle_overwrite")) - suite.addTest(HandleSyscallTest("test_handle_overwrite_with_pprint")) - suite.addTest(HandleSyscallTest("test_handles_sync")) - suite.addTest(HandleSyscallTest("test_handles_sync_with_pprint")) - suite.addTest(AntidebugEscapingTest("test_antidebug_escaping")) - suite.addTest(SyscallHijackTest("test_hijack_syscall")) - suite.addTest(SyscallHijackTest("test_hijack_syscall_with_pprint")) - suite.addTest(SyscallHijackTest("test_hijack_handle_syscall")) - suite.addTest(SyscallHijackTest("test_hijack_handle_syscall_with_pprint")) - suite.addTest(SyscallHijackTest("test_hijack_syscall_args")) - suite.addTest(SyscallHijackTest("test_hijack_syscall_args_with_pprint")) - suite.addTest(SyscallHijackTest("test_hijack_syscall_wrong_args")) - suite.addTest(SyscallHijackTest("loop_detection_test")) - suite.addTest(PPrintSyscallsTest("test_pprint_syscalls_generic")) - suite.addTest(PPrintSyscallsTest("test_pprint_syscalls_with_statement")) - suite.addTest(PPrintSyscallsTest("test_pprint_handle_syscalls")) - suite.addTest(PPrintSyscallsTest("test_pprint_hijack_syscall")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_after")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_before")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_pprint_after_and_before")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_after")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_before")) - suite.addTest(PPrintSyscallsTest("test_pprint_which_syscalls_not_pprint_after_and_before")) - suite.addTest(SignalCatchTest("test_signal_catch_signal_block")) - suite.addTest(SignalCatchTest("test_signal_pass_to_process")) - suite.addTest(SignalCatchTest("test_signal_disable_catch_signal")) - suite.addTest(SignalCatchTest("test_signal_unblock")) - suite.addTest(SignalCatchTest("test_signal_disable_catch_signal_unblock")) - suite.addTest(SignalCatchTest("test_hijack_signal_with_catch_signal")) - suite.addTest(SignalCatchTest("test_hijack_signal_with_api")) - suite.addTest(SignalCatchTest("test_recursive_true_with_catch_signal")) - suite.addTest(SignalCatchTest("test_recursive_true_with_api")) - suite.addTest(SignalCatchTest("test_recursive_false_with_catch_signal")) - suite.addTest(SignalCatchTest("test_recursive_false_with_api")) - suite.addTest(SignalCatchTest("test_hijack_signal_with_catch_signal_loop")) - suite.addTest(SignalCatchTest("test_hijack_signal_with_api_loop")) - suite.addTest(SignalCatchTest("test_signal_unhijacking")) - suite.addTest(SignalCatchTest("test_override_catch_signal")) - suite.addTest(SignalCatchTest("test_override_hijack")) - suite.addTest(SignalCatchTest("test_override_hybrid")) - suite.addTest(SignalCatchTest("test_signal_get_signal")) - suite.addTest(SignalCatchTest("test_signal_send_signal")) - suite.addTest(SignalCatchTest("test_signal_catch_sync_block")) - suite.addTest(SignalCatchTest("test_signal_catch_sync_pass")) - suite.addTest(SignalMultithreadTest("test_signal_multithread_undet_catch_signal_block")) - suite.addTest(SignalMultithreadTest("test_signal_multithread_undet_pass")) - suite.addTest(SignalMultithreadTest("test_signal_multithread_det_catch_signal_block")) - suite.addTest(SignalMultithreadTest("test_signal_multithread_det_pass")) - suite.addTest(SignalMultithreadTest("test_signal_multithread_send_signal")) - suite.addTest(DeathTest("test_io_death")) - suite.addTest(DeathTest("test_cont_death")) - suite.addTest(DeathTest("test_instr_death")) - suite.addTest(DeathTest("test_exit_signal_death")) - suite.addTest(DeathTest("test_exit_code_death")) - suite.addTest(DeathTest("test_exit_code_normal")) - suite.addTest(DeathTest("test_post_mortem_after_kill")) - suite.addTest(AliasTest("test_basic_alias")) - suite.addTest(AliasTest("test_step_alias")) - suite.addTest(AliasTest("test_step_until_alias")) - suite.addTest(AliasTest("test_memory_alias")) - suite.addTest(AliasTest("test_finish_alias")) - suite.addTest(AliasTest("test_waiting_alias")) - suite.addTest(AliasTest("test_interrupt_alias")) - return suite - - -def complete_suite(): - suite = fast_suite() - suite.addTest(Vmwhere1("test_vmwhere1")) - suite.addTest(Vmwhere1("test_vmwhere1_callback")) - suite.addTest(BruteTest("test_bruteforce")) - suite.addTest(CallbackTest("test_callback_bruteforce")) - suite.addTest(SpeedTest("test_speed")) - suite.addTest(SpeedTest("test_speed_hardware")) - suite.addTest(DeepDiveDivision("test_deep_dive_division")) - return suite - - -def thread_stress_suite(): - suite = unittest.TestSuite() - for _ in range(1024): - suite.addTest(ThreadTest("test_thread")) - suite.addTest(ThreadTest("test_thread_hardware")) - suite.addTest(ComplexThreadTest("test_thread")) - return suite - - -if __name__ == "__main__": - if sys.version_info >= (3, 12): - runner = unittest.TextTestRunner(verbosity=2, durations=3) - else: - runner = unittest.TextTestRunner(verbosity=2) - - if len(sys.argv) > 1 and sys.argv[1].lower() == "slow": - suite = complete_suite() - elif len(sys.argv) > 1 and sys.argv[1].lower() == "thread_stress": - suite = thread_stress_suite() - runner.verbosity = 1 - else: - suite = fast_suite() - - result = runner.run(suite) - - if result.wasSuccessful(): - print("All tests passed") - else: - print("Some tests failed") - print("\nFailed Tests:") - for test, err in result.failures: - print(f"{test}: {err}") - print("\nErrors:") - for test, err in result.errors: - print(f"{test}: {err}") +architectures = os.listdir(".") +architectures.remove("common") + +if len(sys.argv) > 1 and sys.argv[1] not in architectures: + print("Usage: python run_test_suite.py ") + print("Available architectures:") + for arch in architectures: + print(f" {arch}") + sys.exit(1) +elif len(sys.argv) > 1: + arch = sys.argv[1] +else: + arch = platform.machine() + match arch: + case "x86_64": + arch = "amd64" + case "i686": + arch = "i386" + case "aarch64": + arch = "aarch64" + case _: + raise ValueError(f"Unsupported architecture: {arch}") + +os.chdir(arch) +os.system(" ".join([sys.executable, "run_suite.py"]))