Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/841.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve UX for memray attach and detach commands. A new ``--wait`` flag is provided that displays a live progress bar showing time elapsed and remaining.
285 changes: 254 additions & 31 deletions src/memray/commands/attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
import tempfile
import textwrap
import threading
import time

from rich.console import Console
from rich.progress import BarColumn
from rich.progress import Progress
from rich.progress import SpinnerColumn
from rich.progress import TextColumn
from rich.progress import TimeElapsedColumn
from rich.progress import TimeRemainingColumn

import memray
from memray._errors import MemrayCommandError
Expand Down Expand Up @@ -331,6 +340,50 @@ def recvall(sock: socket.socket) -> str:
return b"".join(iter(lambda: sock.recv(4096), b"")).decode("utf-8")


def show_progress_with_duration(duration: int, pid: int) -> None:
"""Show a progress indicator while waiting for the specified duration.

Args:
duration: Duration in seconds to wait
pid: Process ID being tracked
"""
console = Console()

with Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task(
f"Tracking process {pid}", total=duration * 10 # 10 updates per second
)

try:
start_time = time.time()
while not progress.finished:
elapsed = time.time() - start_time
if elapsed >= duration:
progress.update(task, completed=duration * 10)
break

progress.update(task, completed=int(elapsed * 10))
time.sleep(0.1)

except KeyboardInterrupt:
console.print()
console.print(
"[yellow]⚠ Interrupted! Tracking is still running in the background.[/yellow]"
)
console.print(
f"[yellow] Use 'memray detach {pid}' to stop tracking immediately.[/yellow]"
)
raise


class ErrorReaderThread(threading.Thread):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
Expand Down Expand Up @@ -488,6 +541,13 @@ def prepare_parser(self, parser: argparse.ArgumentParser) -> None:
"--duration", type=int, help="Duration to track for (in seconds)"
)

parser.add_argument(
"--wait",
help="Wait for tracking to complete before exiting (use with --duration)",
action="store_true",
default=False,
)

super().prepare_parser(parser)

def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
Expand All @@ -503,9 +563,18 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None

destination: memray.Destination
if args.output:
# Check if output file exists before doing any work
output_path = pathlib.Path(args.output).resolve()
if output_path.exists() and not args.force:
raise MemrayCommandError(
f"Output file already exists: {output_path}\n"
f"Use --force to overwrite it.",
exit_code=1,
)

live_port = None
destination = memray.FileDestination(
path=os.path.abspath(args.output),
path=str(output_path),
overwrite=args.force,
compress_on_exit=not args.no_compress,
)
Expand All @@ -530,23 +599,131 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None
f"{file_format})"
)

client = self.inject_control_channel(args.method, args.pid, verbose=verbose)
client.sendall(
PAYLOAD.format(
tracker_call=tracker_call,
mode=mode,
duration=duration,
).encode("utf-8")
)
client.shutdown(socket.SHUT_WR)
console = Console()

# Show detailed attaching progress with steps
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
console=console,
transient=True, # Clear progress when done
) as progress:
# Step 1: Resolve debugger method
task1 = progress.add_task("Resolving injection method...", total=None)
resolved_method = self.resolve_debugger(args.method, verbose=verbose)
progress.update(
task1,
description=f"[green]✓[/green] Using {resolved_method} for injection",
)
progress.stop_task(task1)

# Step 2: Inject control channel
task2 = progress.add_task(
f"Injecting into process {args.pid} using {resolved_method}...",
total=None,
)
client = self.inject_control_channel(
resolved_method, args.pid, verbose=verbose
)
progress.update(
task2, description="[green]✓[/green] Control channel established"
)
progress.stop_task(task2)

# Step 3: Send tracking payload
task3 = progress.add_task("Sending tracking configuration...", total=None)
client.sendall(
PAYLOAD.format(
tracker_call=tracker_call,
mode=mode,
duration=duration,
).encode("utf-8")
)
client.shutdown(socket.SHUT_WR)
progress.update(task3, description="[green]✓[/green] Configuration sent")
progress.stop_task(task3)

# Step 4: Wait for confirmation
if not live_port:
task4 = progress.add_task(
"Waiting for confirmation from process...", total=None
)
err = recvall(client)
if err:
raise MemrayCommandError(
f"Failed to start tracking in remote process: {err}",
exit_code=1,
)
progress.update(
task4,
description="[green]✓[/green] Tracking activated in remote process",
)
progress.stop_task(task4)

# Only show confirmation after attach succeeded
if not live_port:
err = recvall(client)
if err:
raise MemrayCommandError(
f"Failed to start tracking in remote process: {err}",
exit_code=1,
console.print(
f"[green]✓[/green] Successfully attached to process [bold]{args.pid}[/bold]"
)
console.print(f" Output file: [cyan]{args.output}[/cyan]")

# If duration and --wait are specified, wait and show progress
if duration and args.wait:
console.print(f" Tracking for [bold]{duration}[/bold] seconds...")
console.print() # Add blank line before progress bar
try:
show_progress_with_duration(duration, args.pid)
console.print() # Add blank line after completion
console.print(
f"[green]✓[/green] Tracking complete. "
f"Results saved to: [cyan]{args.output}[/cyan]"
)
except KeyboardInterrupt:
console.print(
f"\n[yellow]⚠ Note: Tracking will continue in process "
f"{args.pid} until the duration expires.[/yellow]"
)
console.print(
f"[yellow] Use 'memray detach {args.pid}' "
f"to stop tracking immediately.[/yellow]"
)
raise MemrayCommandError("Interrupted by user", exit_code=130)
elif duration:
# Duration specified but not waiting - show prominent info message
console.print() # Blank line for emphasis
console.print(
"[blue]ℹ[/blue] This command will exit immediately, "
"but tracking continues in the background."
)
console.print(
f" The process will be tracked for [bold]{duration}[/bold] "
f"seconds and results will be saved to [cyan]{args.output}[/cyan]."
)
console.print() # Blank line
console.print(
f" To stop tracking early: "
f"[bold]memray detach {args.pid}[/bold]"
)
console.print(
" To wait and see progress: "
"Use the [bold]--wait[/bold] flag next time"
)
else:
# No duration - indefinite tracking
console.print() # Blank line for emphasis
console.print(
"[blue]ℹ[/blue] This command will exit immediately, "
"but tracking continues indefinitely."
)
console.print(
f" Results will be saved to [cyan]{args.output}[/cyan] "
f"when tracking stops."
)
console.print() # Blank line
console.print(
f" To stop tracking: " f"[bold]memray detach {args.pid}[/bold]"
)

return

# If an error prevents the tracked process from binding a server to
Expand Down Expand Up @@ -585,21 +762,67 @@ class DetachCommand(_DebuggerCommand):
def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
verbose = args.verbose
mode: TrackingMode = "DEACTIVATE"
args.method = self.resolve_debugger(args.method, verbose=verbose)
client = self.inject_control_channel(args.method, args.pid, verbose=verbose)

client.sendall(
PAYLOAD.format(
tracker_call=None,
mode=mode,
duration=None,
).encode("utf-8")
)
client.shutdown(socket.SHUT_WR)
console = Console()

# Show detailed detaching progress with steps
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
console=console,
transient=True, # Clear progress when done
) as progress:
# Step 1: Resolve debugger method
task1 = progress.add_task("Resolving injection method...", total=None)
resolved_method = self.resolve_debugger(args.method, verbose=verbose)
progress.update(
task1,
description=f"[green]✓[/green] Using {resolved_method} for injection",
)
progress.stop_task(task1)

err = recvall(client)
if err:
raise MemrayCommandError(
f"Failed to stop tracking in remote process: {err}",
exit_code=1,
# Step 2: Inject control channel
task2 = progress.add_task(
f"Connecting to process {args.pid} using {resolved_method}...",
total=None,
)
client = self.inject_control_channel(
resolved_method, args.pid, verbose=verbose
)
progress.update(
task2, description="[green]✓[/green] Control channel established"
)
progress.stop_task(task2)

# Step 3: Send detach command
task3 = progress.add_task("Sending stop tracking command...", total=None)
client.sendall(
PAYLOAD.format(
tracker_call=None,
mode=mode,
duration=None,
).encode("utf-8")
)
client.shutdown(socket.SHUT_WR)
progress.update(task3, description="[green]✓[/green] Stop command sent")
progress.stop_task(task3)

# Step 4: Wait for confirmation
task4 = progress.add_task(
"Waiting for confirmation from process...", total=None
)
err = recvall(client)
if err:
raise MemrayCommandError(
f"Failed to stop tracking in remote process: {err}",
exit_code=1,
)
progress.update(
task4, description="[green]✓[/green] Tracking stopped in remote process"
)
progress.stop_task(task4)

# Show final confirmation
console.print(
f"[green]✓[/green] Successfully stopped tracking in process "
f"[bold]{args.pid}[/bold]"
)
Loading