Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 12, 2024
1 parent 2e89b60 commit 07423f1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def validate_model_outputs(
logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}")

# Values
if not torch.allclose(ref_output, output, atol=atol):
if not torch.allclose(ref_output, output.to(ref_output.dtype), atol=atol):
max_diff = torch.max(torch.abs(ref_output - output))
logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})")
value_failures.append((name, max_diff))
Expand Down
3 changes: 3 additions & 0 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,9 @@ def __call__(self, *args, **kwargs):
kwargs.pop("width", None)
if kwargs.get("image", None):
kwargs["image"] = self.image_processor.preprocess(kwargs["image"], height=height, width=width)
# Override default `max_sequence_length`, eg. pixart
if "max_sequence_length" in inspect.signature(self.auto_model_class.__call__).parameters:
kwargs["max_sequence_length"] = self.text_encoder.config.neuron.get("static_sequence_length", None)
return self.auto_model_class.__call__(self, height=height, width=width, *args, **kwargs)


Expand Down
40 changes: 40 additions & 0 deletions tests/inference/test_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import cv2
import numpy as np
import PIL
import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import UniPCMultistepScheduler
from diffusers.utils import load_image
from parameterized import parameterized

from optimum.neuron import (
NeuronLatentConsistencyModelPipeline,
NeuronPixArtAlphaPipeline,
NeuronStableDiffusionControlNetPipeline,
NeuronStableDiffusionImg2ImgPipeline,
NeuronStableDiffusionInpaintPipeline,
Expand All @@ -38,6 +40,7 @@
from optimum.neuron.modeling_diffusion import (
NeuronControlNetModel,
NeuronModelTextEncoder,
NeuronModelTransformer,
NeuronModelUnet,
NeuronModelVaeDecoder,
NeuronModelVaeEncoder,
Expand Down Expand Up @@ -435,3 +438,40 @@ def test_from_pipe(self, model_arch):
prompt = "a dog running, lake, moat"
image = img2img_pipeline(prompt=prompt, image=init_image).images[0]
self.assertIsInstance(image, PIL.Image.Image)


is_inferentia_test


@requires_neuronx
@require_diffusers
class NeuronPixArtAlphaPipelineIntegrationTest(unittest.TestCase):
ATOL_FOR_VALIDATION = 1e-3

def test_export_and_inference_non_dyn(self):
model_id = "hf-internal-testing/tiny-pixart-alpha-pipe"
compiler_args = {"auto_cast": "none"}
input_shapes = {"batch_size": 1, "height": 64, "width": 64, "sequence_length": 32}
neuron_pipeline = NeuronPixArtAlphaPipeline.from_pretrained(
model_id,
export=True,
torch_dtype=torch.bfloat16,
dynamic_batch_size=False,
disable_neuron_cache=True,
**input_shapes,
**compiler_args,
)
self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder)
self.assertIsInstance(neuron_pipeline.transformer, NeuronModelTransformer)
self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder)
self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder)

prompt = "Mario eating hamburgers."

neuron_pipeline.transformer.config.sample_size = (
32 # Skip the sample size check because the dummy model uses a smaller sample size (8).
)
image = neuron_pipeline(prompt=prompt, use_resolution_binning=False).images[
0
] # Set `use_resolution_binning=False` to prevent resizing.
self.assertIsInstance(image, PIL.Image.Image)

0 comments on commit 07423f1

Please sign in to comment.