Skip to content

Commit

Permalink
Restore optimized attention score for sd15 & fix the generated images…
Browse files Browse the repository at this point in the history
… quality issue (#646)

* restore

* fix
  • Loading branch information
JingyaHuang authored Jul 1, 2024
1 parent 86900e7 commit fca97de
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 53 deletions.
16 changes: 4 additions & 12 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
get_attention_scores_sd2,
get_attention_scores_sd15,
get_attention_scores_sd,
get_attention_scores_sdxl,
)
from ...utils import (
Expand All @@ -54,10 +53,7 @@
"Please update diffusers by running `pip install --upgrade diffusers`"
)
from diffusers import ControlNetModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
)
from diffusers.models.attention_processor import Attention


if TYPE_CHECKING:
Expand Down Expand Up @@ -388,7 +384,6 @@ def get_submodels_for_export_stable_diffusion(
models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, copy.deepcopy(text_encoder_2)))

# U-NET
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
Expand All @@ -400,12 +395,9 @@ def get_submodels_for_export_stable_diffusion(
if is_sdxl:
logger.info("Applying optimized attention score computation for sdxl.")
Attention.get_attention_scores = get_attention_scores_sdxl
elif "v1-5" in pipeline.config._name_or_path:
logger.info("Applying optimized attention score computation for stable diffusion 1.5.")
Attention.get_attention_scores = get_attention_scores_sd15
else:
logger.info("Applying optimized attention score computation for stable diffusion 2.")
Attention.get_attention_scores = get_attention_scores_sd2
logger.info("Applying optimized attention score computation for stable diffusion.")
Attention.get_attention_scores = get_attention_scores_sd
else:
logger.warning(
"You are not applying optimized attention score computation. If you want better performance, please"
Expand Down
6 changes: 2 additions & 4 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
],
"model_utils": ["get_tied_parameters_dict", "tie_parameters"],
"optimization_utils": [
"get_attention_scores_sd2",
"get_attention_scores_sd15",
"get_attention_scores_sd",
"get_attention_scores_sdxl",
],
"patching": [
Expand Down Expand Up @@ -105,8 +104,7 @@
)
from .model_utils import get_tied_parameters_dict, tie_parameters
from .optimization_utils import (
get_attention_scores_sd2,
get_attention_scores_sd15,
get_attention_scores_sd,
get_attention_scores_sdxl,
)
from .patching import (
Expand Down
39 changes: 2 additions & 37 deletions optimum/neuron/utils/optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,8 @@
import torch


def get_attention_scores_sd15(self, query, key, attention_mask) -> torch.Tensor:
"""Optimized attention for Stable Diffusion 1.5 UNET."""
dtype = query.dtype

if self.upcast_attention:
query = query.float()
key = key.float()

baddbmm_input = torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
beta = 0

attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input

# TODO: following line is supposed to give the same result and reduce unnecessary overhead(no attention mask)
# however the compiled model output is far off from the one on cpu, need to further investigate.
# attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2)) # -> bad perf, max diff: 5.696073055267334 (atol: 0.001)

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
del attention_scores

attention_probs = attention_probs.to(dtype)

return attention_probs


def get_attention_scores_sd2(self, query, key, attn_mask):
"""Optimized attention for Stable Diffusion 2 UNET."""
def get_attention_scores_sd(self, query, key, attn_mask):
"""Optimized attention for Stable Diffusion UNET."""
dtype = query.dtype

if self.upcast_attention:
Expand Down

0 comments on commit fca97de

Please sign in to comment.