Skip to content

Commit

Permalink
from_pipe support
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Oct 9, 2024
1 parent 04a4427 commit c07a68e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 32 deletions.
89 changes: 57 additions & 32 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,42 @@ class NeuronDiffusionPipelineBase(NeuronTracedModel, ConfigMixin):
base_model_prefix = "neuron_model"
config_name = "model_index.json"
sub_component_config_name = "config.json"
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"vae_encoder",
"image_encoder",
"unet",
"transformer",
"feature_extractor",
]

def __init__(
self,
config: Dict[str, Any],
configs: Dict[str, "PretrainedConfig"],
neuron_configs: Dict[str, "NeuronDefaultConfig"],
data_parallel_mode: Literal["none", "unet", "transformer", "all"],
vae_decoder: torch.jit._script.ScriptModule,
text_encoder: Optional[torch.jit._script.ScriptModule] = None,
text_encoder_2: Optional[torch.jit._script.ScriptModule] = None,
unet: Optional[torch.jit._script.ScriptModule] = None,
transformer: Optional[torch.jit._script.ScriptModule] = None,
vae_encoder: Optional[torch.jit._script.ScriptModule] = None,
scheduler: Optional[SchedulerMixin],
vae_decoder: Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"],
text_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None,
text_encoder_2: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None,
unet: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelUnet"]] = None,
transformer: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTransformer"]] = None,
vae_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]] = None,
image_encoder: Optional[torch.jit._script.ScriptModule] = None,
safety_checker: Optional[torch.jit._script.ScriptModule] = None,
tokenizer: Optional[Union[CLIPTokenizer, T5Tokenizer]] = None,
tokenizer_2: Optional[CLIPTokenizer] = None,
scheduler: Optional[SchedulerMixin] = None,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
controlnet: Optional[
Union[
torch.jit._script.ScriptModule,
List[torch.jit._script.ScriptModule],
"NeuronControlNetModel",
"NeuronMultiControlNetModel",
]
] = None,
# stable diffusion xl specific arguments
Expand All @@ -163,17 +176,19 @@ def __init__(
data_parallel_mode (`Literal["none", "unet", "all"]`):
Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only
load unet into both cores of each device), "all"(load the whole pipeline into both cores).
vae_decoder (`torch.jit._script.ScriptModule`):
scheduler (`Optional[SchedulerMixin]`):
A scheduler to be used in combination with the U-NET component to denoise the encoded image latents.
vae_decoder (`Union[torch.jit._script.ScriptModule, "NeuronModelVaeDecoder"]`):
The Neuron TorchScript module associated to the VAE decoder.
text_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
text_encoder (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]]`, defaults to `None`):
The Neuron TorchScript module associated to the text encoder.
text_encoder_2 (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
text_encoder_2 (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]]`, defaults to `None`):
The Neuron TorchScript module associated to the second frozen text encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant.
unet (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
unet (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelUnet"]]`, defaults to `None`):
The Neuron TorchScript module associated to the U-NET.
transformer (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
transformer (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTransformer"]]`, defaults to `None`):
The Neuron TorchScript module associated to the diffuser transformer.
vae_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
vae_encoder (`Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]]`, defaults to `None`):
The Neuron TorchScript module associated to the VAE encoder.
image_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`):
The Neuron TorchScript module associated to the frozen CLIP image-encoder.
Expand All @@ -186,11 +201,9 @@ def __init__(
tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`):
Second tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`):
A scheduler to be used in combination with the U-NET component to denoise the encoded image latents.
feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`):
A model extracting features from generated images to be used as inputs for the `safety_checker`
controlnet (`Optional[Union[torch.jit._script.ScriptModule, List[torch.jit._script.ScriptModule]]]`, defaults to `None`):
controlnet (`Optional[Union[torch.jit._script.ScriptModule, List[torch.jit._script.ScriptModule], "NeuronControlNetModel", "NeuronMultiControlNetModel"]]`, defaults to `None`):
The Neuron TorchScript module(s) associated to the ControlNet(s).
requires_aesthetics_score (`bool`, defaults to `False`):
Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
Expand Down Expand Up @@ -225,8 +238,8 @@ def __init__(
self.configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME],
self.neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME],
)
if text_encoder is not None
else None
if text_encoder is not None and not isinstance(text_encoder, NeuronModelTextEncoder)
else text_encoder
)
self.text_encoder_2 = (
NeuronModelTextEncoder(
Expand All @@ -235,15 +248,15 @@ def __init__(
self.configs[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME],
self.neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME],
)
if text_encoder_2 is not None
else None
if text_encoder_2 is not None and not isinstance(text_encoder_2, NeuronModelTextEncoder)
else text_encoder_2
)
self.unet = (
NeuronModelUnet(
unet, self, self.configs[DIFFUSION_MODEL_UNET_NAME], self.neuron_configs[DIFFUSION_MODEL_UNET_NAME]
)
if unet is not None
else None
if unet is not None and not isinstance(unet, NeuronModelUnet)
else unet
)
self.transformer = (
NeuronModelTransformer(
Expand All @@ -252,8 +265,8 @@ def __init__(
self.configs[DIFFUSION_MODEL_TRANSFORMER_NAME],
self.neuron_configs[DIFFUSION_MODEL_TRANSFORMER_NAME],
)
if transformer is not None
else None
if transformer is not None and not isinstance(transformer, NeuronModelTransformer)
else transformer
)
self.vae_encoder = (
NeuronModelVaeEncoder(
Expand All @@ -262,8 +275,8 @@ def __init__(
self.configs[DIFFUSION_MODEL_VAE_ENCODER_NAME],
self.neuron_configs[DIFFUSION_MODEL_VAE_ENCODER_NAME],
)
if vae_encoder is not None
else None
if vae_encoder is not None and not isinstance(vae_encoder, NeuronModelVaeEncoder)
else vae_encoder
)
self.vae_decoder = (
NeuronModelVaeDecoder(
Expand All @@ -272,12 +285,16 @@ def __init__(
self.configs[DIFFUSION_MODEL_VAE_DECODER_NAME],
self.neuron_configs[DIFFUSION_MODEL_VAE_DECODER_NAME],
)
if vae_decoder is not None
else None
if vae_decoder is not None and not isinstance(vae_decoder, NeuronModelVaeDecoder)
else vae_decoder
)
self.vae = NeuronModelVae(self.vae_encoder, self.vae_decoder)

if controlnet is not None:
if (
controlnet
and not isinstance(controlnet, NeuronControlNetModel)
and not isinstance(controlnet, NeuronMultiControlNetModel)
):
controlnet_cls = (
NeuronMultiControlNetModel
if isinstance(controlnet, list) and len(controlnet) > 1
Expand All @@ -290,7 +307,7 @@ def __init__(
self.neuron_configs[DIFFUSION_MODEL_CONTROLNET_NAME],
)
else:
self.controlnet = None
self.controlnet = controlnet

self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
Expand Down Expand Up @@ -1048,15 +1065,23 @@ def _save_config(self, save_directory):
@property
def components(self) -> Dict[str, Any]:
components = {
"vae": self.vae,
"vae_encoder": self.vae_encoder,
"vae_decoder": self.vae_decoder,
"unet": self.unet,
"transformer": self.transformer,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"image_encoder": self.image_encoder,
"safety_checker": self.safety_checker,
"neuron_configs": self.neuron_configs,
"data_parallel_mode": self.data_parallel_mode,
"feature_extractor": self.feature_extractor,
"configs": self.configs,
"config": self.config,
"tokenizer": self.tokenizer,
"tokenizer_2": self.tokenizer_2,
"scheduler": self.scheduler,
}
components = {k: v for k, v in components.items() if v is not None}
return components

@property
Expand Down
16 changes: 16 additions & 0 deletions tests/inference/test_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,19 @@ def test_compatibility_with_compel(self, model_arch):
num_inference_steps=1,
).images[0]
self.assertIsInstance(image, PIL.Image.Image)

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
def test_from_pipe(self, model_arch):
txt2img_pipeline = NeuronStableDiffusionXLPipeline.from_pretrained(
MODEL_NAMES[model_arch],
export=True,
dynamic_batch_size=False,
**self.STATIC_INPUTS_SHAPES,
**self.COMPILER_ARGS,
)
img2img_pipeline = NeuronStableDiffusionXLImg2ImgPipeline.from_pipe(txt2img_pipeline)
url = "https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/sd_xl/castle_friedrich.png"
init_image = download_image(url)
prompt = "a dog running, lake, moat"
image = img2img_pipeline(prompt=prompt, image=init_image).images[0]
self.assertIsInstance(image, PIL.Image.Image)

0 comments on commit c07a68e

Please sign in to comment.