Skip to content

Commit 0b6edd2

Browse files
Release CoTracker3 (#108)
* release cotracker3 * gradio demo fixes * new gradio demo * fix readme demo * update readme * update evaluation * update readme and gradio demo
1 parent 5951295 commit 0b6edd2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+4557
-697
lines changed

README.md

+148-52
Large diffs are not rendered by default.

assets/teaser.png

-1.2 MB
Loading

cotracker/datasets/dataclass_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
115115
types = cls.__annotations__.values()
116116
dlist_T = zip(*dlist)
117117
res_T = [
118-
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
118+
_dataclass_list_from_dict_list(key_list, tp)
119+
for key_list, tp in zip(dlist_T, types)
119120
]
120121
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
121122
elif issubclass(cls, (list, tuple)):
@@ -125,7 +126,8 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
125126
types = types * len(dlist[0])
126127
dlist_T = zip(*dlist)
127128
res_T = (
128-
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
129+
_dataclass_list_from_dict_list(pos_list, tp)
130+
for pos_list, tp in zip(dlist_T, types)
129131
)
130132
if issubclass(cls, tuple):
131133
return list(zip(*res_T))

cotracker/datasets/dr_dataset.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def __init__(
6767
with gzip.open(
6868
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
6969
) as zipfile:
70-
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
70+
frame_annots_list = load_dataclass(
71+
zipfile, List[DynamicReplicaFrameAnnotation]
72+
)
7173
seq_annot = defaultdict(list)
7274
for frame_annot in frame_annots_list:
7375
if frame_annot.camera_name == "left":
@@ -102,7 +104,10 @@ def crop(self, rgbs, trajs):
102104
# simple random crop
103105
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
104106
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
105-
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
107+
rgbs = [
108+
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
109+
for rgb in rgbs
110+
]
106111

107112
trajs[:, :, 0] -= x0
108113
trajs[:, :, 1] -= y0
@@ -118,7 +123,9 @@ def __getitem__(self, index):
118123
image_size = (H, W)
119124

120125
for i in range(T):
121-
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
126+
traj_path = os.path.join(
127+
self.root, self.split, sample[i].trajectories["path"]
128+
)
122129
traj = torch.load(traj_path)
123130

124131
visibilities.append(traj["verts_inds_vis"].numpy())

0 commit comments

Comments
 (0)