11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member 
22
3+ from  collections  import  defaultdict 
34from  concurrent .futures  import  ThreadPoolExecutor , as_completed 
45import  logging , math , time 
6+ from  copy  import  copy 
57
68import  fastremap 
79import  numpy  as  np 
810from  tqdm  import  tqdm 
9- from  pychunkedgraph .graph  import  ChunkedGraph 
11+ from  pychunkedgraph .graph  import  ChunkedGraph ,  types 
1012from  pychunkedgraph .graph .attributes  import  Connectivity , Hierarchy 
1113from  pychunkedgraph .graph .utils  import  serializers 
1214from  pychunkedgraph .utils .general  import  chunked 
1315
14- from  .utils  import  exists_as_parent ,  get_end_timestamps , get_parent_timestamps 
16+ from  .utils  import  get_end_timestamps , get_parent_timestamps 
1517
1618CHILDREN  =  {}
1719
1820
1921def  update_cross_edges (
20-     cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , node_end_ts , timestamps : set 
22+     cg : ChunkedGraph ,
23+     node ,
24+     cx_edges_d : dict ,
25+     node_ts ,
26+     node_end_ts ,
27+     timestamps_d : defaultdict [int , set ],
2128) ->  list :
2229    """ 
2330    Helper function to update a single L2 ID. 
2431    Returns a list of mutations with given timestamps. 
2532    """ 
2633    rows  =  []
2734    edges  =  np .concatenate (list (cx_edges_d .values ()))
28-     uparents  =  np .unique (cg .get_parents (edges [:, 0 ], time_stamp = node_ts ))
29-     assert  uparents .size  <=  1 , f"{ node }  , { node_ts }  , { uparents }  " 
30-     if  uparents .size  ==  0  or  node  !=  uparents [0 ]:
31-         # if node is not the parent at this ts, it must be invalid 
32-         assert  not  exists_as_parent (cg , node , edges [:, 0 ])
33-         return  rows 
35+     partners  =  np .unique (edges [:, 1 ])
3436
35-     partner_parent_ts_d  =  get_parent_timestamps ( cg ,  np . unique ( edges [:,  1 ]) )
36-     for  v  in  partner_parent_ts_d . values () :
37-         timestamps .update (v )
37+     timestamps  =  copy ( timestamps_d [ node ] )
38+     for  partner  in  partners :
39+         timestamps .update (timestamps_d [ partner ] )
3840
3941    for  ts  in  sorted (timestamps ):
4042        if  ts  <  node_ts :
4143            continue 
4244        if  ts  >  node_end_ts :
4345            break 
4446        val_dict  =  {}
45-          svs   =   edges [:,  1 ] 
46-         parents  =  cg .get_parents (svs , time_stamp = ts )
47-         edge_parents_d  =  dict (zip (svs , parents ))
47+ 
48+         parents  =  cg .get_parents (partners , time_stamp = ts )
49+         edge_parents_d  =  dict (zip (partners , parents ))
4850        for  layer , layer_edges  in  cx_edges_d .items ():
4951            layer_edges  =  fastremap .remap (
5052                layer_edges , edge_parents_d , preserve_missing_labels = True 
@@ -62,19 +64,21 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
6264    if  children_map  is  None :
6365        children_map  =  CHILDREN 
6466    end_timestamps  =  get_end_timestamps (cg , nodes , nodes_ts , children_map )
65-      timestamps_d   =   get_parent_timestamps ( cg ,  nodes ) 
67+ 
6668    cx_edges_d  =  cg .get_atomic_cross_edges (nodes )
69+     all_cx_edges  =  [types .empty_2d ]
70+     for  _cx_edges_d  in  cx_edges_d .values ():
71+         if  _cx_edges_d :
72+             all_cx_edges .append (np .concatenate (list (_cx_edges_d .values ())))
73+     all_partners  =  np .unique (np .concatenate (all_cx_edges )[:, 1 ])
74+     timestamps_d  =  get_parent_timestamps (cg , np .concatenate ([nodes , all_partners ]))
75+ 
6776    rows  =  []
6877    for  node , node_ts , end_ts  in  zip (nodes , nodes_ts , end_timestamps ):
69-         if  cg .get_parent (node ) is  None :
70-             # invalid id caused by failed ingest task / edits 
71-             continue 
7278        _cx_edges_d  =  cx_edges_d .get (node , {})
7379        if  not  _cx_edges_d :
7480            continue 
75-         _rows  =  update_cross_edges (
76-             cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d [node ]
77-         )
81+         _rows  =  update_cross_edges (cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d )
7882        rows .extend (_rows )
7983    return  rows 
8084
@@ -84,9 +88,7 @@ def _update_nodes_helper(args):
8488    return  update_nodes (cg , nodes , nodes_ts )
8589
8690
87- def  update_chunk (
88-     cg : ChunkedGraph , chunk_coords : list [int ], layer : int  =  2 , debug : bool  =  False 
89- ):
91+ def  update_chunk (cg : ChunkedGraph , chunk_coords : list [int ], debug : bool  =  False ):
9092    """ 
9193    Iterate over all L2 IDs in a chunk and update their cross chunk edges, 
9294    within the periods they were valid/active. 
@@ -95,7 +97,7 @@ def update_chunk(
9597
9698    start  =  time .time ()
9799    x , y , z  =  chunk_coords 
98-     chunk_id  =  cg .get_chunk_id (layer = layer , x = x , y = y , z = z )
100+     chunk_id  =  cg .get_chunk_id (layer = 2 , x = x , y = y , z = z )
99101    cg .copy_fake_edges (chunk_id )
100102    rr  =  cg .range_read_chunk (chunk_id )
101103
@@ -108,22 +110,29 @@ def update_chunk(
108110        ts  =  v [Hierarchy .Child ][0 ].timestamp 
109111        nodes_ts .append (earliest_ts  if  ts  <  earliest_ts  else  ts )
110112
113+     nodes  =  np .array (nodes , dtype = np .uint64 )
114+     node_parents  =  cg .get_parents (nodes , fail_to_zero = True )
115+     invalid_mask  =  node_parents  !=  0 
116+     nodes  =  nodes [invalid_mask ]
117+     nodes_ts  =  nodes_ts [invalid_mask ]
118+ 
111119    if  len (nodes ) >  0 :
112-         logging .info (f"Processing  { len (nodes )}   nodes." )
120+         logging .info (f"processing  { len (nodes )}   nodes." )
113121        assert  len (CHILDREN ) >  0 , (nodes , CHILDREN )
114122    else :
115123        return 
116124
117125    if  debug :
118126        rows  =  update_nodes (cg , nodes , nodes_ts )
119127    else :
120-         task_size  =  int (math .ceil (len (nodes ) /  64 ))
128+         task_size  =  int (math .ceil (len (nodes ) /  16 ))
121129        chunked_nodes  =  chunked (nodes , task_size )
122130        chunked_nodes_ts  =  chunked (nodes_ts , task_size )
123131        tasks  =  []
124132        for  chunk , ts_chunk  in  zip (chunked_nodes , chunked_nodes_ts ):
125133            args  =  (cg , chunk , ts_chunk )
126134            tasks .append (args )
135+         logging .info (f"task size { task_size }  , count { len (tasks )}  ." )
127136
128137        rows  =  []
129138        with  ThreadPoolExecutor (max_workers = 8 ) as  executor :
@@ -132,4 +141,4 @@ def update_chunk(
132141                rows .extend (future .result ())
133142
134143    cg .client .write (rows )
135-     print (f"total elaspsed time: { time .time () -  start }  " )
144+     logging . info (f"total elaspsed time: { time .time () -  start }  " )
0 commit comments