diff --git a/acestep/core/generation/handler/generate_music.py b/acestep/core/generation/handler/generate_music.py index 39d0d59e..4160bea5 100644 --- a/acestep/core/generation/handler/generate_music.py +++ b/acestep/core/generation/handler/generate_music.py @@ -4,6 +4,7 @@ ``AceStepHandler`` so orchestration stays separate from lower-level helpers. """ +import gc import traceback from typing import Any, Dict, List, Optional, Union @@ -251,7 +252,7 @@ def generate_music( use_tiled_decode=use_tiled_decode, time_costs=time_costs, ) - return self._build_generate_music_success_payload( + result = self._build_generate_music_success_payload( outputs=outputs, pred_wavs=pred_wavs, pred_latents_cpu=pred_latents_cpu, @@ -260,6 +261,20 @@ def generate_music( actual_batch_size=actual_batch_size, progress=progress, ) + # Clear GPU tensor references from the mutable outputs dict so + # accelerator memory is reclaimable before the next generation. + _gpu_keys = ( + "src_latents", "target_latents_input", "chunk_masks", + "latent_masks", "encoder_hidden_states", + "encoder_attention_mask", "context_latents", + "lyric_token_idss", + ) + for _k in _gpu_keys: + outputs.pop(_k, None) + del outputs, pred_wavs, pred_latents_cpu + gc.collect() + self._empty_cache() + return result except Exception as exc: error_msg = f"Error: {exc!s}\n{traceback.format_exc()}" logger.exception("[generate_music] Generation failed") diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index ffaaebbe..1ec9b0fb 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -1,5 +1,6 @@ """Decode/validation helpers for ``generate_music`` orchestration.""" +import gc import os import time from typing import Any, Dict, Optional, Tuple @@ -180,7 +181,6 @@ def _decode_generate_music_pred_latents( if vae_cpu and vae_device is not None: logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") self.vae = self.vae.to(vae_device) - pred_latents_for_decode = pred_latents_for_decode.to(vae_device) self._empty_cache() logger.debug( "[generate_music] After VAE decode: " @@ -194,6 +194,8 @@ def _decode_generate_music_pred_latents( if torch.any(peak > 1.0): pred_wavs = pred_wavs / peak.clamp(min=1.0) self._empty_cache() + gc.collect() + self._empty_cache() end_time = time.time() time_costs["vae_decode_time_cost"] = end_time - start_time time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index 82cd04fe..ece612ff 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -251,5 +251,66 @@ def _empty_cache(self): self.assertGreaterEqual(host.empty_cache_calls, 2) + def test_decode_pred_latents_does_not_restore_latents_to_gpu_after_successful_cpu_decode(self): + """It does not move pred_latents_for_decode back to GPU after a successful CPU decode. + + The removed ``pred_latents_for_decode = pred_latents_for_decode.to(vae_device)`` line + was causing a wasteful re-allocation of the already-decoded input tensor on the + GPU. After the fix, only the VAE itself is restored; the input latent is not. + """ + + class _SuccessVae(_FakeVae): + """VAE double that records transfer calls and succeeds on decode.""" + + def __init__(self): + """Initialize transfer call trackers.""" + super().__init__() + self.vae_to_calls = [] + self._device = "cuda" + + def decode(self, latents: torch.Tensor): + """Return a simple decoded output.""" + return _FakeDecodeOutput(torch.ones(latents.shape[0], 2, 8)) + + def cpu(self): + """Simulate VAE being moved to CPU.""" + self._device = "cpu" + return self + + def to(self, *args, **kwargs): + """Record VAE device-transfer destinations.""" + self.vae_to_calls.append(args[0] if args else kwargs) + return self + + class _SuccessHost(_Host): + """Host that forces non-MLX VAE so the CPU-decode path is exercised.""" + + def __init__(self): + """Configure non-MLX state and a tracking VAE.""" + super().__init__() + self.use_mlx_vae = False + self.mlx_vae = None + self.vae = _SuccessVae() + self.device = "cuda" + + host = _SuccessHost() + pred_latents = torch.ones(1, 4, 3) + time_costs = {"total_time_cost": 1.0} + + with patch.dict(GENERATE_MUSIC_DECODE_MODULE.os.environ, {"ACESTEP_VAE_ON_CPU": "1"}, clear=False): + pred_wavs, _cpu_latents, _costs = host._decode_generate_music_pred_latents( + pred_latents=pred_latents, + progress=None, + use_tiled_decode=False, + time_costs=time_costs, + ) + + # VAE itself must be restored to its original device. + self.assertEqual(len(host.vae.vae_to_calls), 1) + # The decoded waveform must be returned correctly. + self.assertEqual(tuple(pred_wavs.shape), (1, 2, 8)) + + if __name__ == "__main__": unittest.main() + diff --git a/acestep/core/generation/handler/generate_music_payload.py b/acestep/core/generation/handler/generate_music_payload.py index d8ee4c7e..fec55858 100644 --- a/acestep/core/generation/handler/generate_music_payload.py +++ b/acestep/core/generation/handler/generate_music_payload.py @@ -40,6 +40,8 @@ def _build_generate_music_success_payload( for index in range(actual_batch_size): audio_tensor = pred_wavs[index].cpu() audio_tensors.append(audio_tensor) + # Free the GPU waveform tensor now that all per-sample CPU copies are done. + del pred_wavs status_message = "Generation completed successfully!" logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") diff --git a/acestep/core/generation/handler/generate_music_payload_test.py b/acestep/core/generation/handler/generate_music_payload_test.py index 47a3ba8a..bf5d11a1 100644 --- a/acestep/core/generation/handler/generate_music_payload_test.py +++ b/acestep/core/generation/handler/generate_music_payload_test.py @@ -135,5 +135,66 @@ def test_build_success_payload_handles_missing_optional_outputs_without_progress self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu") + def test_build_success_payload_does_not_mutate_outputs_dict(self): + """Payload builder must not remove keys from the caller's outputs dict.""" + host = _Host() + outputs = { + "target_latents_input": torch.ones(1, 4, 3), + "src_latents": torch.ones(1, 4, 3), + "chunk_masks": torch.ones(1, 4), + "latent_masks": torch.ones(1, 4), + "spans": [(0, 4)], + "encoder_hidden_states": torch.ones(1, 2, 3), + "encoder_attention_mask": torch.ones(1, 2), + "context_latents": torch.ones(1, 4, 3), + "lyric_token_idss": torch.ones(1, 2, dtype=torch.long), + } + original_keys = set(outputs.keys()) + + host._build_generate_music_success_payload( + outputs=outputs, + pred_wavs=torch.ones(1, 2, 8), + pred_latents_cpu=torch.ones(1, 4, 3), + time_costs={"total_time_cost": 1.0}, + seed_value_for_ui=0, + actual_batch_size=1, + progress=None, + ) + + self.assertEqual(set(outputs.keys()), original_keys) + + def test_build_success_payload_all_extra_outputs_are_cpu_tensors(self): + """It ensures every tensor in extra_outputs is on CPU.""" + host = _Host() + outputs = { + "target_latents_input": torch.ones(1, 4, 3), + "src_latents": torch.ones(1, 4, 3), + "encoder_hidden_states": torch.ones(1, 2, 3), + "encoder_attention_mask": torch.ones(1, 2), + "context_latents": torch.ones(1, 4, 3), + "lyric_token_idss": torch.ones(1, 2, dtype=torch.long), + } + pred_wavs = torch.ones(1, 2, 8) + pred_latents_cpu = torch.ones(1, 4, 3) + + payload = host._build_generate_music_success_payload( + outputs=outputs, + pred_wavs=pred_wavs, + pred_latents_cpu=pred_latents_cpu, + time_costs={"total_time_cost": 1.0}, + seed_value_for_ui=0, + actual_batch_size=1, + progress=None, + ) + + for key, val in payload["extra_outputs"].items(): + if isinstance(val, torch.Tensor): + self.assertEqual( + val.device.type, "cpu", + f"extra_outputs['{key}'] is not on CPU (device={val.device})", + ) + + if __name__ == "__main__": unittest.main() + diff --git a/acestep/core/generation/handler/init_service_offload_context.py b/acestep/core/generation/handler/init_service_offload_context.py index df61afa7..cdafd450 100644 --- a/acestep/core/generation/handler/init_service_offload_context.py +++ b/acestep/core/generation/handler/init_service_offload_context.py @@ -1,5 +1,6 @@ """Context manager for temporary model loading/offloading.""" +import gc import time from contextlib import contextmanager @@ -61,6 +62,7 @@ def _load_model_context(self, model_name: str): else: self._recursive_to_device(model, "cpu") + gc.collect() self._empty_cache() offload_time = time.time() - start_time self.current_offload_cost += offload_time diff --git a/acestep/ui/gradio/events/results/batch_management_wrapper.py b/acestep/ui/gradio/events/results/batch_management_wrapper.py index bad0f2df..cc323211 100644 --- a/acestep/ui/gradio/events/results/batch_management_wrapper.py +++ b/acestep/ui/gradio/events/results/batch_management_wrapper.py @@ -84,6 +84,13 @@ def generate_with_batch_management( gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), ) + # Release the generator frame and run GC to reclaim any accelerator memory + # that was not yet freed at the end of the inner generator. + del generator + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + result = final_result_from_inner if result is None: error_msg = t("messages.batch_failed", error="No generation result was produced")