@@ -22,9 +22,9 @@ def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda'
22
22
23
23
def condense (self , query , backs , ranking ):
24
24
stage1_preds = self ._stage1 (query , backs , ranking )
25
- stage2_preds = self ._stage2 (query , stage1_preds )
25
+ stage2_preds , stage2_preds_L3x = self ._stage2 (query , stage1_preds )
26
26
27
- return stage1_preds , stage2_preds
27
+ return stage1_preds , stage2_preds , stage2_preds_L3x
28
28
29
29
def _load_model (self , path , device ):
30
30
model = torch .load (path , map_location = 'cpu' )
@@ -128,8 +128,14 @@ def _stage2(self, query, preds):
128
128
129
129
preds = [(score , (pid , sid )) for (pid , sid ), score in zip (preds , scores )]
130
130
preds = sorted (preds , reverse = True )[:5 ]
131
+
132
+ preds_L3x = [x for score , x in preds if score > min (0 , preds [1 ][0 ] - 1e-10 )] # Take at least 2!
131
133
preds = [x for score , x in preds if score > 0 ]
132
134
133
- # TODO: Apply L3x for final stage.
134
-
135
- return preds
135
+ earliest_pids = f7 ([pid for pid , _ in preds_L3x ])[:4 ] # Take at most 4 docs.
136
+ preds_L3x = [(pid , sid ) for pid , sid in preds_L3x if pid in earliest_pids ]
137
+
138
+ assert len (preds_L3x ) >= 2
139
+ assert len (f7 ([pid for pid , _ in preds_L3x ])) <= 4
140
+
141
+ return preds , preds_L3x
0 commit comments