Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions acestep/core/generation/handler/generate_music.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
import traceback
from typing import Any, Dict, List, Optional, Union

import torch
from loguru import logger

from acestep.constants import DEFAULT_DIT_INSTRUCTION
from acestep.gpu_config import (
DIT_INFERENCE_VRAM_PER_BATCH,
VRAM_SAFETY_MARGIN_GB,
get_effective_free_vram_gb,
)


class GenerateMusicMixin:
Expand All @@ -19,6 +25,65 @@ class GenerateMusicMixin:
orchestration flow.
"""

def _vram_preflight_check(
self,
actual_batch_size: int,
audio_duration: Optional[float],
guidance_scale: float,
) -> Optional[Dict[str, Any]]:
"""Check free VRAM headroom before attempting service_generate.

Model weights are already resident in GPU memory at this point. We
only need to verify there is enough room for the diffusion-pass
activations (intermediate attention maps, FFN buffers, noise tensors)
plus a project-standard safety margin.

Args:
actual_batch_size: Number of samples being generated.
audio_duration: Requested audio length in seconds, or None for default.
guidance_scale: CFG guidance value; values > 1.0 indicate CFG is active
and the DiT runs two forward passes per step (doubling activation memory).

Returns:
An error payload dict when VRAM is insufficient, or None when the
check passes or no CUDA device is present (CPU/MPS/XPU fall through).
"""
if not torch.cuda.is_available():
return None

duration_s = audio_duration or 60.0
# CFG doubles forward-pass memory: two DiT evaluations per step.
dit_key = "base" if guidance_scale > 1.0 else "turbo"
per_batch_gb = DIT_INFERENCE_VRAM_PER_BATCH.get(dit_key, 0.6)
# Longer audio = more latent frames (5 Hz rate) = more memory.
duration_factor = max(1.0, duration_s / 60.0)
needed_gb = per_batch_gb * actual_batch_size * duration_factor + VRAM_SAFETY_MARGIN_GB

free_gb = get_effective_free_vram_gb()
logger.info(
"[generate_music] VRAM pre-flight: {:.2f} GB free, ~{:.2f} GB needed "
"(batch={}, duration={:.0f}s, mode={}).",
free_gb, needed_gb, actual_batch_size, duration_s, dit_key,
)

if free_gb >= needed_gb:
return None
Comment on lines +51 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

get_effective_free_vram_gb() returning 0.0 on internal error will falsely block generation.

When CUDA is available but get_effective_free_vram_gb() hits an unexpected internal exception, it returns 0.0 (per its own except clause). Since 0.0 < needed_gb is always true, this will surface a "not enough VRAM" error even though the GPU may have plenty of memory.

Consider guarding against this — e.g., treat free_gb <= 0.0 as "unable to query VRAM" and let the generation attempt proceed:

Proposed fix
         free_gb = get_effective_free_vram_gb()
+        if free_gb <= 0.0:
+            logger.warning(
+                "[generate_music] VRAM pre-flight: unable to query free VRAM; "
+                "skipping check and allowing generation to proceed."
+            )
+            return None
+
         logger.info(
             "[generate_music] VRAM pre-flight: {:.2f} GB free, ~{:.2f} GB needed "
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music.py` around lines 51 - 70, The
VRAM check in generate_music is blocking when get_effective_free_vram_gb()
returns 0.0 on error; change the logic so that if free_gb is None or free_gb <=
0.0 you treat it as "unable to query VRAM" and do not block generation: call
logger.warning (or similar) to record the failed VRAM query referencing
get_effective_free_vram_gb/free_gb/needed_gb and then return None to proceed;
otherwise keep the existing check (if free_gb >= needed_gb return None, else
raise or handle insufficient VRAM).


msg = (
f"Insufficient free VRAM: need ~{needed_gb:.1f} GB, "
f"only {free_gb:.1f} GB available. "
f"Reduce batch size (currently {actual_batch_size}) "
f"or audio duration (currently {duration_s:.0f}s)."
)
logger.warning("[generate_music] VRAM pre-flight failed: {}", msg)
return {
"audios": [],
"status_message": f"Error: {msg}",
"extra_outputs": {},
"success": False,
"error": msg,
}

def generate_music(
self,
captions: str,
Expand Down Expand Up @@ -134,6 +199,14 @@ def generate_music(
repainting_start=repainting_start,
repainting_end=repainting_end,
)
vram_error = self._vram_preflight_check(
actual_batch_size=actual_batch_size,
audio_duration=audio_duration,
guidance_scale=guidance_scale,
)
if vram_error is not None:
return vram_error

service_run = self._run_generate_music_service_with_progress(
progress=progress,
actual_batch_size=actual_batch_size,
Expand Down
103 changes: 78 additions & 25 deletions acestep/core/generation/handler/generate_music_execute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
"""Execution helper for ``generate_music`` service invocation with progress tracking."""

import os
import threading
from typing import Any, Dict, List, Optional, Sequence

from loguru import logger

# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
# Generous default: most generations finish in 30-120s, but large batches on slow
# GPUs can take several minutes. Override via ACESTEP_GENERATION_TIMEOUT env var.
_DEFAULT_GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
Comment on lines +9 to +12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Unsafe env-var parsing — non-numeric values crash at import time.

The PR description promises safe parsing with fallback to 600 s, but this bare int() raises ValueError on non-numeric input (e.g., ACESTEP_GENERATION_TIMEOUT="abc"), crashing the import. Zero and negative values also produce degenerate behavior (thread.join(timeout=0) returns immediately, always triggering the timeout path).

🐛 Proposed fix: safe parsing with validation
-# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
-# Generous default: most generations finish in 30-120s, but large batches on slow
-# GPUs can take several minutes.  Override via ACESTEP_GENERATION_TIMEOUT env var.
-_DEFAULT_GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
+# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
+# Generous default: most generations finish in 30-120s, but large batches on slow
+# GPUs can take several minutes.  Override via ACESTEP_GENERATION_TIMEOUT env var.
+_DEFAULT_GENERATION_TIMEOUT: int = 600
+
+_raw_timeout = os.environ.get("ACESTEP_GENERATION_TIMEOUT")
+if _raw_timeout is not None:
+    try:
+        _parsed = int(_raw_timeout)
+        if _parsed > 0:
+            _DEFAULT_GENERATION_TIMEOUT = _parsed
+        else:
+            logger.warning(
+                f"ACESTEP_GENERATION_TIMEOUT={_raw_timeout!r} is <= 0; "
+                f"falling back to {_DEFAULT_GENERATION_TIMEOUT}s."
+            )
+    except ValueError:
+        logger.warning(
+            f"ACESTEP_GENERATION_TIMEOUT={_raw_timeout!r} is not a valid integer; "
+            f"falling back to {_DEFAULT_GENERATION_TIMEOUT}s."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_execute.py` around lines 9 -
12, The module-level _DEFAULT_GENERATION_TIMEOUT is currently set with bare
int(os.environ.get(...)) which will raise ValueError on non-numeric input and
allows zero/negative values; change the initialization to parse
ACESTEP_GENERATION_TIMEOUT safely: read the env var, attempt to convert to int
inside a try/except, fall back to 600 on any exception or if the parsed value is
<= 0 (optionally clamp to a minimum like 1), and use that validated value for
_DEFAULT_GENERATION_TIMEOUT; reference the symbol _DEFAULT_GENERATION_TIMEOUT
and the env var name ACESTEP_GENERATION_TIMEOUT when making the change and add a
debug/warn log if parsing fails.



class GenerateMusicExecuteMixin:
"""Run service generation under diffusion progress estimation lifecycle."""
Expand All @@ -25,12 +34,55 @@ def _run_generate_music_service_with_progress(
shift: float,
infer_method: str,
) -> Dict[str, Any]:
"""Invoke ``service_generate`` while maintaining background progress estimation."""
"""Invoke ``service_generate`` while maintaining background progress estimation.

Wraps the synchronous CUDA call in a monitored thread so that a hung
diffusion loop becomes a recoverable ``TimeoutError`` instead of a
permanent UI freeze.
"""
infer_steps_for_progress = len(timesteps) if timesteps else inference_steps
progress_desc = f"Generating music (batch size: {actual_batch_size})..."
progress(0.52, desc=progress_desc)
stop_event = None
progress_thread = None

# --- Timeout-wrapped service_generate ---
# Run the actual CUDA work in a child thread so we can join() with a
# deadline. If it exceeds the timeout the calling thread unblocks and
# raises TimeoutError, which propagates to generate_music()'s
# try/except and becomes a clean error payload for the UI.
_result: Dict[str, Any] = {}
_error: Dict[str, BaseException] = {}

def _service_target():
try:
_result["outputs"] = self.service_generate(
captions=service_inputs["captions_batch"],
lyrics=service_inputs["lyrics_batch"],
metas=service_inputs["metas_batch"],
vocal_languages=service_inputs["vocal_languages_batch"],
refer_audios=refer_audios,
target_wavs=service_inputs["target_wavs_tensor"],
infer_steps=inference_steps,
guidance_scale=guidance_scale,
seed=actual_seed_list,
repainting_start=service_inputs["repainting_start_batch"],
repainting_end=service_inputs["repainting_end_batch"],
instructions=service_inputs["instructions_batch"],
audio_cover_strength=audio_cover_strength,
cover_noise_strength=cover_noise_strength,
use_adg=use_adg,
cfg_interval_start=cfg_interval_start,
cfg_interval_end=cfg_interval_end,
shift=shift,
infer_method=infer_method,
audio_code_hints=service_inputs["audio_code_hints_batch"],
return_intermediate=service_inputs["should_return_intermediate"],
timesteps=timesteps,
)
except Exception as exc:
_error["exc"] = exc
Comment on lines +83 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

except Exception won't catch BaseException subclasses in the worker thread.

If SystemExit (or another BaseException subclass) is raised inside the worker, neither _result nor _error gets populated. The thread exits silently, gen_thread.is_alive() returns False, the error check on line 117 passes, and line 126 crashes with KeyError: 'outputs'.

Use except BaseException here to ferry all exceptions across the thread boundary, matching the stated intent in the PR description. This also addresses the Ruff BLE001 hint — in this specific pattern (worker thread → re-raise on caller) broad catching is intentional.

🐛 Proposed fix
-            except Exception as exc:
+            except BaseException as exc:
                 _error["exc"] = exc
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 83-83: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_execute.py` around lines 83 -
84, Replace the narrow except with a BaseException catcher so worker-thread
fatal exceptions (e.g. SystemExit, KeyboardInterrupt) are captured and
transported back to the caller: in the exception handling block that assigns to
_error (the try/except around the worker that currently does "except Exception
as exc" and sets _error["exc"] = exc), change it to catch BaseException and
store the exception object so the caller (the code that checks
gen_thread.is_alive() and later inspects _error/_result) can re-raise or handle
it; keep the same _error key and behavior but use BaseException to match the
intended cross-thread re-raise pattern used by this module
(generate_music_execute.py, the _error dict and the exc variable).


try:
stop_event, progress_thread = self._start_diffusion_progress_estimator(
progress=progress,
Expand All @@ -41,33 +93,34 @@ def _run_generate_music_service_with_progress(
duration_sec=audio_duration if audio_duration and audio_duration > 0 else None,
desc=progress_desc,
)
outputs = self.service_generate(
captions=service_inputs["captions_batch"],
lyrics=service_inputs["lyrics_batch"],
metas=service_inputs["metas_batch"],
vocal_languages=service_inputs["vocal_languages_batch"],
refer_audios=refer_audios,
target_wavs=service_inputs["target_wavs_tensor"],
infer_steps=inference_steps,
guidance_scale=guidance_scale,
seed=actual_seed_list,
repainting_start=service_inputs["repainting_start_batch"],
repainting_end=service_inputs["repainting_end_batch"],
instructions=service_inputs["instructions_batch"],
audio_cover_strength=audio_cover_strength,
cover_noise_strength=cover_noise_strength,
use_adg=use_adg,
cfg_interval_start=cfg_interval_start,
cfg_interval_end=cfg_interval_end,
shift=shift,
infer_method=infer_method,
audio_code_hints=service_inputs["audio_code_hints_batch"],
return_intermediate=service_inputs["should_return_intermediate"],
timesteps=timesteps,

gen_thread = threading.Thread(
target=_service_target,
name="service-generate",
daemon=True,
)
gen_thread.start()
gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT)

if gen_thread.is_alive():
logger.error(
f"[generate_music] service_generate exceeded {_DEFAULT_GENERATION_TIMEOUT}s "
f"timeout (batch={actual_batch_size}, steps={inference_steps}, "
f"duration={audio_duration}s). The CUDA operation may still be "
f"running in the background."
)
raise TimeoutError(
f"Music generation timed out after {_DEFAULT_GENERATION_TIMEOUT} seconds. "
f"This usually means the GPU ran out of VRAM or the diffusion loop "
f"stalled. Try reducing batch size, duration, or inference steps."
)
if "exc" in _error:
raise _error["exc"]

finally:
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)
return {"outputs": outputs, "infer_steps_for_progress": infer_steps_for_progress}

return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing defensive guard on _result["outputs"] — potential KeyError.

The PR description mentions a defensive guard here, but the code directly accesses _result["outputs"] which will raise KeyError if the worker thread exits without populating it (e.g., via an uncaught BaseException). Even with the except BaseException fix above, a guard here provides belt-and-suspenders safety.

🐛 Proposed fix
-        return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
+        if "outputs" not in _result:
+            raise RuntimeError(
+                "service_generate thread exited without producing outputs or raising an exception."
+            )
+        return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_execute.py` at line 126, The
return unconditionally indexing _result["outputs"] can raise KeyError if the
worker never populated outputs; update the return to safely read outputs (e.g.,
use _result.get("outputs", []) or a similar defensive check) before returning
from the function/method in generate_music_execute.py (look for the return that
references _result["outputs"] in the generate_music_execute handler), so the
returned dict always contains a valid outputs value even when the worker thread
failed to populate it.

30 changes: 27 additions & 3 deletions acestep/core/generation/handler/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
import time
from typing import Optional

from loguru import logger

# Conservative per-step estimate used when no historical timing data exists
# (i.e., first-ever generation on this machine). 2.5s/step is deliberately
# slow so the progress bar undershoots rather than overshoots — reaching 79%
# early and pausing is far less alarming than freezing at 52% with zero
# movement. The estimate self-corrects after the first successful generation.
_FALLBACK_PER_STEP_SEC = 2.5


class ProgressMixin:
def _get_project_root(self) -> str:
Expand Down Expand Up @@ -146,16 +155,31 @@ def _start_diffusion_progress_estimator(
duration_sec: Optional[float],
desc: str,
):
"""Best-effort progress updates during diffusion using previous step timing."""
"""Best-effort progress updates during diffusion using previous step timing.

Falls back to a conservative default estimate when no historical data
exists (first-ever generation). This ensures the progress bar always
moves during Phase 2 instead of freezing at 52%.
"""
if progress is None or infer_steps <= 0:
return None, None
per_step = self._estimate_diffusion_per_step(
infer_steps=infer_steps,
batch_size=batch_size,
duration_sec=duration_sec,
) or self._last_diffusion_per_step_sec

if not per_step or per_step <= 0:
return None, None
# No history at all — use conservative fallback so progress bar
# still moves on first run. Scale by batch size for a rough
# approximation.
per_step = _FALLBACK_PER_STEP_SEC * max(1, batch_size)
logger.info(
f"[progress] No timing history — using fallback estimate "
f"({per_step:.1f}s/step for batch_size={batch_size}). "
f"This will self-calibrate after the first generation."
)

expected = per_step * infer_steps
if expected <= 0:
return None, None
Expand All @@ -175,4 +199,4 @@ def _runner():

thread = threading.Thread(target=_runner, name="diffusion-progress", daemon=True)
thread.start()
return stop_event, thread
return stop_event, thread