diff --git a/tests/cn_script/cn_script_test.py b/tests/cn_script/cn_script_test.py index 79e14c5bc..0557cbb79 100644 --- a/tests/cn_script/cn_script_test.py +++ b/tests/cn_script/cn_script_test.py @@ -150,22 +150,33 @@ def test_choose_input_image(self): ) with self.subTest(name="control net input"): - _, from_a1111 = Script.choose_input_image( - p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]), + _, resize_mode = Script.choose_input_image( + p=MockImg2ImgProcessing( + init_images=[TestScript.sample_np_image], + resize_mode=external_code.ResizeMode.OUTER_FIT, + ), unit=external_code.ControlNetUnit( - image=TestScript.sample_base64_image, module="none" + image=TestScript.sample_base64_image, + module="none", + resize_mode=external_code.ResizeMode.INNER_FIT, ), idx=0, ) - self.assertFalse(from_a1111) + self.assertEqual(resize_mode, external_code.ResizeMode.INNER_FIT) with self.subTest(name="A1111 input"): - _, from_a1111 = Script.choose_input_image( - p=MockImg2ImgProcessing(init_images=[TestScript.sample_np_image]), - unit=external_code.ControlNetUnit(module="none"), + _, resize_mode = Script.choose_input_image( + p=MockImg2ImgProcessing( + init_images=[TestScript.sample_np_image], + resize_mode=external_code.ResizeMode.OUTER_FIT, + ), + unit=external_code.ControlNetUnit( + module="none", + resize_mode=external_code.ResizeMode.INNER_FIT, + ), idx=0, ) - self.assertTrue(from_a1111) + self.assertEqual(resize_mode, external_code.ResizeMode.OUTER_FIT) if __name__ == "__main__":