From d9e32c0efd24a3eb5644938e48201ba306b3f9f0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 19 Jul 2023 17:58:52 +0300 Subject: [PATCH 1/6] Configure ruff and black to be run via `pre-commit` (and check those in GitHub Actions) --- .github/workflows/black.yml | 15 --------------- .github/workflows/lint.yml | 12 ++++++++++++ .pre-commit-config.yaml | 13 +++++++++++++ README.md | 5 +++++ pyproject.toml | 11 +++++++++++ 5 files changed, 41 insertions(+), 15 deletions(-) delete mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index ab652601..00000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: Run black -on: [pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Install venv - run: | - sudo apt-get -y install python3.10-venv - - uses: psf/black@stable - with: - options: "--check --verbose -l88" - src: "./sgm ./scripts ./main.py" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..94a92dc0 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,12 @@ +name: Lint + +on: + pull_request: + +jobs: + Lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..6f6176b4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +minimum_pre_commit_version: 2.15.0 +ci: + autofix_prs: false +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.6 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black diff --git a/README.md b/README.md index b717d7d3..a7702631 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,11 @@ now `DiffusionEngine`) has been cleaned up: training (`sgm/modules/diffusionmodules/sigma_sampling.py`). - Autoencoding models have also been cleaned up. +### Style + +* The repo is formatted with [Black](https://github.com/psf/black) and linted with [Ruff](https://beta.ruff.rs/). + * You can easily have these run on every commit by installing [`pre-commit`](https://pre-commit.com/) (`pip install pre-commit`) and running `pre-commit install` in the repo root. + ## Installation: diff --git a/pyproject.toml b/pyproject.toml index 2cc50216..5831d399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,3 +46,14 @@ test-inference = [ "pip install -r requirements/pt2.txt", "pytest -v tests/inference/test_inference.py {args}", ] + +[tool.ruff] +target-version = "py38" +extend-select = ["I"] +ignore = [ + "E501", +] +[tool.ruff.per-file-ignores] +"*/__init__.py" = [ + "F401", # unused imports (will be taken care of by PR #44) +] From 6379cb137ca47fe7fed211b109a3a42c05a40dc7 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 09:19:09 +0200 Subject: [PATCH 2/6] Don't use star imports --- scripts/demo/sampling.py | 19 ++++++++++++++++++- scripts/demo/turbo.py | 12 +++++++++++- scripts/demo/video_sampling.py | 14 ++++++++++++-- sgm/modules/diffusionmodules/video_model.py | 14 +++++++++++--- sgm/modules/video_attention.py | 15 +++++++++++++-- 5 files changed, 65 insertions(+), 9 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 46c20048..ef286219 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,6 +1,23 @@ +import os + +import numpy as np +import streamlit as st +import torch +from einops import repeat from pytorch_lightning import seed_everything -from scripts.demo.streamlit_helpers import * +from scripts.demo.streamlit_helpers import ( + do_img2img, + do_sample, + get_interactive_image, + get_unique_embedder_keys_from_conditioner, + init_embedder_options, + init_sampling, + init_save_locally, + init_st, + perform_save_locally, + set_lowvram_mode, +) SAVE_PATH = "outputs/demo/txt2img/" diff --git a/scripts/demo/turbo.py b/scripts/demo/turbo.py index 3f348792..5a3c433c 100644 --- a/scripts/demo/turbo.py +++ b/scripts/demo/turbo.py @@ -1,5 +1,15 @@ -from streamlit_helpers import * +import numpy as np +import streamlit as st +import torch from st_keyup import st_keyup + +from scripts.demo.streamlit_helpers import ( + autocast, + get_batch, + get_unique_embedder_keys_from_conditioner, + init_st, + load_model, +) from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler VERSION2SPECS = { diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 95789020..6d2ea860 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -1,8 +1,19 @@ import os +import streamlit as st +import torch from pytorch_lightning import seed_everything -from scripts.demo.streamlit_helpers import * +from scripts.demo.streamlit_helpers import ( + do_sample, + get_unique_embedder_keys_from_conditioner, + init_embedder_options, + init_sampling, + init_save_locally, + init_st, + load_img_for_prediction, + save_video_as_grid_and_mp4, +) SAVE_PATH = "outputs/demo/vid/" @@ -89,7 +100,6 @@ }, } - if __name__ == "__main__": st.title("Stable Video Diffusion") version = st.selectbox( diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py index ff2d077c..3242c21d 100644 --- a/sgm/modules/diffusionmodules/video_model.py +++ b/sgm/modules/diffusionmodules/video_model.py @@ -1,12 +1,20 @@ -from functools import partial from typing import List, Optional, Union +import torch as th +import torch.nn as nn from einops import rearrange -from ...modules.diffusionmodules.openaimodel import * +from .openaimodel import ResBlock, Timestep, TimestepEmbedSequential, Downsample, Upsample +from .util import AlphaBlender +from ...modules.diffusionmodules.util import ( + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) from ...modules.video_attention import SpatialVideoTransformer from ...util import default -from .util import AlphaBlender class VideoResBlock(ResBlock): diff --git a/sgm/modules/video_attention.py b/sgm/modules/video_attention.py index 783395aa..482b3bcb 100644 --- a/sgm/modules/video_attention.py +++ b/sgm/modules/video_attention.py @@ -1,6 +1,17 @@ -import torch +from typing import Optional -from ..modules.attention import * +import torch +from einops import rearrange, repeat +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ..modules.attention import ( + CrossAttention, + FeedForward, + MemoryEfficientCrossAttention, + SpatialTransformer, + exists, +) from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding From 295787e6d6c75c4b50b8dd55c860442725b92121 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 09:26:57 +0200 Subject: [PATCH 3/6] Remove duplicate get_interactive_image --- scripts/demo/streamlit_helpers.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 6c5760e2..58b86da5 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -432,14 +432,6 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1 return sampler -def get_interactive_image() -> Image.Image: - image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) - if image is not None: - image = Image.open(image) - if not image.mode == "RGB": - image = image.convert("RGB") - return image - def load_img( display: bool = True, From a770cd24911f823750417b185f2a33dc33c78381 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 09:10:45 +0200 Subject: [PATCH 4/6] Apply Ruff autofixes --- main.py | 4 +-- scripts/demo/discretization.py | 8 ++--- scripts/demo/streamlit_helpers.py | 33 +++++++++---------- scripts/demo/video_sampling.py | 2 +- scripts/sampling/simple_video_sample.py | 3 +- scripts/tests/attention.py | 2 +- sgm/data/dataset.py | 2 +- sgm/inference/api.py | 17 +++++----- sgm/inference/helpers.py | 4 +-- sgm/models/autoencoder.py | 10 ++++-- sgm/models/diffusion.py | 9 +++-- sgm/modules/attention.py | 2 +- .../autoencoding/regularizers/__init__.py | 3 +- sgm/modules/diffusionmodules/openaimodel.py | 11 +++++-- sgm/modules/diffusionmodules/sampling.py | 11 ++++--- sgm/modules/diffusionmodules/video_model.py | 10 ++++-- sgm/modules/ema.py | 4 +-- sgm/modules/encoders/modules.py | 23 +++++++++---- sgm/util.py | 2 +- tests/inference/test_inference.py | 15 +++++---- 20 files changed, 104 insertions(+), 71 deletions(-) diff --git a/main.py b/main.py index 5e03c1c5..ea20d0be 100644 --- a/main.py +++ b/main.py @@ -648,7 +648,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str): ckpt_resume_path = opt.resume_from_checkpoint - if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu": + if "devices" not in trainer_config and trainer_config["accelerator"] != "gpu": del trainer_config["accelerator"] cpu = True else: @@ -814,7 +814,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str): trainer_kwargs["callbacks"] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg ] - if not "plugins" in trainer_kwargs: + if "plugins" not in trainer_kwargs: trainer_kwargs["plugins"] = list() # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs) diff --git a/scripts/demo/discretization.py b/scripts/demo/discretization.py index b7030a22..2a5ef178 100644 --- a/scripts/demo/discretization.py +++ b/scripts/demo/discretization.py @@ -18,12 +18,12 @@ def __init__(self, discretization: Discretization, strength: float = 1.0): def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) + print("sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) + print("sigmas after pruning: ", sigmas) return sigmas @@ -45,7 +45,7 @@ def __init__( def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) + print("sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) if self.original_steps is None: steps = len(sigmas) @@ -55,5 +55,5 @@ def __call__(self, *args, **kwargs): sigmas = sigmas[prune_index:] print("prune index:", prune_index) sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) + print("sigmas after pruning: ", sigmas) return sigmas diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 58b86da5..fd500ce6 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -8,10 +8,8 @@ import numpy as np import streamlit as st import torch -import torch.nn as nn import torchvision.transforms as TT from einops import rearrange, repeat -from imwatermark import WatermarkEncoder from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors @@ -19,26 +17,28 @@ from torchvision import transforms from torchvision.utils import make_grid, save_image -from scripts.demo.discretization import (Img2ImgDiscretizationWrapper, - Txt2NoisyDiscretizationWrapper) -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering +from scripts.demo.discretization import ( + Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, +) +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark -from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider, - VanillaCFG) -from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler) +from sgm.modules.diffusionmodules.guiders import LinearPredictionGuider, VanillaCFG +from sgm.modules.diffusionmodules.sampling import ( + DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler, +) from sgm.util import append_dims, default, instantiate_from_config @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() - if not "model" in state: + if "model" not in state: config = version_dict["config"] ckpt = version_dict["ckpt"] @@ -253,7 +253,7 @@ def get_guider(options, key): min_value=1.0, ) min_scale = st.number_input( - f"min guidance scale", + "min guidance scale", value=options.get("min_cfg", 1.0), min_value=1.0, max_value=10.0, @@ -432,7 +432,6 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1 return sampler - def load_img( display: bool = True, size: Union[None, int, Tuple[int, int]] = None, diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 6d2ea860..143a884e 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -175,7 +175,7 @@ if st.checkbox("Overwrite fps in mp4 generator", False): saving_fps = st.number_input( - f"saving video at fps:", value=value_dict["fps"], min_value=1 + "saving video at fps:", value=value_dict["fps"], min_value=1 ) else: saving_fps = value_dict["fps"] diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index c3f4ad2a..3a326a33 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -13,8 +13,7 @@ from PIL import Image from torchvision.transforms import ToTensor -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py index d7c3f7c8..febdb489 100644 --- a/scripts/tests/attention.py +++ b/scripts/tests/attention.py @@ -51,7 +51,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): dtype=dtype, ) - print(f"q/k/v shape:", query.shape, key.shape, value.shape) + print("q/k/v shape:", query.shape, key.shape, value.shape) # Lets explore the speed of each of the 3 implementations from torch.backends.cuda import SDPBackend, sdp_kernel diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py index b7261499..2a19491a 100644 --- a/sgm/data/dataset.py +++ b/sgm/data/dataset.py @@ -7,7 +7,7 @@ try: from sdata import create_dataset, create_dummy_dataset, create_loader -except ImportError as e: +except ImportError: print("#" * 100) print("Datasets not yet available") print("to enable, we need to add stable-datasets as a submodule") diff --git a/sgm/inference/api.py b/sgm/inference/api.py index a359a67b..8b16a84c 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -5,14 +5,15 @@ from omegaconf import OmegaConf -from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, - do_sample) -from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler) +from sgm.inference.helpers import Img2ImgDiscretizationWrapper, do_img2img, do_sample +from sgm.modules.diffusionmodules.sampling import ( + DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler, +) from sgm.util import load_model_from_config diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 31b0ec3d..f29597ea 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -89,12 +89,12 @@ def __init__(self, discretization, strength: float = 1.0): def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) + print("sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) + print("sigmas after pruning: ", sigmas) return sigmas diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 2949b910..fb39ec47 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -13,8 +13,12 @@ from ..modules.autoencoding.regularizers import AbstractRegularizer from ..modules.ema import LitEma -from ..util import (default, get_nested_attribute, get_obj_from_str, - instantiate_from_config) +from ..util import ( + default, + get_nested_attribute, + get_obj_from_str, + instantiate_from_config, +) logpy = logging.getLogger(__name__) @@ -529,7 +533,7 @@ def __init__( **kwargs, ): if "lossconfig" in kwargs: - logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") + logpy.warn("Parameter `lossconfig` is deprecated, use `loss_config`.") kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 2f3efd3c..261a4812 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -12,8 +12,13 @@ from ..modules.autoencoding.temporal_ae import VideoDecoder from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.ema import LitEma -from ..util import (default, disabled_train, get_obj_from_str, - instantiate_from_config, log_txt_as_img) +from ..util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, +) class DiffusionEngine(pl.LightningModule): diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 52a50b7b..a0239a3c 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -200,7 +200,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: - raise NotImplemented + raise NotImplementedError x = self.proj(x) x = self.proj_drop(x) diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py index ff2b1815..6065fb20 100644 --- a/sgm/modules/autoencoding/regularizers/__init__.py +++ b/sgm/modules/autoencoding/regularizers/__init__.py @@ -5,8 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from ....modules.distributions.distributions import \ - DiagonalGaussianDistribution +from ....modules.distributions.distributions import DiagonalGaussianDistribution from .base import AbstractRegularizer diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index b58e1b0e..5874fdd6 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -10,9 +10,14 @@ from torch.utils.checkpoint import checkpoint from ...modules.attention import SpatialTransformer -from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear, - normalization, - timestep_embedding, zero_module) +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) from ...modules.video_attention import SpatialVideoTransformer from ...util import exists diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py index af07566d..6346829c 100644 --- a/sgm/modules/diffusionmodules/sampling.py +++ b/sgm/modules/diffusionmodules/sampling.py @@ -9,10 +9,13 @@ from omegaconf import ListConfig, OmegaConf from tqdm import tqdm -from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, - linear_multistep_coeff, - to_d, to_neg_log_sigma, - to_sigma) +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) from ...util import append_dims, default, instantiate_from_config DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py index 3242c21d..8942e02f 100644 --- a/sgm/modules/diffusionmodules/video_model.py +++ b/sgm/modules/diffusionmodules/video_model.py @@ -4,8 +4,6 @@ import torch.nn as nn from einops import rearrange -from .openaimodel import ResBlock, Timestep, TimestepEmbedSequential, Downsample, Upsample -from .util import AlphaBlender from ...modules.diffusionmodules.util import ( conv_nd, linear, @@ -15,6 +13,14 @@ ) from ...modules.video_attention import SpatialVideoTransformer from ...util import default +from .openaimodel import ( + Downsample, + ResBlock, + Timestep, + TimestepEmbedSequential, + Upsample, +) +from .util import AlphaBlender class VideoResBlock(ResBlock): diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py index 97b5ae2b..bf33b510 100644 --- a/sgm/modules/ema.py +++ b/sgm/modules/ema.py @@ -51,7 +51,7 @@ def forward(self, model): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -60,7 +60,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index d77b8ed7..0faa9b46 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -11,17 +11,28 @@ from einops import rearrange, repeat from omegaconf import ListConfig from torch.utils.checkpoint import checkpoint -from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, - T5EncoderModel, T5Tokenizer) +from transformers import ( + ByT5Tokenizer, + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, +) from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer from ...modules.diffusionmodules.model import Encoder from ...modules.diffusionmodules.openaimodel import Timestep -from ...modules.diffusionmodules.util import (extract_into_tensor, - make_beta_schedule) +from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ...modules.distributions.distributions import DiagonalGaussianDistribution -from ...util import (append_dims, autocast, count_params, default, - disabled_train, expand_dims_like, instantiate_from_config) +from ...util import ( + append_dims, + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + instantiate_from_config, +) class AbstractEmbModel(nn.Module): diff --git a/sgm/util.py b/sgm/util.py index 66d9b2a6..9c59b3cc 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -166,7 +166,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 2b2af11e..f9dee55b 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -1,18 +1,19 @@ +from typing import Tuple + import numpy -from PIL import Image import pytest -from pytest import fixture import torch -from typing import Tuple +from PIL import Image +from pytest import fixture +import sgm.inference.helpers as helpers from sgm.inference.api import ( - model_specs, + ModelArchitecture, + Sampler, SamplingParams, SamplingPipeline, - Sampler, - ModelArchitecture, + model_specs, ) -import sgm.inference.helpers as helpers @pytest.mark.inference From e0f746af392699c13ea5cafbdf088f0ec93832ec Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 19 Jul 2023 18:13:24 +0300 Subject: [PATCH 5/6] Fix other Ruff complaints --- main.py | 5 ++--- scripts/demo/detect.py | 2 +- scripts/tests/attention.py | 8 ++++---- sgm/inference/api.py | 1 - sgm/inference/helpers.py | 6 +++--- sgm/models/diffusion.py | 6 +++--- sgm/modules/attention.py | 4 +--- sgm/modules/autoencoding/lpips/loss/lpips.py | 2 +- sgm/modules/diffusionmodules/model.py | 9 ++++----- sgm/util.py | 2 +- 10 files changed, 20 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index ea20d0be..579e67e7 100644 --- a/main.py +++ b/main.py @@ -402,7 +402,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): # batch_idx > 5 and self.max_images > 0 ): - logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() @@ -691,7 +690,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str): # TODO change once leaving "swiffer" config directory try: group_name = nowname.split(now)[-1].split("-")[1] - except: + except Exception: group_name = nowname default_logger_cfg["params"]["group"] = group_name init_wandb( @@ -839,7 +838,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str): print( f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" ) - except: + except Exception: print("datasets not yet initialized.") # configure learning rate diff --git a/scripts/demo/detect.py b/scripts/demo/detect.py index 96e9f212..6e219204 100644 --- a/scripts/demo/detect.py +++ b/scripts/demo/detect.py @@ -45,7 +45,7 @@ def decode(self, cv2Image, method="dwtDct", **configs): bits = embed.decode(cv2Image) return self.reconstruct(bits) - except: + except Exception: raise e diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py index febdb489..d52cc100 100644 --- a/scripts/tests/attention.py +++ b/scripts/tests/attention.py @@ -87,7 +87,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): ) as prof: with record_function("Default detailed stats"): for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) + _o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print( @@ -99,7 +99,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): ) as prof: with record_function("Math implmentation stats"): for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) + _o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): @@ -114,7 +114,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): ) as prof: with record_function("FlashAttention stats"): for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) + _o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): @@ -129,7 +129,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): ) as prof: with record_function("EfficientAttention stats"): for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) + _o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 8b16a84c..697b4f7b 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -326,7 +326,6 @@ def get_discretization_config(params: SamplingParams): def get_sampler_config(params: SamplingParams): discretization_config = get_discretization_config(params) guider_config = get_guider_config(params) - sampler = None if params.sampler == Sampler.EULER_EDM: return EulerEDMSampler( num_steps=params.steps, diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index f29597ea..7a461394 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -119,7 +119,7 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device): with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( @@ -131,7 +131,7 @@ def do_sample( if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) + print(key, [len(lst) for lst in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( @@ -255,7 +255,7 @@ def do_img2img( device="cuda", ): with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device): with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 261a4812..7cf7915e 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -250,9 +250,9 @@ def sample( ): randn = torch.randn(batch_size, *shape).to(self.device) - denoiser = lambda input, sigma, c: self.denoiser( - self.model, input, sigma, c, **kwargs - ) + def denoiser(input, sigma, c): + return self.denoiser(self.model, input, sigma, c, **kwargs) + samples = self.sampler(denoiser, randn, cond, uc=uc) return samples diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index a0239a3c..75386d58 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -51,12 +51,10 @@ import xformers.ops XFORMERS_IS_AVAILABLE = True -except: +except Exception: XFORMERS_IS_AVAILABLE = False logpy.warn("no module 'xformers'. Processing without...") -# from .diffusionmodules.util import mixed_checkpoint as checkpoint - def exists(val): return val is not None diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py index 3e34f3d0..a15b3e18 100644 --- a/sgm/modules/autoencoding/lpips/loss/lpips.py +++ b/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -59,7 +59,7 @@ def forward(self, input, target): for kk in range(len(self.chns)) ] val = res[0] - for l in range(1, len(self.chns)): + for l in range(1, len(self.chns)): # noqa: E741 val += res[l] return val diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 4cf9d921..a549125d 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -9,6 +9,8 @@ from einops import rearrange from packaging import version +from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention + logpy = logging.getLogger(__name__) try: @@ -16,12 +18,10 @@ import xformers.ops XFORMERS_IS_AVAILABLE = True -except: +except Exception: XFORMERS_IS_AVAILABLE = False logpy.warning("no module 'xformers'. Processing without...") -from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention - def get_timestep_embedding(timesteps, embedding_dim): """ @@ -633,8 +633,7 @@ def __init__( self.give_pre_end = give_pre_end self.tanh_out = tanh_out - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) + # compute block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) diff --git a/sgm/util.py b/sgm/util.py index 9c59b3cc..4777e439 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -28,7 +28,7 @@ def get_string_from_tuple(s): return t[0] else: pass - except: + except Exception: pass return s From e4a7d363248aaa0390b2f2a3b946a5319f4f676a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 09:36:21 +0200 Subject: [PATCH 6/6] Add a noqa directive for now --- scripts/demo/streamlit_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index fd500ce6..421a9837 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -767,7 +767,7 @@ def denoiser(x, sigma, c): if filter is not None: samples = filter(samples) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") # noqa: F821 outputs.image(grid.cpu().numpy()) if return_latents: return samples, samples_z