Skip to content
39 changes: 32 additions & 7 deletions text_localization_environment/TextLocEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,33 @@ def calculate_reward(self, action):
else:
reward = -self.ETA
else:
new_iou = self.compute_best_iou()
reward = np.sign(new_iou - self.iou)
new_iou, box = self.compute_best_iou()

if reward == 0:
env_box_center = np.array(self.bbox[:2]) + (np.array(self.bbox[2:]) - np.array(self.bbox[:2]))/2
text_box_center = np.array(self.episode_true_bboxes[box][0]) + \
(np.array(self.episode_true_bboxes[box][1]) - np.array(
self.episode_true_bboxes[box][0])) / 2
new_center_distance = np.linalg.norm(env_box_center - text_box_center)

reward = np.sign(new_iou - self.iou) + np.sign(self.center_distance - new_center_distance)/2

if new_iou == self.iou:
self.steps_since_last_change += 1
else:
self.steps_since_last_change = 0

if self.steps_since_last_change >= 3:
reward = -1
reward -= 1

self.iou = new_iou
self.center_distance = new_center_distance

return reward - self.current_step * self.DURATION_PENALTY

def calculate_potential_reward(self, action):
old_bbox = self.bbox
old_iou = self.iou
old_center_distance = self.center_distance

if self.action_set[action] != self.trigger:
self.action_set[action]()
Expand All @@ -119,6 +128,7 @@ def calculate_potential_reward(self, action):

self.bbox = old_bbox
self.iou = old_iou
self.center_distance = old_center_distance

return reward

Expand Down Expand Up @@ -193,11 +203,20 @@ def create_ior_mark(self):

def compute_best_iou(self):
max_iou = 0
max_i = -1
i = -1

for box in self.episode_true_bboxes:
max_iou = max(max_iou, self.compute_iou(box))
i += 1
candidate_iou = self.compute_iou(box)
if candidate_iou >= max_iou:
max_i = i
max_iou = candidate_iou

return max_iou
if max_i == -1:
raise ValueError("There is a problem in compute_best_iou")

return max_iou, max_i

def compute_iou(self, other_bbox):
"""Computes the intersection over union of the argument and the current bounding box."""
Expand Down Expand Up @@ -288,7 +307,13 @@ def reset(self, image_index=None):
self.current_step = 0
self.state = self.compute_state()
self.done = False
self.iou = self.compute_best_iou()
self.iou, box = self.compute_best_iou()

env_box_center = np.array(self.bbox[:2]) + (np.array(self.bbox[2:]) - np.array(self.bbox[:2])) / 2
text_box_center = np.array(self.episode_true_bboxes[box][0]) + \
(np.array(self.episode_true_bboxes[box][1]) - np.array(
self.episode_true_bboxes[box][0])) / 2
self.center_distance = np.linalg.norm(env_box_center - text_box_center)
self.max_iou = self.iou
self.steps_since_last_change = 0

Expand Down