Skip to content

Commit 462d59f

Browse files
make LitDaliWrapper robust to custom dali pipelines
1 parent d0036fa commit 462d59f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

lightning_pose/data/dali.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,13 @@ def _dali_output_to_tensors(
190190
frames = batch[0]["frames"][0, :, :, :, :]
191191
# shape (1,) or (2, 3)
192192
transforms = batch[0]["transforms"][0]
193-
# get frame size, order is seq_len,H,W,C
194-
height = batch[0]["frame_size"][0, 1]
195-
width = batch[0]["frame_size"][0, 2]
193+
# get frame size
194+
if batch[0]["frame_size"][0, -1] == 3: # order is seq_len,H,W,C
195+
height = batch[0]["frame_size"][0, 1]
196+
width = batch[0]["frame_size"][0, 2]
197+
else: # order is seq_len,C,H,W
198+
height = batch[0]["frame_size"][0, 2]
199+
width = batch[0]["frame_size"][0, 3]
196200
bbox = torch.tensor([0, 0, height, width], device=frames.device).repeat(
197201
(frames.shape[0], 1))
198202

0 commit comments

Comments
 (0)