Skip to content

Commit

Permalink
Fix other Ruff complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jul 25, 2023
1 parent 763a8e7 commit e943461
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 15 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
# TODO change once leaving "swiffer" config directory
try:
group_name = nowname.split(now)[-1].split("-")[1]
except:
except Exception:
group_name = nowname
default_logger_cfg["params"]["group"] = group_name
init_wandb(
Expand Down Expand Up @@ -842,7 +842,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
print(
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
)
except:
except Exception:
print("datasets not yet initialized.")

# configure learning rate
Expand Down
2 changes: 1 addition & 1 deletion scripts/demo/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def decode(self, cv2Image, method="dwtDct", **configs):
bits = embed.decode(cv2Image)
return self.reconstruct(bits)

except:
except Exception:
raise e


Expand Down
2 changes: 1 addition & 1 deletion scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def do_sample(
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
print(key, [len(lst) for lst in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
Expand Down
6 changes: 2 additions & 4 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from torch import nn

from ..util import default, exists
from .diffusionmodules.util import checkpoint

logger = logging.getLogger(__name__)


if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
Expand Down Expand Up @@ -51,12 +51,10 @@
import xformers.ops

XFORMERS_IS_AVAILABLE = True
except:
except Exception:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")

from .diffusionmodules.util import checkpoint


def uniq(arr): # TODO: this seems unused
return {el: True for el in arr}.keys()
Expand Down
10 changes: 6 additions & 4 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
from einops import rearrange
from packaging import version

from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention

logger = logging.getLogger(__name__)

try:
import xformers
import xformers.ops

XFORMERS_IS_AVAILABLE = True
except:
except Exception:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")

from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention


def get_timestep_embedding(timesteps, embedding_dim):
"""
Expand Down Expand Up @@ -299,7 +299,9 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
)
attn_type = "vanilla"
attn_kwargs = None
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
logger.debug(
f"making attention of type '{attn_type}' with {in_channels} in_channels"
)
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
Expand Down
3 changes: 2 additions & 1 deletion sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,8 @@ def __init__(

self.use_fairscale_checkpoint = False
checkpoint_wrapper_fn = (
partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
# TODO: this can't work since `checkpoint_wrapper` is not defined
partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) # noqa: F821
if self.use_fairscale_checkpoint
else lambda x: x
)
Expand Down
6 changes: 4 additions & 2 deletions sgm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_string_from_tuple(s):
return t[0]
else:
pass
except:
except Exception:
pass
return s

Expand Down Expand Up @@ -164,7 +164,9 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
logger.info(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params


Expand Down

0 comments on commit e943461

Please sign in to comment.