Skip to content

Enhance Video and Image Handling, Add Interactive Query Mode #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ We offer a number of other ways to interact with CoTracker:
[Google Colab](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb).
- Or explore the notebook located at [`notebooks/demo.ipynb`](./notebooks/demo.ipynb).
2. You can [install](#installation-instructions) CoTracker _locally_ and then:
- Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video (results will be saved to `./saved_videos/demo.mp4`)):
- Run an *offline* demo with 10 ⨉ 10 points sampled on a grid on the first frame of a video or images (results will be saved to `./saved_videos/demo.mp4`)):

```bash
python demo.py --grid_size 10
```
Or interactive_query points using mouse click
```
python demo.py --video_path ./your/images/or/video --checkpoint ./checkpoints/cotracker2.pth --grid_query_frame 5 --interactive_query
```
- Run an *online* demo:

```bash
Expand Down
82 changes: 75 additions & 7 deletions cotracker/utils/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,86 @@
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

import matplotlib.pyplot as plt
from matplotlib.widgets import Button

# Event handler for mouse clicks
def on_click(event, queries):
if event.button == 1 and event.inaxes == ax: # Left mouse button clicked
x, y = int(np.round(event.xdata)), int(np.round(event.ydata))
frame_idx = 0 # Assuming the first frame for simplicity

# Add the clicked point to the list of queries
queries.append([frame_idx, x, y])

# Update the plot to show the new point
# Map the index of the query to a value between 0 and 1
color_index = len(queries) % 20 # Ensure it cycles through the colormap
color = colormap(color_index / 20) # Normalize the index

# The colormap returns a tuple with an alpha channel, but we only need the RGB values
color = color[:3] # Remove the alpha channel

ax.plot(x, y, 'o', color=color, markersize=2)
plt.draw()

# Function to get queries from mouse clicks
def get_queries_from_clicks(frame):
global ax, colormap

# Initialize queries as an empty list
queries = []

# Convert the tensor to a numpy array and ensure it's in the correct range [0, 1]
frame_np = frame.permute(1, 2, 0).cpu().numpy()
frame_np = (frame_np - frame_np.min()) / (frame_np.max() - frame_np.min())

# Display the frame and set up the event handler
fig, ax = plt.subplots()
ax.imshow(frame_np)
colormap = plt.cm.get_cmap('tab20')
cid = fig.canvas.mpl_connect('button_press_event', lambda event: on_click(event, queries))

# Wait for user input
plt.show()

# Disconnect the event handler
fig.canvas.mpl_disconnect(cid)

# Convert the list of queries to a tensor
queries_tensor = torch.tensor(queries)

# Move queries to the appropriate device
if torch.cuda.is_available():
queries_tensor = queries_tensor.cuda()

return queries_tensor

def read_video_from_path(path):
try:
reader = imageio.get_reader(path)
# Check if the path is a video file
if os.path.isfile(path):
reader = imageio.get_reader(path)
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)
# Check if the path is a directory
elif os.path.isdir(path):
images = []
# Get all files in the directory and sort them
filenames = sorted(os.listdir(path))
for filename in filenames:
if filename.endswith(('.png', '.jpg', '.jpeg')):
img = imageio.imread(os.path.join(path, filename))
images.append(img)
return np.stack(images)
else:
print("Error: Invalid path")
return None
except Exception as e:
print("Error opening video file: ", e)
print("Error opening video file or images folder: ", e)
return None
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)


def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
# Create a draw object
Expand Down
17 changes: 16 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from PIL import Image
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.utils.visualizer import Visualizer, read_video_from_path, get_queries_from_clicks
from cotracker.predictor import CoTrackerPredictor

# Unfortunately MPS acceleration does not support all the features we require,
Expand Down Expand Up @@ -58,6 +58,14 @@
help="Compute tracks in both directions, not only forward",
)

# Flag to enable interactive queries
parser.add_argument(
"--interactive_query",
action="store_true",
default=False, # Set default value to False
help="Enable interactive query mode for user input."
)

args = parser.parse_args()

# load the input video frame by frame
Expand All @@ -73,8 +81,15 @@
model = model.to(DEFAULT_DEVICE)
video = video.to(DEFAULT_DEVICE)
# video = video[:, :20]
# Determine the queries based on interactive mode
if args.interactive_query:
queries = get_queries_from_clicks(video[0][args.grid_query_frame]).float()[None]
else:
queries = None

pred_tracks, pred_visibility = model(
video,
queries=queries,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
backward_tracking=args.backward_tracking,
Expand Down