Skip to content

Commit

Permalink
Skip task if it's already predicted with the same model
Browse files Browse the repository at this point in the history
  • Loading branch information
Alyetama committed Apr 21, 2022
1 parent 39f729e commit 2077a86
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
import requests
from PIL import UnidentifiedImageError
from dotenv import load_dotenv
from tqdm import tqdm

from model_utils.mongodb_helper import get_tasks_from_mongodb
from tqdm import tqdm

try:
import torch
Expand All @@ -32,8 +31,8 @@

def keyboard_interrupt_handler(sig, _):
"""This function handles the KeyboardInterrupt (CTRL+C) signal.
It's a handler for the signal, which means it's called when the OS sends the signal.
The signal is sent when the user presses CTRL+C.
It's a handler for the signal, which means it's called when the OS sends
the signal. The signal is sent when the user presses CTRL+C.
Parameters
----------
Expand Down Expand Up @@ -98,9 +97,9 @@ def model(self):


class Predict(LoadModel, _Headers):
"""This class is used to predict bounding boxes for images in a given project.
It uses the YOLOv5 model to predict bounding boxes and then posts the
predictions to the Label Studio server.
"""This class is used to predict bounding boxes for images in a given
project. It uses the YOLOv5 model to predict bounding boxes and then posts
the predictions to the Label Studio server.
Parameters
----------
Expand Down Expand Up @@ -271,7 +270,10 @@ def get_task(self, _task_id: int) -> dict:
Examples
--------
>>> get_task(1)
{'id': 1, 'data': {'image': 'http://localhost:8000/data/local-files/1.jpg'}}
{
'id': 1,
'data': {'image': 'http://localhost:8000/data/local-files/1.jpg'}
}
"""
url = f'{os.environ["LS_HOST"]}/api/tasks/{_task_id}'
resp = requests.get(url, headers=self.headers)
Expand Down Expand Up @@ -398,7 +400,8 @@ def single_task(self, task_id):

@staticmethod
def pred_result(x, y, w, h, score, label):
"""This function takes in the x, y, width, height, score, and label of a prediction and returns a dictionary with the prediction's information.
"""This function takes in the x, y, width, height, score, and label of
a prediction and returns a dictionary with the prediction's information.
Parameters
----------
Expand Down Expand Up @@ -435,7 +438,8 @@ def pred_result(x, y, w, h, score, label):
}

def pred_post(self, results, scores, task_id):
"""This function is used to create an API POST request of a single prediction results.
"""This function is used to create an API POST request of a single
prediction results.
Parameters
----------
Expand All @@ -462,7 +466,8 @@ def pred_post(self, results, scores, task_id):
}

def post_prediction(self, task):
"""This function is called by the `predict` method. It takes a task as an argument and performs the following steps:
"""This function is called by the `predict` method. It takes a task as
an argument and performs the following steps:
1. It downloads the image from the task's `data` field.
2. It runs the image through the model and gets the predictions.
Expand All @@ -471,9 +476,11 @@ def post_prediction(self, task):
If the task has no data, it skips the task.
If the task has no predictions, it deletes the task if `delete_if_no_predictions` is set to `True`.
If the task has no predictions, it deletes the task if
`delete_if_no_predictions` is set to `True`.
If `if_empty_apply_label` is set to a label, it applies a the string of `if_empty_apply_label` if not set to `None`.
If `if_empty_apply_label` is set to a label, it applies a the string of
`if_empty_apply_label` if not set to `None`.
Parameters
----------
Expand All @@ -487,6 +494,18 @@ def post_prediction(self, task):
"""
try:
task_id = task['id']
pred_ids = task['predictions']

for pred_id in pred_ids:
url = f'{os.environ["LS_HOST"]}/api/predictions/{pred_id}'
resp = requests.get(url, headers=headers)
pred_details = resp.json()
if self.model_version == pred_details['model_version']:
logger.debug(
f'Task {task_id} is already predicted with model '
f'`{self.model_version}`. Skipping...')
return

try:
img = self.download_image(
self.get_task(task_id)['data']['image'])
Expand Down

0 comments on commit 2077a86

Please sign in to comment.