Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion acestep/core/generation/handler/generate_music.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
``AceStepHandler`` so orchestration stays separate from lower-level helpers.
"""

import gc
import traceback
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -251,7 +252,7 @@ def generate_music(
use_tiled_decode=use_tiled_decode,
time_costs=time_costs,
)
return self._build_generate_music_success_payload(
result = self._build_generate_music_success_payload(
outputs=outputs,
pred_wavs=pred_wavs,
pred_latents_cpu=pred_latents_cpu,
Expand All @@ -260,6 +261,20 @@ def generate_music(
actual_batch_size=actual_batch_size,
progress=progress,
)
# Clear GPU tensor references from the mutable outputs dict so
# accelerator memory is reclaimable before the next generation.
_gpu_keys = (
"src_latents", "target_latents_input", "chunk_masks",
"latent_masks", "encoder_hidden_states",
"encoder_attention_mask", "context_latents",
"lyric_token_idss",
)
for _k in _gpu_keys:
outputs.pop(_k, None)
del outputs, pred_wavs, pred_latents_cpu
gc.collect()
self._empty_cache()
return result
except Exception as exc:
error_msg = f"Error: {exc!s}\n{traceback.format_exc()}"
logger.exception("[generate_music] Generation failed")
Expand Down
4 changes: 3 additions & 1 deletion acestep/core/generation/handler/generate_music_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Decode/validation helpers for ``generate_music`` orchestration."""

import gc
import os
import time
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -180,7 +181,6 @@ def _decode_generate_music_pred_latents(
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
self._empty_cache()
logger.debug(
"[generate_music] After VAE decode: "
Expand All @@ -194,6 +194,8 @@ def _decode_generate_music_pred_latents(
if torch.any(peak > 1.0):
pred_wavs = pred_wavs / peak.clamp(min=1.0)
self._empty_cache()
gc.collect()
self._empty_cache()
end_time = time.time()
time_costs["vae_decode_time_cost"] = end_time - start_time
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
Expand Down
61 changes: 61 additions & 0 deletions acestep/core/generation/handler/generate_music_decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,66 @@ def _empty_cache(self):
self.assertGreaterEqual(host.empty_cache_calls, 2)


def test_decode_pred_latents_does_not_restore_latents_to_gpu_after_successful_cpu_decode(self):
"""It does not move pred_latents_for_decode back to GPU after a successful CPU decode.

The removed ``pred_latents_for_decode = pred_latents_for_decode.to(vae_device)`` line
was causing a wasteful re-allocation of the already-decoded input tensor on the
GPU. After the fix, only the VAE itself is restored; the input latent is not.
"""

class _SuccessVae(_FakeVae):
"""VAE double that records transfer calls and succeeds on decode."""

def __init__(self):
"""Initialize transfer call trackers."""
super().__init__()
self.vae_to_calls = []
self._device = "cuda"

def decode(self, latents: torch.Tensor):
"""Return a simple decoded output."""
return _FakeDecodeOutput(torch.ones(latents.shape[0], 2, 8))

def cpu(self):
"""Simulate VAE being moved to CPU."""
self._device = "cpu"
return self

def to(self, *args, **kwargs):
"""Record VAE device-transfer destinations."""
self.vae_to_calls.append(args[0] if args else kwargs)
return self

class _SuccessHost(_Host):
"""Host that forces non-MLX VAE so the CPU-decode path is exercised."""

def __init__(self):
"""Configure non-MLX state and a tracking VAE."""
super().__init__()
self.use_mlx_vae = False
self.mlx_vae = None
self.vae = _SuccessVae()
self.device = "cuda"

host = _SuccessHost()
pred_latents = torch.ones(1, 4, 3)
time_costs = {"total_time_cost": 1.0}

with patch.dict(GENERATE_MUSIC_DECODE_MODULE.os.environ, {"ACESTEP_VAE_ON_CPU": "1"}, clear=False):
pred_wavs, _cpu_latents, _costs = host._decode_generate_music_pred_latents(
pred_latents=pred_latents,
progress=None,
use_tiled_decode=False,
time_costs=time_costs,
)

# VAE itself must be restored to its original device.
self.assertEqual(len(host.vae.vae_to_calls), 1)
# The decoded waveform must be returned correctly.
self.assertEqual(tuple(pred_wavs.shape), (1, 2, 8))


if __name__ == "__main__":
unittest.main()

2 changes: 2 additions & 0 deletions acestep/core/generation/handler/generate_music_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _build_generate_music_success_payload(
for index in range(actual_batch_size):
audio_tensor = pred_wavs[index].cpu()
audio_tensors.append(audio_tensor)
# Free the GPU waveform tensor now that all per-sample CPU copies are done.
del pred_wavs

status_message = "Generation completed successfully!"
logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
Expand Down
61 changes: 61 additions & 0 deletions acestep/core/generation/handler/generate_music_payload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,66 @@ def test_build_success_payload_handles_missing_optional_outputs_without_progress
self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu")


def test_build_success_payload_does_not_mutate_outputs_dict(self):
"""Payload builder must not remove keys from the caller's outputs dict."""
host = _Host()
outputs = {
"target_latents_input": torch.ones(1, 4, 3),
"src_latents": torch.ones(1, 4, 3),
"chunk_masks": torch.ones(1, 4),
"latent_masks": torch.ones(1, 4),
"spans": [(0, 4)],
"encoder_hidden_states": torch.ones(1, 2, 3),
"encoder_attention_mask": torch.ones(1, 2),
"context_latents": torch.ones(1, 4, 3),
"lyric_token_idss": torch.ones(1, 2, dtype=torch.long),
}
original_keys = set(outputs.keys())

host._build_generate_music_success_payload(
outputs=outputs,
pred_wavs=torch.ones(1, 2, 8),
pred_latents_cpu=torch.ones(1, 4, 3),
time_costs={"total_time_cost": 1.0},
seed_value_for_ui=0,
actual_batch_size=1,
progress=None,
)

self.assertEqual(set(outputs.keys()), original_keys)

def test_build_success_payload_all_extra_outputs_are_cpu_tensors(self):
"""It ensures every tensor in extra_outputs is on CPU."""
host = _Host()
outputs = {
"target_latents_input": torch.ones(1, 4, 3),
"src_latents": torch.ones(1, 4, 3),
"encoder_hidden_states": torch.ones(1, 2, 3),
"encoder_attention_mask": torch.ones(1, 2),
"context_latents": torch.ones(1, 4, 3),
"lyric_token_idss": torch.ones(1, 2, dtype=torch.long),
}
pred_wavs = torch.ones(1, 2, 8)
pred_latents_cpu = torch.ones(1, 4, 3)

payload = host._build_generate_music_success_payload(
outputs=outputs,
pred_wavs=pred_wavs,
pred_latents_cpu=pred_latents_cpu,
time_costs={"total_time_cost": 1.0},
seed_value_for_ui=0,
actual_batch_size=1,
progress=None,
)

for key, val in payload["extra_outputs"].items():
if isinstance(val, torch.Tensor):
self.assertEqual(
val.device.type, "cpu",
f"extra_outputs['{key}'] is not on CPU (device={val.device})",
)


if __name__ == "__main__":
unittest.main()

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Context manager for temporary model loading/offloading."""

import gc
import time
from contextlib import contextmanager

Expand Down Expand Up @@ -61,6 +62,7 @@ def _load_model_context(self, model_name: str):
else:
self._recursive_to_device(model, "cpu")

gc.collect()
self._empty_cache()
offload_time = time.time() - start_time
self.current_offload_cost += offload_time
Expand Down
7 changes: 7 additions & 0 deletions acestep/ui/gradio/events/results/batch_management_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def generate_with_batch_management(
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
)

# Release the generator frame and run GC to reclaim any accelerator memory
# that was not yet freed at the end of the inner generator.
del generator
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

result = final_result_from_inner
if result is None:
error_msg = t("messages.batch_failed", error="No generation result was produced")
Expand Down
Loading