11
2+ from collections import defaultdict
23from typing import Optional
34from sqlmodel import Session , func , select , update
45
@@ -25,17 +26,10 @@ async def reset_user_oid(session: Session, oid: int):
2526 select (
2627 UserModel .id ,
2728 UserModel .oid ,
28- func .coalesce (
29- func .array_remove (
30- func .array_agg (UserWsModel .oid ),
31- None
32- ),
33- []
34- ).label ("oid_list" )
29+ UserWsModel .oid .label ("associated_oid" )
3530 )
3631 .join (UserWsModel , UserModel .id == UserWsModel .uid , isouter = True )
3732 .where (UserModel .id != 1 )
38- .group_by (UserModel .id )
3933 )
4034
4135 user_filter = (
@@ -46,17 +40,30 @@ async def reset_user_oid(session: Session, oid: int):
4640 )
4741 stmt = stmt .where (UserModel .id .in_ (user_filter ))
4842
49- result_user_list = session .exec (stmt )
50- for row in result_user_list :
51- result_dict = {}
52- for item , key in zip (row , row ._fields ):
53- result_dict [key ] = item
54-
55- origin_oid = result_dict ['oid' ]
56- oid_list : list = list (filter (lambda x : x != oid , result_dict ['oid_list' ]))
43+ result_user_list = session .exec (stmt ).all ()
44+ if not result_user_list :
45+ return
46+
47+ merged = defaultdict (list )
48+ extra_attrs = {}
49+
50+ for (id , oid , associated_oid ) in result_user_list :
51+ item = {"id" : id , "oid" : oid }
52+ merged [id ].append (associated_oid )
53+ if id not in extra_attrs :
54+ extra_attrs [id ] = {k : v for k , v in item .items ()}
55+
56+ # 组合结果
57+ result = [
58+ {** extra_attrs [user_id ], "oid_list" : oid_list }
59+ for user_id , oid_list in merged .items ()
60+ ]
61+
62+ for row in result :
63+ origin_oid = row ['oid' ]
64+ oid_list : list = list (filter (lambda x : x != oid , row ['oid_list' ]))
5765 if origin_oid not in oid_list :
58- result_dict ['oid' ] = oid_list [0 ] if oid_list else 0
59- if result_dict ['oid' ] != origin_oid :
60- result_dict .pop ("oid_list" , None )
61- update_stmt = update (UserModel ).where (UserModel .id == result_dict ['id' ]).values (oid = result_dict ['oid' ])
66+ row ['oid' ] = oid_list [0 ] if oid_list else 0
67+ if row ['oid' ] != origin_oid :
68+ update_stmt = update (UserModel ).where (UserModel .id == row ['id' ]).values (oid = row ['oid' ])
6269 session .exec (update_stmt )
0 commit comments