Skip to content

Commit

Permalink
Update for black 24.1.0 (#308)
Browse files Browse the repository at this point in the history
Co-authored-by: Jay Qi <[email protected]>
  • Loading branch information
jayqi and jayqi authored Jan 28, 2024
1 parent 3c7b310 commit dc87a8a
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 28 deletions.
2 changes: 1 addition & 1 deletion requirements-dev/lint.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
black
black>=24.1.0
flake8
16 changes: 10 additions & 6 deletions zamba/models/densepose/densepose_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,16 @@ def serialize_image_output(self, instances, filename=None, write_embeddings=Fals
"value": labels[i],
"mesh_name": self.vis_class_to_mesh_name[labels[i]],
},
"embedding": pose_result.embedding[[i], ...].cpu().tolist()
if write_embeddings
else None,
"segmentation": pose_result.coarse_segm[[i], ...].cpu().tolist()
if write_embeddings
else None,
"embedding": (
pose_result.embedding[[i], ...].cpu().tolist()
if write_embeddings
else None
),
"segmentation": (
pose_result.coarse_segm[[i], ...].cpu().tolist()
if write_embeddings
else None
),
}
for i in range(len(instances))
]
Expand Down
24 changes: 15 additions & 9 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,16 @@ def train_model(
model_checkpoint = ModelCheckpoint(
dirpath=logging_and_save_dir,
filename=train_config.model_name,
monitor=train_config.early_stopping_config.monitor
if train_config.early_stopping_config is not None
else None,
mode=train_config.early_stopping_config.mode
if train_config.early_stopping_config is not None
else "min",
monitor=(
train_config.early_stopping_config.monitor
if train_config.early_stopping_config is not None
else None
),
mode=(
train_config.early_stopping_config.mode
if train_config.early_stopping_config is not None
else "min"
),
)

callbacks = [model_checkpoint]
Expand All @@ -283,9 +287,11 @@ def train_model(
logger=tensorboard_logger,
callbacks=callbacks,
fast_dev_run=train_config.dry_run,
strategy=DDPStrategy(find_unused_parameters=False)
if (data_module.multiprocessing_context is not None) and (train_config.gpus > 1)
else "auto",
strategy=(
DDPStrategy(find_unused_parameters=False)
if (data_module.multiprocessing_context is not None) and (train_config.gpus > 1)
else "auto"
),
)

if video_loader_config.cache_dir is None:
Expand Down
6 changes: 3 additions & 3 deletions zamba/models/slowfast_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def __init__(
),
activation=None,
pool=None,
dropout=None
if post_backbone_dropout is None
else torch.nn.Dropout(post_backbone_dropout),
dropout=(
None if post_backbone_dropout is None else torch.nn.Dropout(post_backbone_dropout)
),
output_pool=torch.nn.AdaptiveAvgPool3d(1),
)

Expand Down
16 changes: 10 additions & 6 deletions zamba/pytorch/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,16 @@ def __getitem__(self, index: int):
video = np.zeros(
(
self.video_loader_config.total_frames,
self.video_loader_config.model_input_height
if self.video_loader_config.model_input_height is not None
else self.video_loader_config.frame_selection_height,
self.video_loader_config.model_input_width
if self.video_loader_config.model_input_width is not None
else self.video_loader_config.frame_selection_width,
(
self.video_loader_config.model_input_height
if self.video_loader_config.model_input_height is not None
else self.video_loader_config.frame_selection_height
),
(
self.video_loader_config.model_input_width
if self.video_loader_config.model_input_width is not None
else self.video_loader_config.frame_selection_width
),
3,
),
dtype="int",
Expand Down
8 changes: 5 additions & 3 deletions zamba/pytorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def compute_left_and_right_pad(original_size: int, padded_size: int) -> Tuple[in
def forward(self, vid: torch.Tensor) -> torch.Tensor:
padding = tuple(
itertools.chain.from_iterable(
(0, 0)
if padded_size is None
else self.compute_left_and_right_pad(original_size, padded_size)
(
(0, 0)
if padded_size is None
else self.compute_left_and_right_pad(original_size, padded_size)
)
for original_size, padded_size in zip(vid.shape, self.dimension_sizes)
)
)
Expand Down

0 comments on commit dc87a8a

Please sign in to comment.