Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def handle_timeout(self, signume, frame):
raise TimeoutException(self.error_msg)

def __enter__(self):
# return
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)

def __exit__(self, exc_type, exc_val, exc_tb):
# return
signal.alarm(0)
2 changes: 1 addition & 1 deletion selfdrive/controls/controlsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def publish(self, CC, lac_log):
dat.valid = CS.canValid
cs = dat.controlsState

cs.curvature = self.curvature
cs.curvature = 420.69
cs.longitudinalPlanMonoTime = self.sm.logMonoTime['longitudinalPlan']
cs.lateralPlanMonoTime = self.sm.logMonoTime['modelV2']
cs.desiredCurvature = self.desired_curvature
Expand Down
39 changes: 32 additions & 7 deletions selfdrive/test/process_replay/process_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import heapq
import signal
from collections import Counter
from collections import Counter, deque
from dataclasses import dataclass, field
from itertools import islice
from typing import Any
Expand Down Expand Up @@ -95,8 +95,10 @@ def send_sync(self, pm, endpoint, dat):

def unlock_sockets(self):
expected_sets = len(self.events)
# print('events', self.events)
while expected_sets > 0:
index = messaging.wait_for_one_event(self.all_recv_called_events)
# print('cleared', index)
self.all_recv_called_events[index].clear()
self.all_recv_ready_events[index].set()
expected_sets -= 1
Expand Down Expand Up @@ -301,6 +303,7 @@ def run_step(self, msg: capnp._DynamicStructReader, frs: dict[str, FrameReader]
self.pm.send(m.which(), m.as_builder())
# send frames if needed
if self.vipc_server is not None and m.which() in self.cfg.vision_pubs:
raise Exception
camera_state = getattr(m, m.which())
camera_meta = meta_from_camera_state(m.which())
assert frs is not None
Expand All @@ -309,6 +312,8 @@ def run_step(self, msg: capnp._DynamicStructReader, frs: dict[str, FrameReader]
camera_state.frameId, camera_state.timestampSof, camera_state.timestampEof)
self.msg_queue = []

# input()
# *** goes through each socket that called recv and clears recv_called event, then sets recv ready event ***
self.rc.unlock_sockets()
if trigger_empty_recv:
self.rc.unlock_sockets()
Expand Down Expand Up @@ -617,8 +622,10 @@ def replay_process_with_name(name: str | Iterable[str], lr: LogIterable, *args,
def replay_process(
cfg: ProcessConfig | Iterable[ProcessConfig], lr: LogIterable, frs: dict[str, FrameReader] = None,
fingerprint: str = None, return_all_logs: bool = False, custom_params: dict[str, Any] = None,
captured_output_store: dict[str, dict[str, str]] = None, disable_progress: bool = False
captured_output_store: dict[str, dict[str, str]] = None, disable_progress: bool = False, t=None
) -> list[capnp._DynamicStructReader]:
if t is None:
t = time.monotonic()
if isinstance(cfg, Iterable):
cfgs = list(cfg)
else:
Expand All @@ -628,7 +635,8 @@ def replay_process(
manager_states=True,
panda_states=any("pandaStates" in cfg.pubs for cfg in cfgs),
camera_states=any(len(cfg.vision_pubs) != 0 for cfg in cfgs))
process_logs = _replay_multi_process(cfgs, all_msgs, frs, fingerprint, custom_params, captured_output_store, disable_progress)
# return all_msgs
process_logs = _replay_multi_process(cfgs, all_msgs, frs, fingerprint, custom_params, captured_output_store, disable_progress, t)

if return_all_logs:
keys = {m.which() for m in process_logs}
Expand All @@ -644,7 +652,7 @@ def replay_process(

def _replay_multi_process(
cfgs: list[ProcessConfig], lr: LogIterable, frs: dict[str, FrameReader] | None, fingerprint: str | None,
custom_params: dict[str, Any] | None, captured_output_store: dict[str, dict[str, str]] | None, disable_progress: bool
custom_params: dict[str, Any] | None, captured_output_store: dict[str, dict[str, str]] | None, disable_progress: bool, t: float
) -> list[capnp._DynamicStructReader]:
if fingerprint is not None:
params_config = generate_params_config(lr=lr, fingerprint=fingerprint, custom_params=custom_params)
Expand All @@ -653,15 +661,17 @@ def _replay_multi_process(
CP = next((m.carParams for m in lr if m.which() == "carParams"), None)
params_config = generate_params_config(lr=lr, CP=CP, custom_params=custom_params)
env_config = generate_environ_config(CP=CP)
print('env config', time.monotonic() - t)

# validate frs and vision pubs
all_vision_pubs = [pub for cfg in cfgs for pub in cfg.vision_pubs]
if len(all_vision_pubs) != 0:
assert frs is not None, "frs must be provided when replaying process using vision streams"
assert all(meta_from_camera_state(st) is not None for st in all_vision_pubs), \
f"undefined vision stream spotted, probably misconfigured process: (vision pubs: {all_vision_pubs})"
f"undefined vision stream spotted, probably misconfigured process: (vision pubs: {all_vision_pubs})"
required_vision_pubs = {m.camera_state for m in available_streams(lr)} & set(all_vision_pubs)
assert all(st in frs for st in required_vision_pubs), f"frs for this process must contain following vision streams: {required_vision_pubs}"
print('validate frs', time.monotonic() - t)

all_msgs = sorted(lr, key=lambda msg: msg.logMonoTime)
log_msgs = []
Expand All @@ -672,26 +682,39 @@ def _replay_multi_process(
containers.append(container)
container.start(params_config, env_config, all_msgs, frs, fingerprint, captured_output_store is not None)

print('created containers', time.monotonic() - t)

all_pubs = {pub for container in containers for pub in container.pubs}
all_subs = {sub for container in containers for sub in container.subs}
lr_pubs = all_pubs - all_subs
print('all_pubs', all_pubs, 'all_subs', all_subs, 'lr_pubs', lr_pubs)
pubs_to_containers = {pub: [container for container in containers if pub in container.pubs] for pub in all_pubs}

print('prepared pubs and subs', time.monotonic() - t)

pub_msgs = [msg for msg in all_msgs if msg.which() in lr_pubs]
# external queue for messages taken from logs; internal queue for messages generated by processes, which will be republished
external_pub_queue: list[capnp._DynamicStructReader] = pub_msgs.copy()
external_pub_queue: list[capnp._DynamicStructReader] = deque(pub_msgs)
internal_pub_queue: list[capnp._DynamicStructReader] = []
# heap for maintaining the order of messages generated by processes, where each element: (logMonoTime, index in internal_pub_queue)
internal_pub_index_heap: list[tuple[int, int]] = []

print('prepared queues', time.monotonic() - t)

pbar = tqdm(total=len(external_pub_queue), disable=disable_progress)
print('starting looping', time.monotonic() - t)
while len(external_pub_queue) != 0 or (len(internal_pub_index_heap) != 0 and not all(c.has_empty_queue for c in containers)):
# t = time.monotonic()
# del external_pub_queue[0]
# msg = external_pub_queue.popleft()
if len(internal_pub_index_heap) == 0 or (len(external_pub_queue) != 0 and external_pub_queue[0].logMonoTime < internal_pub_index_heap[0][0]):
msg = external_pub_queue.pop(0)
msg = external_pub_queue.popleft()
pbar.update(1)
else:
# raise Exception
_, index = heapq.heappop(internal_pub_index_heap)
msg = internal_pub_queue[index]
# print('loop pop msg', time.monotonic() - t)

target_containers = pubs_to_containers[msg.which()]
for container in target_containers:
Expand All @@ -707,6 +730,8 @@ def _replay_multi_process(
last_time = log_msgs[-1].logMonoTime if len(log_msgs) > 0 else int(time.monotonic() * 1e9)
log_msgs.extend(container.get_output_msgs(last_time))
finally:
print('final internal_pub_queue len', len(internal_pub_queue))
print('loop finished', time.monotonic() - t)
for container in containers:
container.stop()
if captured_output_store is not None:
Expand Down
57 changes: 57 additions & 0 deletions tools/diff/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
import argparse
import time
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor

from openpilot.selfdrive.test.process_replay.process_replay import CONFIGS, replay_process
from openpilot.tools.lib.logreader import LogReader, save_log

SEG_LIST = [
"d9b97c1d3b8c39b2/0000018b--8a62ed4984/1",
] * 100

# these use cameras/run models which are slow
BLACKLIST_PROCS = ['modeld', 'dmonitoringmodeld']
WHITELIST_PROCS = ['radard'] # TODO: temporary for debugging


def replay(cfgs, seg):
inputs = list(LogReader(seg))
t = time.monotonic()
outputs = replay_process(cfgs, inputs, fingerprint=None, disable_progress=True, t=t)
print(f"\nTotal time: {time.monotonic() - t} seconds")

# Remove message generated by the process under test and merge in the new messages
produces = {o.which() for o in outputs}
inputs = [i for i in inputs if i.which() not in produces]
outputs = sorted(inputs + outputs, key=lambda x: x.logMonoTime)
return outputs


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="process replay v2",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# parser.add_argument("--fingerprint", help="The fingerprint to use")
# parser.add_argument("route", help="The route name to use")
parser.add_argument("-n", type=int, default=8, help="Number of processes to use")
args = parser.parse_args()

cfgs = [c for c in CONFIGS if c.proc_name not in BLACKLIST_PROCS and c.proc_name in WHITELIST_PROCS]

t = time.monotonic()
with ProcessPoolExecutor(max_workers=args.n) as executor:
futures = []
for seg in tqdm(SEG_LIST):
futures.append(executor.submit(replay, cfgs, seg))

print('hi')
for future in tqdm(futures):
outputs = future.result()

print('got', len(outputs), 'output messages')

# fn = f"diff_{seg.replace('/', '_')}.zst"
# print(f"Saving log to {fn}")
# save_log(fn, outputs)
print(f"\nTotal time: {time.monotonic() - t} seconds")
3 changes: 3 additions & 0 deletions tools/lib/logreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CachedReader:
__slots__ = ("_evt", "_enum")

def __init__(self, evt: capnp._DynamicStructReader):
assert not isinstance(evt, CachedReader), "CachedReader should not be nested"
"""All capnp attribute accesses are expensive, and which() is often called multiple times"""
self._evt = evt
self._enum: str | None = None
Expand All @@ -74,6 +75,8 @@ def which(self) -> str:
return self._enum

def __getattr__(self, name: str):
if name.startswith("__") and name.endswith("__"):
return super().__getattr__(name)
return getattr(self._evt, name)


Expand Down
Loading