-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval_kagglebirds2020.py
executable file
·84 lines (67 loc) · 2.71 KB
/
eval_kagglebirds2020.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Evaluates bird classification predictions in the Kaggle submission format.
For usage information, call with --help.
Author: Jan Schlüter
"""
from __future__ import print_function
from argparse import ArgumentParser
import numpy as np
import pandas as pd
import sklearn.metrics
from definitions.datasets.kagglebirds2020 import (derive_labelset,
make_multilabel_target)
def opts_parser():
descr = ("Evaluates bird classification predictions in the Kaggle "
"submission format.")
parser = ArgumentParser(description=descr)
parser.add_argument('gt', metavar='GTFILE',
type=str,
help='Ground truth .csv file (perfect_submission.csv)')
parser.add_argument('preds', metavar='PREDFILE',
type=str,
help='Prediction .csv file (submission.csv)')
parser.add_argument('--train-csv',
type=str, required=True,
help='File to read the set of class labels from.')
parser.add_argument('--f1-mode',
type=str, default='micro', choices=('micro', 'samples'),
help='F1 score aggregation: "micro" (default) or "samples"')
parser.add_argument('--nocall-class',
action='store_true', default=False,
help='If given, treat "nocall" as a separate class')
return parser
def main():
# parse command line
parser = opts_parser()
options = parser.parse_args()
gtfile = options.gt
predfile = options.preds
# figure out the set of labels
labelset = derive_labelset(pd.read_csv(options.train_csv))
if options.nocall_class:
labelset.append('nocall')
label_to_idx = dict((label, idx) for idx, label in enumerate(labelset))
# read ground truth
gt = pd.read_csv(gtfile)
gt = np.stack([make_multilabel_target(len(labelset),
[label_to_idx[label]
for label in birds.split(' ')
if label in label_to_idx])
for birds in gt.birds])
# read predictions
pr = pd.read_csv(predfile)
pr = np.stack([make_multilabel_target(len(labelset),
[label_to_idx[label]
for label in birds.split(' ')
if label in label_to_idx])
for birds in pr.birds])
# evaluate
p, r, f, _ = sklearn.metrics.precision_recall_fscore_support(
gt, pr, average=options.f1_mode)
print('micro-prec', p)
print('micro-rec', r)
print('micro-f1', f)
if __name__ == "__main__":
main()