Skip to content

Commit

Permalink
Fix other Ruff complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 22, 2023
1 parent f614ed5 commit a0ff858
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 25 deletions.
5 changes: 2 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
# batch_idx > 5 and
self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
Expand Down Expand Up @@ -691,7 +690,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
# TODO change once leaving "swiffer" config directory
try:
group_name = nowname.split(now)[-1].split("-")[1]
except:
except Exception:
group_name = nowname
default_logger_cfg["params"]["group"] = group_name
init_wandb(
Expand Down Expand Up @@ -839,7 +838,7 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
print(
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
)
except:
except Exception:
print("datasets not yet initialized.")

# configure learning rate
Expand Down
2 changes: 1 addition & 1 deletion scripts/demo/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def decode(self, cv2Image, method="dwtDct", **configs):
bits = embed.decode(cv2Image)
return self.reconstruct(bits)

except:
except Exception:
raise e


Expand Down
8 changes: 4 additions & 4 deletions scripts/tests/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print(
Expand All @@ -99,7 +99,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
Expand All @@ -114,7 +114,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
Expand All @@ -129,7 +129,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
_o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


Expand Down
1 change: 0 additions & 1 deletion sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def get_discretization_config(params: SamplingParams):
def get_sampler_config(params: SamplingParams):
discretization_config = get_discretization_config(params)
guider_config = get_guider_config(params)
sampler = None
if params.sampler == Sampler.EULER_EDM:
return EulerEDMSampler(
num_steps=params.steps,
Expand Down
6 changes: 3 additions & 3 deletions sgm/inference/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def do_sample(
batch2model_input = []

with torch.no_grad():
with autocast(device) as precision_scope:
with autocast(device):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
Expand All @@ -131,7 +131,7 @@ def do_sample(
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
print(key, [len(lst) for lst in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
Expand Down Expand Up @@ -255,7 +255,7 @@ def do_img2img(
device="cuda",
):
with torch.no_grad():
with autocast(device) as precision_scope:
with autocast(device):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
Expand Down
6 changes: 3 additions & 3 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ def sample(
):
randn = torch.randn(batch_size, *shape).to(self.device)

denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
def denoiser(input, sigma, c):
return self.denoiser(self.model, input, sigma, c, **kwargs)

samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples

Expand Down
4 changes: 1 addition & 3 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@
import xformers.ops

XFORMERS_IS_AVAILABLE = True
except:
except Exception:
XFORMERS_IS_AVAILABLE = False
logpy.warn("no module 'xformers'. Processing without...")

# from .diffusionmodules.util import mixed_checkpoint as checkpoint


def exists(val):
return val is not None
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/autoencoding/lpips/loss/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, input, target):
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
for l in range(1, len(self.chns)): # noqa: E741
val += res[l]
return val

Expand Down
9 changes: 4 additions & 5 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
from einops import rearrange
from packaging import version

from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention

logpy = logging.getLogger(__name__)

try:
import xformers
import xformers.ops

XFORMERS_IS_AVAILABLE = True
except:
except Exception:
XFORMERS_IS_AVAILABLE = False
logpy.warning("no module 'xformers'. Processing without...")

from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention


def get_timestep_embedding(timesteps, embedding_dim):
"""
Expand Down Expand Up @@ -633,8 +633,7 @@ def __init__(
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out

# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
Expand Down
2 changes: 1 addition & 1 deletion sgm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_string_from_tuple(s):
return t[0]
else:
pass
except:
except Exception:
pass
return s

Expand Down

0 comments on commit a0ff858

Please sign in to comment.