Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure ruff and black for linting #52

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions .github/workflows/black.yml

This file was deleted.

12 changes: 12 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Lint

on:
pull_request:

jobs:
Lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/[email protected]
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
minimum_pre_commit_version: 2.15.0
ci:
autofix_prs: false
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ now `DiffusionEngine`) has been cleaned up:
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- Autoencoding models have also been cleaned up.

### Style

* The repo is formatted with [Black](https://github.com/psf/black) and linted with [Ruff](https://beta.ruff.rs/).
* You can easily have these run on every commit by installing [`pre-commit`](https://pre-commit.com/) (`pip install pre-commit`) and running `pre-commit install` in the repo root.
benjaminaubin marked this conversation as resolved.
Show resolved Hide resolved

## Installation:

<a name="installation"></a>
Expand Down
9 changes: 4 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
# batch_idx > 5 and
self.max_images > 0
):
logger = type(pl_module.logger)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused local.

is_train = pl_module.training
if is_train:
pl_module.eval()
Expand Down Expand Up @@ -648,7 +647,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):

ckpt_resume_path = opt.resume_from_checkpoint

if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
if "devices" not in trainer_config and trainer_config["accelerator"] != "gpu":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x not in y is idiomatic.

del trainer_config["accelerator"]
cpu = True
else:
Expand Down Expand Up @@ -691,7 +690,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
# TODO change once leaving "swiffer" config directory
try:
group_name = nowname.split(now)[-1].split("-")[1]
except:
except Exception:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except: is dangerous as it catches SystemExits and KeyboardInterrupts.

group_name = nowname
default_logger_cfg["params"]["group"] = group_name
init_wandb(
Expand Down Expand Up @@ -814,7 +813,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
if not "plugins" in trainer_kwargs:
if "plugins" not in trainer_kwargs:
trainer_kwargs["plugins"] = list()

# cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
Expand All @@ -839,7 +838,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
print(
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
)
except:
except Exception:
print("datasets not yet initialized.")

# configure learning rate
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ test-inference = [
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference/test_inference.py {args}",
]

[tool.ruff]
target-version = "py38"
extend-select = ["I"]
ignore = [
"E501",
]
[tool.ruff.per-file-ignores]
"*/__init__.py" = [
"F401", # unused imports (will be taken care of by PR #44)
]
2 changes: 1 addition & 1 deletion scripts/demo/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def decode(self, cv2Image, method="dwtDct", **configs):
bits = embed.decode(cv2Image)
return self.reconstruct(bits)

except:
except Exception:
raise e


Expand Down
8 changes: 4 additions & 4 deletions scripts/demo/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(self, discretization: Discretization, strength: float = 1.0):
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
print("sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
print("sigmas after pruning: ", sigmas)
Comment on lines -21 to +26
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused fs.

return sigmas


Expand All @@ -45,7 +45,7 @@ def __init__(
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
print("sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
Expand All @@ -55,5 +55,5 @@ def __call__(self, *args, **kwargs):
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
print("sigmas after pruning: ", sigmas)
return sigmas
19 changes: 18 additions & 1 deletion scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import os

import numpy as np
import streamlit as st
import torch
from einops import repeat
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.demo.streamlit_helpers import (
do_img2img,
do_sample,
get_interactive_image,
get_unique_embedder_keys_from_conditioner,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
perform_save_locally,
set_lowvram_mode,
)

SAVE_PATH = "outputs/demo/txt2img/"

Expand Down
43 changes: 17 additions & 26 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,37 @@
import numpy as np
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as TT
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid, save_image

from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from scripts.demo.discretization import (
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
VanillaCFG)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.modules.diffusionmodules.guiders import LinearPredictionGuider, VanillaCFG
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims, default, instantiate_from_config


@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
state = dict()
if not "model" in state:
if "model" not in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]

Expand Down Expand Up @@ -253,7 +253,7 @@ def get_guider(options, key):
min_value=1.0,
)
min_scale = st.number_input(
f"min guidance scale",
"min guidance scale",
value=options.get("min_cfg", 1.0),
min_value=1.0,
max_value=10.0,
Expand Down Expand Up @@ -432,15 +432,6 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
return sampler


def get_interactive_image() -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image


Comment on lines -435 to -443
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated function.

def load_img(
display: bool = True,
size: Union[None, int, Tuple[int, int]] = None,
Expand Down Expand Up @@ -776,7 +767,7 @@ def denoiser(x, sigma, c):
if filter is not None:
samples = filter(samples)

grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") # noqa: F821
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
Expand Down
12 changes: 11 additions & 1 deletion scripts/demo/turbo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from streamlit_helpers import *
import numpy as np
import streamlit as st
import torch
from st_keyup import st_keyup

from scripts.demo.streamlit_helpers import (
autocast,
get_batch,
get_unique_embedder_keys_from_conditioner,
init_st,
load_model,
)
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler

VERSION2SPECS = {
Expand Down
16 changes: 13 additions & 3 deletions scripts/demo/video_sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import os

import streamlit as st
import torch
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.demo.streamlit_helpers import (
do_sample,
get_unique_embedder_keys_from_conditioner,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
load_img_for_prediction,
save_video_as_grid_and_mp4,
)

SAVE_PATH = "outputs/demo/vid/"

Expand Down Expand Up @@ -89,7 +100,6 @@
},
}


if __name__ == "__main__":
st.title("Stable Video Diffusion")
version = st.selectbox(
Expand Down Expand Up @@ -165,7 +175,7 @@

if st.checkbox("Overwrite fps in mp4 generator", False):
saving_fps = st.number_input(
f"saving video at fps:", value=value_dict["fps"], min_value=1
"saving video at fps:", value=value_dict["fps"], min_value=1
)
else:
saving_fps = value_dict["fps"]
Expand Down
3 changes: 1 addition & 2 deletions scripts/sampling/simple_video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from PIL import Image
from torchvision.transforms import ToTensor

from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config

Expand Down
10 changes: 5 additions & 5 deletions scripts/tests/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
dtype=dtype,
)

print(f"q/k/v shape:", query.shape, key.shape, value.shape)
print("q/k/v shape:", query.shape, key.shape, value.shape)

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel
Expand Down Expand Up @@ -87,7 +87,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused local, mark it as such.

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print(
Expand All @@ -99,7 +99,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
Expand All @@ -114,7 +114,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
Expand All @@ -129,7 +129,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


Expand Down
2 changes: 1 addition & 1 deletion sgm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
except ImportError:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
Expand Down
Loading