Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions acestep/core/generation/handler/generate_music.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
import traceback
from typing import Any, Dict, List, Optional, Union

import torch
from loguru import logger

from acestep.constants import DEFAULT_DIT_INSTRUCTION
from acestep.gpu_config import (
DIT_INFERENCE_VRAM_PER_BATCH,
VRAM_SAFETY_MARGIN_GB,
get_effective_free_vram_gb,
)


class GenerateMusicMixin:
Expand All @@ -19,6 +25,65 @@ class GenerateMusicMixin:
orchestration flow.
"""

def _vram_preflight_check(
self,
actual_batch_size: int,
audio_duration: Optional[float],
guidance_scale: float,
) -> Optional[Dict[str, Any]]:
"""Check free VRAM headroom before attempting service_generate.

Model weights are already resident in GPU memory at this point. We
only need to verify there is enough room for the diffusion-pass
activations (intermediate attention maps, FFN buffers, noise tensors)
plus a project-standard safety margin.

Args:
actual_batch_size: Number of samples being generated.
audio_duration: Requested audio length in seconds, or None for default.
guidance_scale: CFG guidance value; values > 1.0 indicate CFG is active
and the DiT runs two forward passes per step (doubling activation memory).

Returns:
An error payload dict when VRAM is insufficient, or None when the
check passes or no CUDA device is present (CPU/MPS/XPU fall through).
"""
if not torch.cuda.is_available():
return None

duration_s = audio_duration or 60.0
# CFG doubles forward-pass memory: two DiT evaluations per step.
dit_key = "base" if guidance_scale > 1.0 else "turbo"
per_batch_gb = DIT_INFERENCE_VRAM_PER_BATCH.get(dit_key, 0.6)
# Longer audio = more latent frames (5 Hz rate) = more memory.
duration_factor = max(1.0, duration_s / 60.0)
needed_gb = per_batch_gb * actual_batch_size * duration_factor + VRAM_SAFETY_MARGIN_GB

free_gb = get_effective_free_vram_gb()
logger.info(
"[generate_music] VRAM pre-flight: {:.2f} GB free, ~{:.2f} GB needed "
"(batch={}, duration={:.0f}s, mode={}).",
free_gb, needed_gb, actual_batch_size, duration_s, dit_key,
)

if free_gb >= needed_gb:
return None
Comment on lines +51 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

get_effective_free_vram_gb() returning 0.0 on internal error will falsely block generation.

When CUDA is available but get_effective_free_vram_gb() hits an unexpected internal exception, it returns 0.0 (per its own except clause). Since 0.0 < needed_gb is always true, this will surface a "not enough VRAM" error even though the GPU may have plenty of memory.

Consider guarding against this — e.g., treat free_gb <= 0.0 as "unable to query VRAM" and let the generation attempt proceed:

Proposed fix
         free_gb = get_effective_free_vram_gb()
+        if free_gb <= 0.0:
+            logger.warning(
+                "[generate_music] VRAM pre-flight: unable to query free VRAM; "
+                "skipping check and allowing generation to proceed."
+            )
+            return None
+
         logger.info(
             "[generate_music] VRAM pre-flight: {:.2f} GB free, ~{:.2f} GB needed "
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music.py` around lines 51 - 70, The
VRAM check in generate_music is blocking when get_effective_free_vram_gb()
returns 0.0 on error; change the logic so that if free_gb is None or free_gb <=
0.0 you treat it as "unable to query VRAM" and do not block generation: call
logger.warning (or similar) to record the failed VRAM query referencing
get_effective_free_vram_gb/free_gb/needed_gb and then return None to proceed;
otherwise keep the existing check (if free_gb >= needed_gb return None, else
raise or handle insufficient VRAM).


msg = (
f"Insufficient free VRAM: need ~{needed_gb:.1f} GB, "
f"only {free_gb:.1f} GB available. "
f"Reduce batch size (currently {actual_batch_size}) "
f"or audio duration (currently {duration_s:.0f}s)."
)
logger.warning("[generate_music] VRAM pre-flight failed: {}", msg)
return {
"audios": [],
"status_message": f"Error: {msg}",
"extra_outputs": {},
"success": False,
"error": msg,
}

def generate_music(
self,
captions: str,
Expand Down Expand Up @@ -134,6 +199,14 @@ def generate_music(
repainting_start=repainting_start,
repainting_end=repainting_end,
)
vram_error = self._vram_preflight_check(
actual_batch_size=actual_batch_size,
audio_duration=audio_duration,
guidance_scale=guidance_scale,
)
if vram_error is not None:
return vram_error

service_run = self._run_generate_music_service_with_progress(
progress=progress,
actual_batch_size=actual_batch_size,
Expand Down
177 changes: 153 additions & 24 deletions acestep/core/generation/handler/generate_music_execute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,45 @@
"""Execution helper for ``generate_music`` service invocation with progress tracking."""

import os
import threading
from typing import Any, Dict, List, Optional, Sequence

import torch
from loguru import logger


def _parse_generation_timeout() -> int:
"""Parse ACESTEP_GENERATION_TIMEOUT safely; fall back to 600 on invalid values.

Two common misconfiguration traps are handled explicitly:
- A non-numeric value raises ``ValueError`` at module import time if parsed
naively with ``int(os.environ.get(...))``, crashing the server before any
generation runs.
- A value of 0 or negative makes ``Thread.join(timeout=0)`` return
immediately, so ``is_alive()`` is always ``True`` and every generation
would time out instantly.
"""
raw = os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600")
try:
val = int(raw)
except ValueError:
logger.warning(
"ACESTEP_GENERATION_TIMEOUT={!r} is not a valid integer; defaulting to 600s.", raw
)
return 600
if val <= 0:
logger.warning(
"ACESTEP_GENERATION_TIMEOUT={} must be positive; defaulting to 600s.", val
)
return 600
return val


# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
# Generous default: most generations finish in 30-120 s, but large batches on slow
# GPUs can take several minutes. Override via ACESTEP_GENERATION_TIMEOUT env var.
_DEFAULT_GENERATION_TIMEOUT: int = _parse_generation_timeout()


class GenerateMusicExecuteMixin:
"""Run service generation under diffusion progress estimation lifecycle."""
Expand All @@ -25,7 +63,42 @@ def _run_generate_music_service_with_progress(
shift: float,
infer_method: str,
) -> Dict[str, Any]:
"""Invoke ``service_generate`` while maintaining background progress estimation."""
"""Invoke ``service_generate`` while maintaining background progress estimation.

``service_generate`` is a blocking CUDA call. On mid-tier hardware with
VRAM fragmentation it can hang indefinitely, freezing the Gradio UI. We
run it in a daemon thread and enforce ``_DEFAULT_GENERATION_TIMEOUT``
seconds of wall-clock patience before surfacing a ``TimeoutError``.

Args:
progress: Gradio-style progress callback.
actual_batch_size: Number of audio samples to generate.
audio_duration: Requested audio length in seconds, or None for default.
inference_steps: Number of diffusion steps.
timesteps: Optional custom timestep schedule; overrides ``inference_steps``
for progress tracking when provided.
service_inputs: Pre-processed batch tensors and metadata from
``_prepare_generate_music_service_inputs``.
refer_audios: Optional reference audio tensors for conditioning.
guidance_scale: CFG guidance value forwarded to ``service_generate``.
actual_seed_list: Per-sample PRNG seeds.
audio_cover_strength: Cover strength parameter.
cover_noise_strength: Cover noise strength parameter.
use_adg: Whether to use adaptive guidance.
cfg_interval_start: CFG interval start fraction.
cfg_interval_end: CFG interval end fraction.
shift: Scheduler shift value.
infer_method: Diffusion method name (e.g. ``"ode"``).

Returns:
Dict with ``"outputs"`` (service_generate return value) and
``"infer_steps_for_progress"`` (effective step count used for tracking).

Raises:
TimeoutError: when ``service_generate`` exceeds the configured timeout.
BaseException: any exception raised by ``service_generate`` is re-raised
transparently so upstream handlers see the original error.
"""
infer_steps_for_progress = len(timesteps) if timesteps else inference_steps
progress_desc = f"Generating music (batch size: {actual_batch_size})..."
progress(0.52, desc=progress_desc)
Expand All @@ -41,30 +114,86 @@ def _run_generate_music_service_with_progress(
duration_sec=audio_duration if audio_duration and audio_duration > 0 else None,
desc=progress_desc,
)
outputs = self.service_generate(
captions=service_inputs["captions_batch"],
lyrics=service_inputs["lyrics_batch"],
metas=service_inputs["metas_batch"],
vocal_languages=service_inputs["vocal_languages_batch"],
refer_audios=refer_audios,
target_wavs=service_inputs["target_wavs_tensor"],
infer_steps=inference_steps,
guidance_scale=guidance_scale,
seed=actual_seed_list,
repainting_start=service_inputs["repainting_start_batch"],
repainting_end=service_inputs["repainting_end_batch"],
instructions=service_inputs["instructions_batch"],
audio_cover_strength=audio_cover_strength,
cover_noise_strength=cover_noise_strength,
use_adg=use_adg,
cfg_interval_start=cfg_interval_start,
cfg_interval_end=cfg_interval_end,
shift=shift,
infer_method=infer_method,
audio_code_hints=service_inputs["audio_code_hints_batch"],
return_intermediate=service_inputs["should_return_intermediate"],
timesteps=timesteps,

_result: Dict[str, Any] = {}
_error: Dict[str, BaseException] = {}

def _service_target() -> None:
try:
_result["outputs"] = self.service_generate(
captions=service_inputs["captions_batch"],
lyrics=service_inputs["lyrics_batch"],
metas=service_inputs["metas_batch"],
vocal_languages=service_inputs["vocal_languages_batch"],
refer_audios=refer_audios,
target_wavs=service_inputs["target_wavs_tensor"],
infer_steps=inference_steps,
guidance_scale=guidance_scale,
seed=actual_seed_list,
repainting_start=service_inputs["repainting_start_batch"],
repainting_end=service_inputs["repainting_end_batch"],
instructions=service_inputs["instructions_batch"],
audio_cover_strength=audio_cover_strength,
cover_noise_strength=cover_noise_strength,
use_adg=use_adg,
cfg_interval_start=cfg_interval_start,
cfg_interval_end=cfg_interval_end,
shift=shift,
infer_method=infer_method,
audio_code_hints=service_inputs["audio_code_hints_batch"],
return_intermediate=service_inputs["should_return_intermediate"],
timesteps=timesteps,
)
except BaseException as exc: # noqa: BLE001 — ferry all exceptions across thread boundary
_error["exc"] = exc

gen_thread = threading.Thread(
target=_service_target, daemon=True, name="service-generate"
)
gen_thread.start()
gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT)

if gen_thread.is_alive():
# Attempt to recover VRAM from the stalled CUDA context so that
# the next generation attempt has a fighting chance.
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Count orphaned threads so operators can detect buildup from
# repeated timeouts without restarting the server.
stalled = sum(
1 for t in threading.enumerate() if t.name == "service-generate"
)
logger.error(
"[generate_music] service_generate exceeded {}s timeout "
"(batch={}, steps={}, duration={:.1f}s, "
"orphaned_service_generate_threads={}). "
"The CUDA operation may still be running in the background.",
_DEFAULT_GENERATION_TIMEOUT,
actual_batch_size,
inference_steps,
audio_duration or 0.0,
stalled,
)
raise TimeoutError(
f"Music generation timed out after {_DEFAULT_GENERATION_TIMEOUT} seconds. "
"This usually means the GPU ran out of VRAM or the diffusion loop stalled. "
"Try reducing batch size, duration, or inference steps."
)

# Re-raise any exception that escaped the worker thread.
if "exc" in _error:
raise _error["exc"]

# Defensive guard: the thread completed without raising but also
# without populating _result. This should never happen with
# except BaseException above, but guard against future refactors.
if "outputs" not in _result:
raise RuntimeError(
"service_generate completed without producing outputs or raising "
"an exception — this is unexpected. Please report this as a bug."
)

outputs = _result["outputs"]
finally:
if stop_event is not None:
stop_event.set()
Expand Down
38 changes: 36 additions & 2 deletions acestep/core/generation/handler/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
import time
from typing import Optional

from loguru import logger

# Conservative per-step wall-clock estimate used when no timing history exists
# (i.e. first run on a fresh machine). 2.5 s/step is typical for mid-tier GPUs;
# the value is multiplied by batch_size so larger batches get a proportionally
# longer expected duration, keeping the progress bar from jumping to the end too
# quickly.
_FALLBACK_PER_STEP_SEC_PER_BATCH: float = 2.5


class ProgressMixin:
def _get_project_root(self) -> str:
Expand Down Expand Up @@ -146,7 +155,26 @@ def _start_diffusion_progress_estimator(
duration_sec: Optional[float],
desc: str,
):
"""Best-effort progress updates during diffusion using previous step timing."""
"""Start a daemon thread that emits best-effort progress updates during diffusion.

Estimates expected duration from timing history or falls back to a
conservative constant scaled by ``batch_size``. Returns ``(None, None)``
when ``progress`` is ``None`` or ``infer_steps <= 0``.

Args:
progress: Gradio-style progress callback, or None to skip updates.
start: Progress range lower bound (0-1).
end: Progress range upper bound (0-1).
infer_steps: Number of diffusion steps.
batch_size: Number of items in the current batch.
duration_sec: Target audio duration in seconds, used to refine estimates.
desc: Description string forwarded to the progress callback.

Returns:
Tuple of ``(stop_event, thread)``; call ``stop_event.set()`` to halt
updates and then join the thread. Returns ``(None, None)`` when
updates cannot be started.
"""
if progress is None or infer_steps <= 0:
return None, None
per_step = self._estimate_diffusion_per_step(
Expand All @@ -155,7 +183,13 @@ def _start_diffusion_progress_estimator(
duration_sec=duration_sec,
) or self._last_diffusion_per_step_sec
if not per_step or per_step <= 0:
return None, None
per_step = _FALLBACK_PER_STEP_SEC_PER_BATCH * max(1, batch_size)
logger.debug(
"[progress] No timing history available; using conservative cold-start "
"fallback of {:.1f}s/step (batch_size={}).",
per_step,
batch_size,
)
expected = per_step * infer_steps
if expected <= 0:
return None, None
Expand Down