forked from max-andr/square-attack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gvision_utils.py
93 lines (68 loc) · 2.56 KB
/
gvision_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from PIL import Image
import random
import numpy as np
from google.cloud import vision
import io
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "keys/Trema-14000fdb4eac.json"
def convert_to_pillow(img, channels_first=True):
"""Convert numpy img to pillow Image object"""
# convert from channels-first to channels-last
if channels_first:
img = img.transpose(1, 2, 0)
img = (img * 255).astype(np.uint8)
return Image.fromarray(img)
class GVisionResults:
def __init__(self, results):
self.results = results
def match(self, labelset, inverse=False):
labels = []
scores = []
for label, score in self.results:
if inverse:
# add to result if none of the patterns match given label
if not any([l.lower() in label.lower() for l in labelset]):
labels.append(label)
scores.append(score)
else:
if any([l.lower() in label.lower() for l in labelset]):
labels.append(label)
scores.append(score)
return GVisionResults(list(zip(labels, scores)))
@property
def labels(self):
return [l for l, s in self.results]
@property
def scores(self):
return [s for l, s in self.results]
@property
def top_label(self):
top_score = max([s for l, s in self.results])
assert(top_score == self.results[0][1])
return self.results[0][0]
@property
def top_score(self):
top_score = max([s for l, s in self.results])
assert(top_score == self.results[0][1])
return top_score
def __str__(self):
return "\n".join([l + ": " + str(s) for l, s in self.results])
def gvision_classify_numpy(img):
"""Return the labels and scores by calling the cloud API
Args:
img -- numpy [W, H, C] image with values [0, 1]
"""
img = convert_to_pillow(img)
fname = "/tmp/.temp_img_" + str(random.randint(0, 1000000)) + ".png"
img.save(fname)
client = vision.ImageAnnotatorClient()
# Loads the image into memory
with io.open(fname, 'rb') as image_file:
content = image_file.read()
image = vision.Image(content=content)
# Performs label detection on the image file
response = client.label_detection(image=image, max_results=100)
labels = response.label_annotations
descriptions = [label.description for label in labels]
scores = [label.score for label in labels]
return GVisionResults(list(zip(descriptions, scores)))