|
6 | 6 | from logging import INFO, getLogger |
7 | 7 | from logging.handlers import QueueHandler |
8 | 8 | from queue import Empty, SimpleQueue |
9 | | -from typing import Generator |
| 9 | +from typing import Generator, Optional |
| 10 | +from hashlib import sha256 |
| 11 | +from unittest.mock import MagicMock |
10 | 12 | import pytest |
11 | 13 |
|
12 | 14 | from openjd.sessions import PosixSessionUser, WindowsSessionUser, BadCredentialsException |
13 | 15 | from openjd.sessions._os_checker import is_posix, is_windows |
14 | 16 | from openjd.sessions._logging import LoggerAdapter |
| 17 | +from openjd.sessions._action_filter import ActionMonitoringFilter |
15 | 18 |
|
16 | 19 | if is_windows(): |
17 | 20 | from openjd.sessions._win32._helpers import ( # type: ignore |
@@ -55,15 +58,102 @@ def pytest_collection_modifyitems(config, items): |
55 | 58 | config.option.markexpr = mark_expr |
56 | 59 |
|
57 | 60 |
|
| 61 | +def create_unique_logger_name(prefix: str = "", seed: Optional[str] = None) -> str: |
| 62 | + """Create a unique logger name using a hash to avoid collisions. |
| 63 | +
|
| 64 | + Args: |
| 65 | + prefix: Optional prefix for the logger name |
| 66 | + seed: Optional seed string to use for generating the hash |
| 67 | +
|
| 68 | + Returns: |
| 69 | + A unique logger name |
| 70 | + """ |
| 71 | + if seed: |
| 72 | + h = sha256() |
| 73 | + h.update(seed.encode("utf-8")) |
| 74 | + suffix = h.hexdigest()[0:32] |
| 75 | + else: |
| 76 | + charset = string.ascii_letters + string.digits |
| 77 | + suffix = "".join(random.choices(charset, k=32)) |
| 78 | + |
| 79 | + return f"{prefix}{suffix}" |
| 80 | + |
| 81 | + |
58 | 82 | def build_logger(handler: QueueHandler) -> LoggerAdapter: |
59 | | - charset = string.ascii_letters + string.digits + string.punctuation |
60 | | - name_suffix = "".join(random.choices(charset, k=32)) |
| 83 | + """Build a logger for testing purposes. |
| 84 | +
|
| 85 | + Args: |
| 86 | + handler: The queue handler to attach to the logger |
| 87 | +
|
| 88 | + Returns: |
| 89 | + A configured LoggerAdapter |
| 90 | + """ |
| 91 | + name_suffix = create_unique_logger_name() |
61 | 92 | log = getLogger(".".join((__name__, name_suffix))) |
62 | 93 | log.setLevel(INFO) |
63 | 94 | log.addHandler(handler) |
64 | 95 | return LoggerAdapter(log, extra=dict()) |
65 | 96 |
|
66 | 97 |
|
| 98 | +def setup_action_filter_test( |
| 99 | + queue_handler: QueueHandler, |
| 100 | + session_id: str = "foo", |
| 101 | + callback: Optional[MagicMock] = None, |
| 102 | + suppress_filtered: bool = False, |
| 103 | + enabled_extensions: Optional[list[str]] = None, |
| 104 | +) -> tuple[LoggerAdapter, ActionMonitoringFilter, MagicMock]: |
| 105 | + """Set up a test environment for testing ActionMonitoringFilter. |
| 106 | +
|
| 107 | + This helper method creates a unique logger name, sets up the ActionMonitoringFilter, |
| 108 | + and configures the logger with the filter. |
| 109 | +
|
| 110 | + Args: |
| 111 | + queue_handler: The QueueHandler to attach to the logger |
| 112 | + session_id: The session ID to use for the filter |
| 113 | + callback: Optional mock callback to use for the filter |
| 114 | + suppress_filtered: Whether to suppress filtered messages |
| 115 | + enabled_extensions: Optional list of extensions to enable |
| 116 | +
|
| 117 | + Returns: |
| 118 | + A tuple containing (logger_adapter, action_filter, callback_mock) |
| 119 | +
|
| 120 | + Note: |
| 121 | + This helper works for most tests, but for tests that need to verify specific |
| 122 | + callback behavior with redacted values, it's better to create the filter and |
| 123 | + logger directly in the test. This is because when multiple filters are applied |
| 124 | + to the same log message (which can happen when running multiple tests), the |
| 125 | + redaction can happen before the callback is invoked, resulting in the callback |
| 126 | + receiving redacted values instead of the original values. |
| 127 | + """ |
| 128 | + # Create a unique logger name WITHOUT using the message as seed |
| 129 | + # This ensures each test gets a truly unique logger name |
| 130 | + logger_name = create_unique_logger_name(prefix="action_filter_") |
| 131 | + |
| 132 | + # Create a mock callback if one wasn't provided |
| 133 | + if callback is None: |
| 134 | + callback = MagicMock() |
| 135 | + |
| 136 | + # Create the filter directly with the provided parameters |
| 137 | + action_filter = ActionMonitoringFilter( |
| 138 | + session_id=session_id, |
| 139 | + callback=callback, |
| 140 | + suppress_filtered=suppress_filtered, |
| 141 | + enabled_extensions=enabled_extensions, |
| 142 | + ) |
| 143 | + |
| 144 | + # Set up the logger |
| 145 | + log = getLogger(".".join((__name__, logger_name))) |
| 146 | + log.setLevel(INFO) |
| 147 | + log.addHandler(queue_handler) |
| 148 | + log.addFilter(action_filter) |
| 149 | + |
| 150 | + # Create and return the logger adapter with the session_id |
| 151 | + # This is critical for the filter to work properly |
| 152 | + logger_adapter = LoggerAdapter(log, extra={"session_id": session_id}) |
| 153 | + |
| 154 | + return logger_adapter, action_filter, callback |
| 155 | + |
| 156 | + |
67 | 157 | def collect_queue_messages(queue: SimpleQueue) -> list[str]: |
68 | 158 | """Extract the text of messages from a SimpleQueue containing LogRecords""" |
69 | 159 | messages: list[str] = [] |
|
0 commit comments