From 5e91f48cb066a9dc3e22acd080849cc47223ad23 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 10 Aug 2023 00:04:45 +0000 Subject: [PATCH] Improve stable diffusion tests --- .../stable_diffusion/stable_diffusion_test.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/keras_cv/models/stable_diffusion/stable_diffusion_test.py b/keras_cv/models/stable_diffusion/stable_diffusion_test.py index 39600b4f2a..097f036929 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion_test.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion_test.py @@ -49,11 +49,32 @@ def test_image_encoder_golden_value(self): [2.451568, 1.607522, -0.546311, -1.194388], atol=1e-4, ) + + def test_text_encoder_golden_value(self): + prompt = "a caterpillar smoking a hookah while sitting on a mushroom" + stablediff = StableDiffusion(128, 128) + text_encoding = stablediff.encode_text(prompt) + self.assertAllClose( + text_encoding[0][1][0:5], + [0.029033, -1.325784, 0.308457, -0.061469, 0.03983], + atol=1e-4, + ) + + def test_text_tokenizer_golden_value(self): + prompt = "a caterpillar smoking a hookah while sitting on a mushroom" + stablediff = StableDiffusion(128, 128) + text_encoding = stablediff.tokenizer.encode(prompt) + self.assertEqual( + text_encoding[0:5], + [49406, 320, 27111, 9038, 320], + ) def test_mixed_precision(self): mixed_precision.set_global_policy("mixed_float16") stablediff = StableDiffusion(128, 128) - _ = stablediff.text_to_image("Testing123 haha!") + _ = stablediff.text_to_image("Testing123 haha!", num_steps=2) + # Clean up global policy + mixed_precision.set_global_policy("float32") def test_generate_image_rejects_noise_and_seed(self): stablediff = StableDiffusion(128, 128)