diff --git a/src/pytest_memray/marks.py b/src/pytest_memray/marks.py index eac7e04..01fdb49 100644 --- a/src/pytest_memray/marks.py +++ b/src/pytest_memray/marks.py @@ -193,20 +193,13 @@ def limit_memory( ) -> _MemoryInfo | _MoreMemoryInfo | None: """Limit memory used by the test.""" reader = FileReader(_result_file) - allocations: list[AllocationRecord] = [] - if current_thread_only: - main_thread = reader.metadata.main_thread_id - allocations.extend( - record - for record in reader.get_high_watermark_allocation_records( - merge_threads=False - ) - if record.tid == main_thread - ) - else: - allocations.extend( - reader.get_high_watermark_allocation_records(merge_threads=True) + allocations: list[AllocationRecord] = [ + record + for record in reader.get_high_watermark_allocation_records( + merge_threads=not current_thread_only ) + if not current_thread_only or record.tid == reader.metadata.main_thread_id + ] max_memory = parse_memory_string(limit) total_allocated_memory = sum(record.size for record in allocations) @@ -243,16 +236,13 @@ def limit_leaks( _test_id: str, ) -> _LeakedInfo | None: reader = FileReader(_result_file) - allocations: list[AllocationRecord] = [] - if current_thread_only: - main_thread_id = reader.metadata.main_thread_id - allocations.extend( - record - for record in reader.get_leaked_allocation_records(merge_threads=False) - if record.tid == main_thread_id + allocations: list[AllocationRecord] = [ + record + for record in reader.get_leaked_allocation_records( + merge_threads=not current_thread_only ) - else: - allocations.extend(reader.get_leaked_allocation_records(merge_threads=True)) + if not current_thread_only or record.tid == reader.metadata.main_thread_id + ] memory_limit = parse_memory_string(location_limit)