Skip to content

Commit ad42017

Browse files
committed
fix(upgrade): use start and end timestamps to filter out irreleveant timestamps
1 parent 2662751 commit ad42017

File tree

8 files changed

+149
-86
lines changed

8 files changed

+149
-86
lines changed

pychunkedgraph/app/meshing/common.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import threading
55

66
import numpy as np
7-
import redis
8-
from rq import Queue, Connection, Retry
97
from flask import Response, current_app, jsonify, make_response, request
108

119
from pychunkedgraph import __version__
@@ -145,37 +143,15 @@ def _check_post_options(cg, resp, data, seg_ids):
145143
def handle_remesh(table_id):
146144
current_app.request_type = "remesh_enque"
147145
current_app.table_id = table_id
148-
is_priority = request.args.get("priority", True, type=str2bool)
149-
is_redisjob = request.args.get("use_redis", False, type=str2bool)
150-
151146
new_lvl2_ids = json.loads(request.data)["new_lvl2_ids"]
152-
153-
if is_redisjob:
154-
with Connection(redis.from_url(current_app.config["REDIS_URL"])):
155-
156-
if is_priority:
157-
retry = Retry(max=3, interval=[1, 10, 60])
158-
queue_name = "mesh-chunks"
159-
else:
160-
retry = Retry(max=3, interval=[60, 60, 60])
161-
queue_name = "mesh-chunks-low-priority"
162-
q = Queue(queue_name, retry=retry, default_timeout=1200)
163-
task = q.enqueue(meshing_tasks.remeshing, table_id, new_lvl2_ids)
164-
165-
response_object = {"status": "success", "data": {"task_id": task.get_id()}}
166-
167-
return jsonify(response_object), 202
168-
else:
169-
new_lvl2_ids = np.array(new_lvl2_ids, dtype=np.uint64)
170-
cg = app_utils.get_cg(table_id)
171-
172-
if len(new_lvl2_ids) > 0:
173-
t = threading.Thread(
174-
target=_remeshing, args=(cg.get_serialized_info(), new_lvl2_ids)
175-
)
176-
t.start()
177-
178-
return Response(status=202)
147+
new_lvl2_ids = np.array(new_lvl2_ids, dtype=np.uint64)
148+
cg = app_utils.get_cg(table_id)
149+
if len(new_lvl2_ids) > 0:
150+
t = threading.Thread(
151+
target=_remeshing, args=(cg.get_serialized_info(), new_lvl2_ids)
152+
)
153+
t.start()
154+
return Response(status=202)
179155

180156

181157
def _remeshing(serialized_cg_info, lvl2_nodes):

pychunkedgraph/graph/edges/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,10 @@ def get_latest_edges(
237237
Then get supervoxels of those L2 IDs and get parent(s) at `node` level.
238238
These parents would be the new identities for the stale `partner`.
239239
"""
240-
_nodes = np.unique(stale_edges[:, 1])
240+
_nodes = np.unique(stale_edges)
241241
nodes_ts_map = dict(
242242
zip(_nodes, cg.get_node_timestamps(_nodes, return_numpy=False, normalize=True))
243243
)
244-
_nodes = np.unique(stale_edges)
245244
layers, coords = cg.get_chunk_layers_and_coordinates(_nodes)
246245
layers_d = dict(zip(_nodes, layers))
247246
coords_d = dict(zip(_nodes, coords))
@@ -352,7 +351,9 @@ def _filter(node):
352351

353352
_edges = []
354353
edges_d = cg.get_cross_chunk_edges(
355-
node_ids=l2ids_a, time_stamp=nodes_ts_map[node_b], raw_only=True
354+
node_ids=l2ids_a,
355+
time_stamp=max(nodes_ts_map[node_a], nodes_ts_map[node_b]),
356+
raw_only=True,
356357
)
357358
for v in edges_d.values():
358359
_edges.append(v.get(edge_layer, types.empty_2d))
@@ -382,7 +383,8 @@ def _filter(node):
382383

383384
parents_a = np.array([node_a] * parents_b.size, dtype=basetypes.NODE_ID)
384385
_new_edges = np.column_stack((parents_a, parents_b))
385-
assert _new_edges.size, f"No edge found for {node_a}, {node_b} at {parent_ts}"
386+
err = f"No edge found for {node_a}, {node_b} at {edge_layer}; {parent_ts}"
387+
assert _new_edges.size, err
386388
result.append(_new_edges)
387389
return np.concatenate(result)
388390

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22

3-
from datetime import timedelta
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
import logging, math, time
45

56
import fastremap
67
import numpy as np
8+
from tqdm import tqdm
79
from pychunkedgraph.graph import ChunkedGraph
8-
from pychunkedgraph.graph.attributes import Connectivity
10+
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
911
from pychunkedgraph.graph.utils import serializers
12+
from pychunkedgraph.utils.general import chunked
1013

11-
from .utils import exists_as_parent, get_parent_timestamps
14+
from .utils import exists_as_parent, get_end_timestamps, get_parent_timestamps
15+
16+
CHILDREN = {}
1217

1318

1419
def update_cross_edges(
15-
cg: ChunkedGraph, node, cx_edges_d: dict, node_ts, timestamps: set, earliest_ts
20+
cg: ChunkedGraph, node, cx_edges_d: dict, node_ts, node_end_ts, timestamps: set
1621
) -> list:
1722
"""
1823
Helper function to update a single L2 ID.
@@ -27,13 +32,15 @@ def update_cross_edges(
2732
assert not exists_as_parent(cg, node, edges[:, 0])
2833
return rows
2934

30-
partner_parent_ts_d = get_parent_timestamps(cg, edges[:, 1])
35+
partner_parent_ts_d = get_parent_timestamps(cg, np.unique(edges[:, 1]))
3136
for v in partner_parent_ts_d.values():
3237
timestamps.update(v)
3338

3439
for ts in sorted(timestamps):
35-
if ts < earliest_ts:
36-
ts = earliest_ts
40+
if ts < node_ts:
41+
continue
42+
if ts > node_end_ts:
43+
break
3744
val_dict = {}
3845
svs = edges[:, 1]
3946
parents = cg.get_parents(svs, time_stamp=ts)
@@ -51,35 +58,78 @@ def update_cross_edges(
5158
return rows
5259

5360

54-
def update_nodes(cg: ChunkedGraph, nodes) -> list:
55-
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
56-
earliest_ts = cg.get_earliest_timestamp()
61+
def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
62+
if children_map is None:
63+
children_map = CHILDREN
64+
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map)
5765
timestamps_d = get_parent_timestamps(cg, nodes)
5866
cx_edges_d = cg.get_atomic_cross_edges(nodes)
5967
rows = []
60-
for node, node_ts in zip(nodes, nodes_ts):
68+
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
6169
if cg.get_parent(node) is None:
62-
# invalid id caused by failed ingest task
70+
# invalid id caused by failed ingest task / edits
6371
continue
6472
_cx_edges_d = cx_edges_d.get(node, {})
6573
if not _cx_edges_d:
6674
continue
6775
_rows = update_cross_edges(
68-
cg, node, _cx_edges_d, node_ts, timestamps_d[node], earliest_ts
76+
cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d[node]
6977
)
7078
rows.extend(_rows)
7179
return rows
7280

7381

74-
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2):
82+
def _update_nodes_helper(args):
83+
cg, nodes, nodes_ts = args
84+
return update_nodes(cg, nodes, nodes_ts)
85+
86+
87+
def update_chunk(
88+
cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2, debug: bool = False
89+
):
7590
"""
7691
Iterate over all L2 IDs in a chunk and update their cross chunk edges,
7792
within the periods they were valid/active.
7893
"""
94+
global CHILDREN
95+
96+
start = time.time()
7997
x, y, z = chunk_coords
8098
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
8199
cg.copy_fake_edges(chunk_id)
82100
rr = cg.range_read_chunk(chunk_id)
83-
nodes = list(rr.keys())
84-
rows = update_nodes(cg, nodes)
85-
cg.client.write(rows)
101+
102+
nodes = []
103+
nodes_ts = []
104+
earliest_ts = cg.get_earliest_timestamp()
105+
for k, v in rr.items():
106+
nodes.append(k)
107+
CHILDREN[k] = v[Hierarchy.Child][0].value
108+
ts = v[Hierarchy.Child][0].timestamp
109+
nodes_ts.append(earliest_ts if ts < earliest_ts else ts)
110+
111+
if len(nodes) > 0:
112+
logging.info(f"Processing {len(nodes)} nodes.")
113+
assert len(CHILDREN) > 0, (nodes, CHILDREN)
114+
else:
115+
return
116+
117+
if debug:
118+
rows = update_nodes(cg, nodes, nodes_ts)
119+
cg.client.write(rows)
120+
else:
121+
task_size = int(math.ceil(len(nodes) / 64))
122+
chunked_nodes = chunked(nodes, task_size)
123+
chunked_nodes_ts = chunked(nodes_ts, task_size)
124+
tasks = []
125+
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
126+
args = (cg, chunk, ts_chunk)
127+
tasks.append(args)
128+
129+
rows = []
130+
with ThreadPoolExecutor(max_workers=8) as executor:
131+
futures = [executor.submit(_update_nodes_helper, task) for task in tasks]
132+
for future in tqdm(as_completed(futures), total=len(futures)):
133+
rows.extend(future.result())
134+
135+
print(f"total elaspsed time: {time.time() - start}")

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging, math, random, time
44
import multiprocessing as mp
55
from collections import defaultdict
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
67

78
import fastremap
89
import numpy as np
@@ -15,7 +16,7 @@
1516
from pychunkedgraph.graph.types import empty_2d
1617
from pychunkedgraph.utils.general import chunked
1718

18-
from .utils import exists_as_parent, get_parent_timestamps
19+
from .utils import exists_as_parent, get_end_timestamps, get_parent_timestamps
1920

2021

2122
CHILDREN = {}
@@ -51,7 +52,7 @@ def _get_cx_edges_at_timestamp(node, response, ts):
5152

5253

5354
def _populate_cx_edges_with_timestamps(
54-
cg: ChunkedGraph, layer: int, nodes: list, nodes_ts: list, earliest_ts
55+
cg: ChunkedGraph, layer: int, nodes: list, nodes_ts: list
5556
):
5657
"""
5758
Collect timestamps of edits from children, since we use the same timestamp
@@ -63,7 +64,8 @@ def _populate_cx_edges_with_timestamps(
6364
all_children = np.concatenate(list(CHILDREN.values()))
6465
response = cg.client.read_nodes(node_ids=all_children, properties=attrs)
6566
timestamps_d = get_parent_timestamps(cg, nodes)
66-
for node, node_ts in zip(nodes, nodes_ts):
67+
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN)
68+
for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps):
6769
CX_EDGES[node] = {}
6870
timestamps = timestamps_d[node]
6971
cx_edges_d_node_ts = _get_cx_edges_at_timestamp(node, response, node_ts)
@@ -75,8 +77,8 @@ def _populate_cx_edges_with_timestamps(
7577
CX_EDGES[node][node_ts] = cx_edges_d_node_ts
7678

7779
for ts in sorted(timestamps):
78-
if ts < earliest_ts:
79-
ts = earliest_ts
80+
if ts > node_end_ts:
81+
break
8082
CX_EDGES[node][ts] = _get_cx_edges_at_timestamp(node, response, ts)
8183

8284

@@ -107,7 +109,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
107109

108110
row_id = serializers.serialize_uint64(node)
109111
for ts, cx_edges_d in CX_EDGES[node].items():
110-
if node_ts > ts:
112+
if ts < node_ts:
111113
continue
112114
edges = get_latest_edges_wrapper(cg, cx_edges_d, parent_ts=ts)
113115
if edges.size == 0:
@@ -129,17 +131,29 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
129131
return rows
130132

131133

134+
def _update_cross_edges_helper_thread(args):
135+
cg, layer, node, node_ts, earliest_ts = args
136+
return update_cross_edges(cg, layer, node, node_ts, earliest_ts)
137+
138+
132139
def _update_cross_edges_helper(args):
133140
cg_info, layer, nodes, nodes_ts, earliest_ts = args
134141
rows = []
135142
cg = ChunkedGraph(**cg_info)
136143
parents = cg.get_parents(nodes, fail_to_zero=True)
144+
145+
tasks = []
137146
for node, parent, node_ts in zip(nodes, parents, nodes_ts):
138147
if parent == 0:
139-
# invalid id caused by failed ingest task
148+
# invalid id caused by failed ingest task / edits
140149
continue
141-
_rows = update_cross_edges(cg, layer, node, node_ts, earliest_ts)
142-
rows.extend(_rows)
150+
tasks.append((cg, layer, node, node_ts, earliest_ts))
151+
152+
with ThreadPoolExecutor(max_workers=4) as executor:
153+
futures = [executor.submit(_update_cross_edges_helper_thread, task) for task in tasks]
154+
for future in tqdm(as_completed(futures), total=len(futures)):
155+
rows.extend(future.result())
156+
143157
cg.client.write(rows)
144158

145159

@@ -159,7 +173,7 @@ def update_chunk(
159173
nodes = list(CHILDREN.keys())
160174
random.shuffle(nodes)
161175
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
162-
_populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts, earliest_ts)
176+
_populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts)
163177

164178
task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
165179
chunked_nodes = chunked(nodes, task_size)
@@ -171,8 +185,9 @@ def update_chunk(
171185
args = (cg_info, layer, chunk, ts_chunk, earliest_ts)
172186
tasks.append(args)
173187

174-
logging.info(f"Processing {len(nodes)} nodes.")
175-
with mp.Pool(min(mp.cpu_count(), len(tasks))) as pool:
188+
processes = min(mp.cpu_count() * 2, len(tasks))
189+
logging.info(f"Processing {len(nodes)} nodes with {processes} workers.")
190+
with mp.Pool(processes) as pool:
176191
_ = list(
177192
tqdm(
178193
pool.imap_unordered(_update_cross_edges_helper, tasks),

0 commit comments

Comments
 (0)