Skip to content

Commit 59ef0f0

Browse files
committed
JsonPickler
1 parent 5d3a484 commit 59ef0f0

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

Diff for: backend/main/views.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def run_code(self, code, source):
121121

122122
for call in birdseye_objects["calls"]:
123123
call["function_id"] = function_ids[call["function_id"]]
124-
if isinstance(call["start_time"], str):
125-
call["start_time"] = datetime.fromisoformat(call["start_time"])
124+
call["start_time"] = datetime.fromisoformat(call["start_time"])
126125
call = eye.db.Call(**call)
127126
session.add(call)
128127
# TODO get correct call from top level

Diff for: backend/main/workers/master.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import atexit
2-
import multiprocessing
32
import queue
43
from collections import defaultdict
54
from functools import lru_cache
6-
from multiprocessing.context import Process
5+
from multiprocessing import Queue, Process
76
from threading import Thread
87

98
from main import simple_settings
@@ -15,10 +14,10 @@
1514

1615

1716
class UserProcess:
18-
def __init__(self, manager):
19-
self.task_queue = manager.Queue()
20-
self.input_queue = manager.Queue()
21-
self.result_queue = manager.Queue()
17+
def __init__(self):
18+
self.task_queue = Queue()
19+
self.input_queue = Queue()
20+
self.result_queue = Queue()
2221
self.awaiting_input = False
2322
self.process = None
2423
self.start_process()
@@ -85,8 +84,7 @@ def _await_result(self):
8584

8685
def master_consumer_loop(comms: AbstractCommunications):
8786
comms = comms.make_master_side_communications()
88-
manager = multiprocessing.Manager()
89-
user_processes = defaultdict(lambda: UserProcess(manager))
87+
user_processes = defaultdict(UserProcess)
9088

9189
while True:
9290
entry = comms.recv_entry()

Diff for: backend/main/workers/utils.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import sys
22
import traceback
3+
import multiprocessing.queues
4+
import json
35

6+
from littleutils import DecentJSONEncoder
47
from sentry_sdk import capture_exception
58

69

@@ -36,6 +39,8 @@ def string(self):
3639

3740
output_buffer = OutputBuffer()
3841

42+
json_encoder = DecentJSONEncoder()
43+
3944

4045
def make_result(
4146
passed=False,
@@ -52,7 +57,7 @@ def make_result(
5257
if output_parts is None:
5358
output_parts = output_buffer.pop()
5459

55-
return dict(
60+
result = dict(
5661
passed=passed,
5762
message=message,
5863
awaiting_input=awaiting_input,
@@ -61,6 +66,10 @@ def make_result(
6166
birdseye_objects=birdseye_objects,
6267
error=error,
6368
)
69+
# Check that JSON encoding works here
70+
# because failures in the queue pickling are silent
71+
json_pickler.dumps(result)
72+
return result
6473

6574

6675
def internal_error_result(sentry_offline=False):
@@ -87,3 +96,16 @@ def internal_error_result(sentry_offline=False):
8796
output_parts=[dict(color="red", text=output)],
8897
error=dict(traceback=tb, sentry_event=sentry_event),
8998
)
99+
100+
101+
# Queues don't communicate in pickle so that the worker
102+
# can't put something malicious for the master to unpickle
103+
class JsonPickler:
104+
def loads(self, b):
105+
return json.loads(b.decode("utf8"))
106+
107+
def dumps(self, x):
108+
return json_encoder.encode(x).encode("utf8")
109+
110+
111+
multiprocessing.queues._ForkingPickler = json_pickler = JsonPickler()

0 commit comments

Comments
 (0)