Skip to content

Commit

Permalink
Add remove unwanted objects example
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Dec 6, 2024
1 parent b6c143b commit c835c36
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 39 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,34 @@ show_image(img)

![car license plate recognition](https://github.com/abraia/abraia-multiple/raw/master/images/car-plate.jpg)

## Remove unwanted objects

Directly remove unwanted objects in images and photos locally.

```python
from abraia.utils import load_image, Sketcher
from abraia.editing.inpaint import LAMA
from abraia.editing.sam import SAM


img = load_image('images/dog.jpg')

sam = SAM()
lama = LAMA()
sam.encode(img)

sketcher = Sketcher(img)

def on_click(point):
mask = sam.predict(img, f'[{{"type":"point","data":[{point[0]},{point[1]}],"label":1}}]')
sketcher.mask = sketcher.dilate(mask)

sketcher.on_click(on_click)
sketcher.run(lama.predict)
```

![inpaint output](https://github.com/abraia/abraia-multiple/raw/master/images/inpaint-output.jpg)

## Command line interface

The Abraia CLI provides access to the Abraia Cloud Platform through the command line. It makes simple to manage your files and enables bulk image editing capabilities. It provides and easy way to resize, convert, and compress your images - JPEG, WebP, or PNG -, and get them ready to publish on the web. Moreover, you can automatically remove the background, upscale, or anonymize your images in bulk.
Expand Down
6 changes: 2 additions & 4 deletions abraia/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import onnxruntime as ort

from .ops import py_cpu_nms, normalize, mask_to_polygon
from .utils import download_file, load_json, get_color
from .utils import download_file, load_json, get_color, get_providers
from .utils import load_image, show_image, save_image, Video
from .draw import render_results

Expand Down Expand Up @@ -160,9 +160,7 @@ class Model:
def load(self, model_uri):
config_uri = f"{os.path.splitext(model_uri)[0]}.json"
self.config = load_json(download_file(config_uri))
providers = ["CUDAExecutionProvider", "CoreMLExecutionProvider", "CPUExecutionProvider"]
providers = [provider for provider in ort.get_available_providers() if provider in providers]
self.session = ort.InferenceSession(download_file(model_uri), providers=providers)
self.session = ort.InferenceSession(download_file(model_uri), providers=get_providers())
self.input_name = self.session.get_inputs()[0].name
self.input_shape = self.config['inputShape']

Expand Down
41 changes: 23 additions & 18 deletions abraia/editing/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import onnxruntime as ort

from ..utils import download_file
from ..utils import download_file, get_providers


def get_input_points(prompt):
Expand All @@ -26,19 +26,29 @@ class SAM:
def __init__(self):
self.target_size = 1024
self.input_size = (684, 1024)
sess_options = ort.SessionOptions()
providers = ort.get_available_providers()
encoder_src = download_file('multiple/models/mobile_sam.encoder.onnx')
decoder_src = download_file('multiple/models/mobile_sam.decoder.onnx')
self.encoder = ort.InferenceSession(encoder_src, providers=providers, sess_options=sess_options)
self.decoder = ort.InferenceSession(decoder_src, providers=providers, sess_options=sess_options)
self.encoder = ort.InferenceSession(encoder_src, providers=get_providers())
self.decoder = ort.InferenceSession(decoder_src, providers=get_providers())

def encode(self, img):
# TODO: Refactor to use resize
scale_x = self.input_size[1] / img.shape[1]
scale_y = self.input_size[0] / img.shape[0]
scale = min(scale_x, scale_y)

transform_matrix = np.array([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])

size = (self.input_size[1], self.input_size[0])
img = cv2.warpAffine(img, transform_matrix[:2], size, flags=cv2.INTER_LINEAR)

encoder_input_name = self.encoder.get_inputs()[0].name
encoder_inputs = {encoder_input_name: img.astype(np.float32)}
encoder_output = self.encoder.run(None, encoder_inputs)
image_embedding = encoder_output[0]
return image_embedding
self.image_embedding = encoder_output[0]
return self.image_embedding

def decode(self, image_embedding, input_points, input_labels):
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
Expand Down Expand Up @@ -68,22 +78,17 @@ def predict(self, img, prompt="[]"):
transform_matrix = np.array([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])

size = (self.input_size[1], self.input_size[0])
img = cv2.warpAffine(img, transform_matrix[:2], size, flags=cv2.INTER_LINEAR)
image_embedding = self.encode(img)

# embedding = {"image_embedding": image_embedding,
# "original_size": (width, height),
# "transform_matrix": transform_matrix}

if self.image_embedding is None:
self.image_embedding = self.encode(img)

input_points, input_labels = get_input_points(prompt)
input_points = input_points * scale
masks = self.decode(image_embedding, input_points, input_labels)
masks = self.decode(self.image_embedding, input_points, input_labels)

inv_transform_matrix = np.linalg.inv(transform_matrix)
mask = np.zeros((height, width, 3), dtype=np.uint8)
mask = np.zeros((height, width), dtype=np.uint8)
for m in masks:
m = cv2.warpAffine(m, inv_transform_matrix[:2], (width, height), flags=cv2.INTER_LINEAR)
mask[m > 0.0] = [255, 255, 255]
mask[m > 0.0] = 255
return mask
7 changes: 7 additions & 0 deletions abraia/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import requests
import numpy as np
import onnxruntime as ort

from tqdm import tqdm
from PIL import Image, ImageOps
Expand Down Expand Up @@ -74,3 +75,9 @@ def get_color(idx):
def hex_to_rgb(hex):
h = hex.lstrip('#')
return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))


def get_providers():
available_providers = ort.get_available_providers()
providers = ["CUDAExecutionProvider", "CoreMLExecutionProvider", "CPUExecutionProvider"]
return [provider for provider in available_providers if provider in providers]
65 changes: 48 additions & 17 deletions abraia/utils/sketcher.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,68 @@
'''
Sketcher.
Keys:
SPACE - callback
r - reset the mask
s - save output
ESC - exit
'''

import cv2
import numpy as np


def draw_mask(img, mask, color, opacity = 1):
img_copy = img.copy()
overlay = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
overlay[mask == 255] = color
img_over = cv2.addWeighted(img_copy, 1 - opacity, overlay, opacity, 0)
img_copy[mask == 255] = img_over[mask == 255]
return img_copy


class Sketcher:
def __init__(self, src, radius=15):
def __init__(self, img, radius=15):
print(__doc__)
self.img = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
self.img_msk = self.img.copy()
self.mask = np.zeros(self.img.shape[:2], np.uint8)
self.prev_pt = None
self.win_name = 'Image'
self.dests = [self.img_msk, self.mask]
self.colors = [(255, 255, 255), 255]
self.radius = radius
self.dirty = False
self.show(self.img_msk)
self.load(img)
self.element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1))
cv2.setMouseCallback(self.win_name, self.on_mouse)

def load(self, img):
self.img = img
self.prev_pt = None
self.mask = np.zeros(img.shape[:2], np.uint8)
self.show(self.img)

def dilate(self, mask):
return cv2.dilate(mask, self.element)

def show(self, img, mask=None):
if mask is not None:
img = draw_mask(img, mask, (255, 0, 0), 0.5)
self.output = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow(self.win_name, self.output)

def show(self, img):
cv2.imshow(self.win_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def on_click(self, callback):
self.handle_click = callback

def on_mouse(self, event, x, y, flags, param):
pt = (x, y)
if event == cv2.EVENT_LBUTTONDOWN:
self.prev_pt = pt
if self.prev_pt and flags & cv2.EVENT_FLAG_LBUTTON:
for dst, color in zip(self.dests, self.colors):
cv2.line(dst, self.prev_pt, pt, color, self.radius)
self.dirty = True
cv2.line(self.mask, self.prev_pt, pt, 255, self.radius)
self.prev_pt = pt
self.show(self.img_msk)
else:
self.prev_pt = None
if event == cv2.EVENT_LBUTTONUP:
if self.handle_click:
self.handle_click(pt)
if self.prev_pt:
self.show(self.img, self.mask)

def run(self, callback):
while True:
Expand All @@ -41,7 +72,7 @@ def run(self, callback):
if ch == ord(' '):
self.show(callback(self.img, self.mask))
if ch == ord('r'):
self.img_msk[:] = self.img
self.mask[:] = 0
self.show(self.img_msk)
self.load(self.img)
if ch == ord('s'):
cv2.imwrite('output.png', self.output)
cv2.destroyWindow(self.win_name)
Binary file added images/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/inpaint-output.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c835c36

Please sign in to comment.