Skip to content

Commit 3f8832d

Browse files
Label confidence pair post processor logits as tensor
1 parent d683d1d commit 3f8832d

File tree

1 file changed

+11
-12
lines changed
  • facetorch/analyzer/predictor

1 file changed

+11
-12
lines changed

facetorch/analyzer/predictor/post.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,20 +297,19 @@ def run(self, preds: torch.Tensor) -> List[Prediction]:
297297
if isinstance(preds, tuple):
298298
preds = preds[0]
299299

300-
# Convert tensor to numpy array once instead of in the loop
301-
preds_np = preds.cpu().numpy()
302-
303-
# Use list comprehension instead of loop for creating pred_list
304-
pred_list = [
305-
Prediction(
300+
pred_list = []
301+
for i in range(preds.shape[0]):
302+
preds_sample = preds[i]
303+
preds_sample_list = preds_sample.cpu().numpy().tolist()
304+
other_labels = {
305+
label: preds_sample_list[j] + self.offsets[j]
306+
for j, label in enumerate(self.labels)
307+
}
308+
pred = Prediction(
306309
label="other",
307310
logits=preds_sample,
308-
other={
309-
label: preds_np[i, j] + offset
310-
for j, (label, offset) in enumerate(zip(self.labels, self.offsets))
311-
},
311+
other=other_labels,
312312
)
313-
for i, preds_sample in enumerate(preds_np)
314-
]
313+
pred_list.append(pred)
315314

316315
return pred_list

0 commit comments

Comments
 (0)