Skip to content

Commit

Permalink
add or remove objects
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Mar 13, 2024
1 parent 9868939 commit f8cdbd9
Show file tree
Hide file tree
Showing 24 changed files with 70 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ python scripts/download_models.py

### Scripting Demo

(See also scripting_demo.py)
This is probably the best starting point if you want to use Cutie in your project. Hopefully, the script is self-explanatory. If not, feel free to open an issue. Run `scripting_demo.py` to see it in action. For more advanced usage, like adding or removing objects, see `scripting_demo_add_del_objects.py`.

```python
import os
Expand Down
59 changes: 7 additions & 52 deletions cutie/inference/inference_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


class InferenceCore:

def __init__(self,
network: CUTIE,
cfg: DictConfig,
Expand Down Expand Up @@ -327,55 +328,9 @@ def step(self,

return output_prob

def get_aux_outputs(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
image, pads = pad_divide_by(image, 16)
image = image.unsqueeze(0) # add the batch dimension
_, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)

aux_inputs = self.memory.aux
aux_outputs = self.network.compute_aux(pix_feat, aux_inputs, selector=None)
aux_outputs['q_weights'] = aux_inputs['q_weights']
aux_outputs['p_weights'] = aux_inputs['p_weights']

for k, v in aux_outputs.items():
if len(v.shape) == 5:
aux_outputs[k] = F.interpolate(v[0],
size=image.shape[-2:],
mode='bilinear',
align_corners=False)
elif 'weights' in k:
b, num_objects, num_heads, num_queries, h, w = v.shape
v = v.view(num_objects * num_heads, num_queries, h, w)
v = F.interpolate(v, size=image.shape[-2:], mode='bilinear', align_corners=False)
aux_outputs[k] = v.view(num_objects, num_heads, num_queries, *image.shape[-2:])
else:
aux_outputs[k] = F.interpolate(v,
size=image.shape[-2:],
mode='bilinear',
align_corners=False)[0]
aux_outputs[k] = unpad(aux_outputs[k], pads)
if 'weights' in k:
weights = aux_outputs[k]
weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0] +
1e-8)
aux_outputs[k] = (weights * 255).cpu().numpy()
else:
aux_outputs[k] = (aux_outputs[k].softmax(dim=0) * 255).cpu().numpy()

self.image_feature_store.delete(self.curr_ti)
return aux_outputs

def get_aux_object_weights(self, image: torch.Tensor) -> np.ndarray:
image, pads = pad_divide_by(image, 16)
# B*num_objects*H*W*num_queries -> num_objects*num_queries*H*W
# weights = F.softmax(self.obj_logits, dim=-1)[0]
weights = F.sigmoid(self.obj_logits)[0]
weights = weights.permute(0, 3, 1, 2).contiguous()
weights = F.interpolate(weights,
size=image.shape[-2:],
mode='bilinear',
align_corners=False)
# weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0])
weights = unpad(weights, pads)
weights = (weights * 255).cpu().numpy()
return weights
def delete_objects(self, objects: List[int]) -> None:
"""
Delete the given objects from the memory.
"""
self.object_manager.delete_objects(objects)
self.memory.purge_except(self.object_manager.all_obj_ids)
5 changes: 3 additions & 2 deletions cutie/inference/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ObjectManager:
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
Temporary IDs start from 1.
"""

def __init__(self):
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
Expand Down Expand Up @@ -52,7 +53,7 @@ def add_new_objects(
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
return corresponding_tmp_ids, corresponding_obj_ids

def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
# delete an object or a list of objects
# re-sort the tmp ids
if isinstance(obj_ids_to_remove, int):
Expand Down Expand Up @@ -93,7 +94,7 @@ def purge_inactive_objects(self,

purge_activated = len(obj_id_to_be_deleted) > 0
if purge_activated:
self.delete_object(obj_id_to_be_deleted)
self.delete_objects(obj_id_to_be_deleted)
return purge_activated, tmp_id_to_keep, obj_id_to_keep

def tmp_to_obj_cls(self, mask) -> torch.Tensor:
Expand Down
Binary file added examples/images/judo/00000.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00001.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00002.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00003.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00004.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00005.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00006.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00007.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00008.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00009.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00010.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00011.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00012.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00013.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00014.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/judo/00015.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/masks/judo/00000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/masks/judo/00005.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/masks/judo/00008.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/masks/judo/00013.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 59 additions & 0 deletions scripting_demo_add_del_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

import torch
from torchvision.transforms.functional import to_tensor
from PIL import Image
import numpy as np

from cutie.inference.inference_core import InferenceCore
from cutie.utils.get_default_model import get_default_model


@torch.inference_mode()
@torch.cuda.amp.autocast()
def main():

cutie = get_default_model()
processor = InferenceCore(cutie, cfg=cutie.cfg)

image_path = './examples/images/judo'
mask_path = './examples/masks/judo'
images = sorted(os.listdir(image_path)) # ordering is important

for ti, image_name in enumerate(images):
image = Image.open(os.path.join(image_path, image_name))
image = to_tensor(image).cuda().float()

# deleting the red mask at time step 10 for no reason -- you can set your own condition
if ti == 10:
processor.delete_objects([1])

mask_name = image_name[:-4] + '.png'
if os.path.exists(os.path.join(mask_path, mask_name)):
# add the objects in the mask
mask = Image.open(os.path.join(mask_path, mask_name))
palette = mask.getpalette()
objects = np.unique(np.array(mask))
objects = objects[objects != 0].tolist() # background "0" does not count as an object
mask = torch.from_numpy(np.array(mask)).cuda()

prediction = processor.step(image, mask, objects=objects)
else:
prediction = processor.step(image)

# visualize prediction
mask = torch.argmax(prediction, dim=0)

# since the objects might shift in the channel dim due to deletion, remap the ids
new_mask = torch.zeros_like(mask)
for tmp_id, obj in processor.object_manager.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
mask = new_mask

mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
# mask.show() # or use prediction.save(...) to save it somewhere
mask.save(os.path.join('./examples', mask_name))


main()

0 comments on commit f8cdbd9

Please sign in to comment.