-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #126 from tgisaturday/master
Add SPICE from coco-caption
- Loading branch information
Showing
25 changed files
with
142 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from __future__ import division | ||
import os | ||
import sys | ||
import subprocess | ||
import threading | ||
import json | ||
import numpy as np | ||
import ast | ||
import tempfile | ||
|
||
# Assumes spice.jar is in the same directory as spice.py. Change as needed. | ||
SPICE_JAR = 'spice-1.0.jar' | ||
TEMP_DIR = 'tmp' | ||
CACHE_DIR = 'cache' | ||
|
||
|
||
def enc(s): | ||
return s.encode('utf-8') | ||
|
||
|
||
def dec(s): | ||
return s.decode('utf-8') | ||
|
||
|
||
|
||
class Spice: | ||
""" | ||
Main Class to compute the SPICE metric | ||
""" | ||
|
||
def float_convert(self, obj): | ||
try: | ||
return float(obj) | ||
except: | ||
return np.nan | ||
|
||
def compute_score(self, gts, res): | ||
assert(sorted(gts.keys()) == sorted(res.keys())) | ||
imgIds = sorted(gts.keys()) | ||
|
||
# Prepare temp input file for the SPICE scorer | ||
input_data = [] | ||
for id in imgIds: | ||
hypo = res[id] | ||
ref = gts[id] | ||
|
||
# Sanity check. | ||
assert(type(hypo) is list) | ||
assert(len(hypo) == 1) | ||
assert(type(ref) is list) | ||
assert(len(ref) >= 1) | ||
|
||
input_data.append({ | ||
"image_id" : id, | ||
"test" : hypo[0], | ||
"refs" : ref | ||
}) | ||
|
||
cwd = os.path.dirname(os.path.abspath(__file__)) | ||
temp_dir=os.path.join(cwd, TEMP_DIR) | ||
if not os.path.exists(temp_dir): | ||
os.makedirs(temp_dir) | ||
in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, mode = "w") | ||
json.dump(input_data, in_file, indent=2) | ||
in_file.close() | ||
|
||
# Start job | ||
out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) | ||
out_file.close() | ||
cache_dir=os.path.join(cwd, CACHE_DIR) | ||
if not os.path.exists(cache_dir): | ||
os.makedirs(cache_dir) | ||
spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, | ||
'-cache', cache_dir, | ||
'-out', out_file.name, | ||
'-subset', | ||
'-silent' | ||
] | ||
subprocess.check_call(spice_cmd, | ||
cwd=os.path.dirname(os.path.abspath(__file__))) | ||
|
||
# Read and process results | ||
with open(out_file.name) as data_file: | ||
results = json.load(data_file) | ||
os.remove(in_file.name) | ||
os.remove(out_file.name) | ||
|
||
imgId_to_scores = {} | ||
spice_scores = [] | ||
for item in results: | ||
imgId_to_scores[item['image_id']] = item['scores'] | ||
spice_scores.append(self.float_convert(item['scores']['All']['f'])) | ||
average_score = np.mean(np.array(spice_scores)) | ||
scores = [] | ||
for image_id in imgIds: | ||
# Convert none to NaN before saving scores over subcategories | ||
score_set = {} | ||
for category,score_tuple in imgId_to_scores[image_id].items(): | ||
score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} | ||
scores.append(score_set) | ||
return average_score, scores | ||
|
||
def method(self): | ||
return "SPICE" | ||
|
||
|