Skip to content

Commit

Permalink
Merge pull request #85 from DevLinyan/main
Browse files Browse the repository at this point in the history
fix evaluation
  • Loading branch information
ChonghaoSima authored Apr 26, 2024
2 parents eb1d4b7 + 647d4b2 commit 6f20873
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
3 changes: 3 additions & 0 deletions challenge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ Chances are that you are not logged in to the current competition space.
Please refresh the page, click `Login with Hugging Face` at the bottom of the left panel.
### If I encounter a reshape error, what should I do?
You should first refer to this [location](https://github.com/OpenDriveLab/DriveLM/blob/main/challenge/evaluation.py#L90). Most of the reshape errors occur here.
### Finally, which dataset do we submit to the competition?
Expand Down
41 changes: 29 additions & 12 deletions challenge/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ def eval_match(self):
for i in range(len(self.match["match"]["answer"])):
answer = self.match["match"]["answer"][i]
GT = self.match["match"]["GT"][i]
matched = self.match_result(answer, GT)
GT_nums = re.findall(r'\d+\.\d+', GT)
GT_nums = np.array([list(map(float, x.split()))[0] for x in GT_nums]).reshape(-1, 2)
GT_nums = [list(i) for i in GT_nums]
outs1.append(len(matched) / len(GT_nums) * 100)
_, F1_score = self.match_result(answer, GT)
outs1.append(F1_score * 100)

outs1 = sum(outs1) / len(outs1)
outs2 = self.eval_chatGPT(self.match["GPT"])
Expand Down Expand Up @@ -90,18 +87,38 @@ def match_result(self, answer, GT):
answer_nums = np.array([list(map(float, x.split()))[0] for x in answer_nums]).reshape(-1, 2)
GT_nums = np.array([list(map(float, x.split()))[0] for x in GT_nums]).reshape(-1, 2)

if len(answer_nums) == 0:
return [], 0

matched_out = []
for ans in answer_nums:
true_positives = 0
false_positives = 0
false_negatives = 0
for pred in answer_nums:
closest_distance = float('inf')
closest_gt = None
for gt in GT_nums:
distance = np.sum(np.abs(ans - gt))
if distance < 16:
matched_out.append(gt)
break
distance = np.sum(np.abs(pred - gt))
if distance < closest_distance:
closest_distance = distance
closest_gt = gt

if closest_distance < 16:
true_positives += 1
matched_out.append(closest_gt)
GT_nums.remove(closest_gt)
else:
false_positives += 1

false_negatives = len(GT_nums) - true_positives
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
F1 = 2 * precision * recall / (precision + recall)

return matched_out
return matched_out, F1

def set_graph(self, answer, GT):
self.graph = self.match_result(answer, GT)
self.graph, _ = self.match_result(answer, GT)
self.graph = [list(i) for i in self.graph]

def forward(self, tag, answer, GT):
Expand Down

0 comments on commit 6f20873

Please sign in to comment.