diff --git a/invokeai/app/invocations/z_image_image_to_latents.py b/invokeai/app/invocations/z_image_image_to_latents.py index d74ce073b69..5a70fdba13d 100644 --- a/invokeai/app/invocations/z_image_image_to_latents.py +++ b/invokeai/app/invocations/z_image_image_to_latents.py @@ -20,6 +20,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux # Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder] @@ -47,7 +48,14 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso "Ensure you are using a compatible VAE model." ) - with vae_info.model_on_device() as (_, vae): + # Estimate working memory needed for VAE encode + estimated_working_memory = estimate_vae_working_memory_flux( + operation="encode", + image_tensor=image_tensor, + vae=vae_info.model, + ) + + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)): raise TypeError( f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. " diff --git a/invokeai/app/invocations/z_image_latents_to_image.py b/invokeai/app/invocations/z_image_latents_to_image.py index d29ec0998a4..a2e6fdcc077 100644 --- a/invokeai/app/invocations/z_image_latents_to_image.py +++ b/invokeai/app/invocations/z_image_latents_to_image.py @@ -21,6 +21,7 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux # Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder] @@ -53,12 +54,19 @@ def invoke(self, context: InvocationContext) -> ImageOutput: is_flux_vae = isinstance(vae_info.model, FluxAutoEncoder) + # Estimate working memory needed for VAE decode + estimated_working_memory = estimate_vae_working_memory_flux( + operation="decode", + image_tensor=latents, + vae=vae_info.model, + ) + # FLUX VAE doesn't support seamless, so only apply for AutoencoderKL seamless_context = ( nullcontext() if is_flux_vae else SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes) ) - with seamless_context, vae_info.model_on_device() as (_, vae): + with seamless_context, vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): context.util.signal_progress("Running VAE") if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)): raise TypeError( diff --git a/tests/app/invocations/test_z_image_working_memory.py b/tests/app/invocations/test_z_image_working_memory.py new file mode 100644 index 00000000000..c3f953ae527 --- /dev/null +++ b/tests/app/invocations/test_z_image_working_memory.py @@ -0,0 +1,139 @@ +"""Test that Z-Image VAE invocations properly estimate and request working memory.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation +from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder + + +class TestZImageWorkingMemory: + """Test that Z-Image VAE invocations request working memory.""" + + @pytest.mark.parametrize("vae_type", [AutoencoderKL, FluxAutoEncoder]) + def test_z_image_latents_to_image_requests_working_memory(self, vae_type): + """Test that ZImageLatentsToImageInvocation estimates and requests working memory.""" + # Create mock VAE + mock_vae = MagicMock(spec=vae_type) + + # Only set config for AutoencoderKL (FluxAutoEncoder doesn't use config) + if vae_type == AutoencoderKL: + mock_vae.config.scaling_factor = 1.0 + mock_vae.config.shift_factor = None + + # Create mock parameter for dtype detection + mock_param = torch.zeros(1) + mock_vae.parameters.return_value = iter([mock_param]) + + # Create mock vae_info + mock_vae_info = MagicMock() + mock_vae_info.model = mock_vae + + # Create mock context manager return value + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=(None, mock_vae)) + mock_cm.__exit__ = MagicMock(return_value=None) + mock_vae_info.model_on_device = MagicMock(return_value=mock_cm) + + # Mock the context + mock_context = MagicMock() + mock_context.models.load.return_value = mock_vae_info + + # Mock latents + mock_latents = torch.zeros(1, 16, 64, 64) + mock_context.tensors.load.return_value = mock_latents + + estimation_path = "invokeai.app.invocations.z_image_latents_to_image.estimate_vae_working_memory_flux" + + with patch(estimation_path) as mock_estimate: + expected_memory = 1024 * 1024 * 500 # 500MB + mock_estimate.return_value = expected_memory + + # Mock VAE decode to avoid actual computation + if vae_type == FluxAutoEncoder: + mock_vae.decode.return_value = torch.zeros(1, 3, 512, 512) + else: + mock_vae.decode.return_value = (torch.zeros(1, 3, 512, 512),) + + # Mock image save + mock_image_dto = MagicMock() + mock_context.images.save.return_value = mock_image_dto + + # Import and create invocation using model_construct to bypass validation + from invokeai.app.invocations.z_image_latents_to_image import ZImageLatentsToImageInvocation + + invocation = ZImageLatentsToImageInvocation.model_construct( + latents=MagicMock(latents_name="test_latents"), + vae=MagicMock(vae=MagicMock(), seamless_axes=["x", "y"]), + ) + + try: + invocation.invoke(mock_context) + except Exception: + # We expect some errors due to mocking, but we just want to verify the working memory was requested + pass + + # Verify that working memory estimation was called + mock_estimate.assert_called_once() + # Verify that model_on_device was called with the estimated working memory + mock_vae_info.model_on_device.assert_called_once_with(working_mem_bytes=expected_memory) + + @pytest.mark.parametrize("vae_type", [AutoencoderKL, FluxAutoEncoder]) + def test_z_image_image_to_latents_requests_working_memory(self, vae_type): + """Test that ZImageImageToLatentsInvocation estimates and requests working memory.""" + # Create mock VAE + mock_vae = MagicMock(spec=vae_type) + + # Only set config for AutoencoderKL (FluxAutoEncoder doesn't use config) + if vae_type == AutoencoderKL: + mock_vae.config.scaling_factor = 1.0 + mock_vae.config.shift_factor = None + + # Create mock parameter for dtype detection + mock_param = torch.zeros(1) + mock_vae.parameters.return_value = iter([mock_param]) + + # Create mock vae_info + mock_vae_info = MagicMock() + mock_vae_info.model = mock_vae + + # Create mock context manager return value + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=(None, mock_vae)) + mock_cm.__exit__ = MagicMock(return_value=None) + mock_vae_info.model_on_device = MagicMock(return_value=mock_cm) + + # Mock image tensor + mock_image_tensor = torch.zeros(1, 3, 512, 512) + + # Mock the estimation function + estimation_path = "invokeai.app.invocations.z_image_image_to_latents.estimate_vae_working_memory_flux" + + with patch(estimation_path) as mock_estimate: + expected_memory = 1024 * 1024 * 250 # 250MB + mock_estimate.return_value = expected_memory + + # Mock VAE encode to avoid actual computation + if vae_type == FluxAutoEncoder: + mock_vae.encode.return_value = torch.zeros(1, 16, 64, 64) + else: + mock_latent_dist = MagicMock() + mock_latent_dist.sample.return_value = torch.zeros(1, 16, 64, 64) + mock_encode_result = MagicMock() + mock_encode_result.latent_dist = mock_latent_dist + mock_vae.encode.return_value = mock_encode_result + + # Call the static method directly + try: + ZImageImageToLatentsInvocation.vae_encode(mock_vae_info, mock_image_tensor) + except Exception: + # We expect some errors due to mocking, but we just want to verify the working memory was requested + pass + + # Verify that working memory estimation was called + mock_estimate.assert_called_once() + # Verify that model_on_device was called with the estimated working memory + mock_vae_info.model_on_device.assert_called_once_with(working_mem_bytes=expected_memory)