Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Nov 21, 2023
1 parent 266423f commit e89f8c7
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def _forward(self, data: dict) -> dict:

if do_early_stop:
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
break
if do_point_pruning and desc0.shape[-2] > pruning_th:
scores0 = self.log_assignment[i].get_matchability(desc0)
Expand Down Expand Up @@ -571,7 +571,7 @@ def _forward(self, data: dict) -> dict:
"prune1": prune1,
}

desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
matches, mscores = [], []
Expand Down Expand Up @@ -600,7 +600,7 @@ def _forward(self, data: dict) -> dict:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers

pred = {
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
Expand All @@ -612,8 +612,6 @@ def _forward(self, data: dict) -> dict:
"prune1": prune1,
}

return pred

def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
Expand Down

0 comments on commit e89f8c7

Please sign in to comment.