diff --git a/acestep/ui/gradio/i18n/en.json b/acestep/ui/gradio/i18n/en.json index 3e33d12f..652d693d 100644 --- a/acestep/ui/gradio/i18n/en.json +++ b/acestep/ui/gradio/i18n/en.json @@ -5,8 +5,8 @@ }, "common": { "language_metadata": { - "name": "English", - "native_name": "English" + "name": "English", + "native_name": "English" } }, "dataset": { @@ -424,25 +424,55 @@ "export_path": "Export Path", "export_lora_btn": "๐Ÿ“ฆ Export LoRA", "export_status": "Export Status", - "stop_no_training": "\u2139 No training in progress", - "stop_stopping": "\u23F9\uFE0F Stopping training...", + "stop_no_training": "โ„น No training in progress", + "stop_stopping": "โน๏ธ Stopping training...", "latest_auto": "Latest (auto)", - "export_path_required": "\u274C Please enter an export path", - "invalid_lora_output_dir": "\u274C Invalid LoRA output directory", - "no_checkpoints_found": "\u274C No checkpoints found", - "no_trained_model_found": "\u274C No trained model found in {path}", - "invalid_export_path": "\u274C Invalid export path", - "lora_exported": "\u2705 LoRA exported to {path}", - "export_failed": "\u274C Export failed: {error}", - "lokr_output_dir_required": "\u26A0\uFE0F Enter LoKr output directory first", - "lokr_no_checkpoints_use_latest": "\u2139 No checkpoints found; export will use latest available weights", - "lokr_no_exportable_checkpoints": "\u2139 No exportable epoch checkpoints found", - "lokr_found_checkpoints": "\u2705 Found {count} LoKr checkpoints", - "lokr_selected_epoch_not_found": "\u274C Selected epoch not found: {chosen}. Available: {available}", - "lokr_no_weights_selected_epoch": "\u274C No LoKr weights found for selected epoch: {epoch}", - "lokr_no_weights_latest_checkpoint": "\u274C No LoKr weights found in latest checkpoint: {checkpoint}", - "lokr_no_trained_weights_found": "\u274C No trained LoKr weights found in {path}", - "lokr_exported": "\u2705 LoKr exported to {path}" + "export_path_required": "โŒ Please enter an export path", + "invalid_lora_output_dir": "โŒ Invalid LoRA output directory", + "no_checkpoints_found": "โŒ No checkpoints found", + "no_trained_model_found": "โŒ No trained model found in {path}", + "invalid_export_path": "โŒ Invalid export path", + "lora_exported": "โœ… LoRA exported to {path}", + "export_failed": "โŒ Export failed: {error}", + "lokr_output_dir_required": "โš ๏ธ Enter LoKr output directory first", + "lokr_no_checkpoints_use_latest": "โ„น No checkpoints found; export will use latest available weights", + "lokr_no_exportable_checkpoints": "โ„น No exportable epoch checkpoints found", + "lokr_found_checkpoints": "โœ… Found {count} LoKr checkpoints", + "lokr_selected_epoch_not_found": "โŒ Selected epoch not found: {chosen}. Available: {available}", + "lokr_no_weights_selected_epoch": "โŒ No LoKr weights found for selected epoch: {epoch}", + "lokr_no_weights_latest_checkpoint": "โŒ No LoKr weights found in latest checkpoint: {checkpoint}", + "lokr_no_trained_weights_found": "โŒ No trained LoKr weights found in {path}", + "lokr_exported": "โœ… LoKr exported to {path}", + "tab_train_lokr": "๐Ÿš€ Train LoKr", + "train_section_tensors": "Preprocessed Tensors", + "train_section_lora": "LoRA Settings", + "train_section_params": "Training Parameters", + "lokr_section_tensors": "Preprocessed Tensors", + "lokr_section_settings": "LoKr Settings", + "lokr_tensor_selection_desc": "Select the directory containing preprocessed tensor files (`.pt` files).\nThese are created using the 'Preprocess' button in the 'Dataset Builder' tab.", + "lokr_linear_dim": "LoKr Linear Dim", + "lokr_linear_dim_info": "Rank (dimension) for LoKr adaptation matrices.", + "lokr_linear_alpha": "LoKr Linear Alpha", + "lokr_linear_alpha_info": "Scaling factor for LoKr (usually similar to dim).", + "lokr_factor": "LoKr Factor", + "lokr_factor_info": "Kronecker factor (-1 for auto).", + "lokr_decompose_both": "Decompose Both Sides", + "lokr_decompose_both_info": "When enabled, decomposes both left and right matrices.", + "lokr_use_tucker": "Use Tucker Decomposition", + "lokr_use_tucker_info": "Apply Tucker decomposition when applicable.", + "lokr_use_scalar": "Use Scalar Gate", + "lokr_use_scalar_info": "Enable scalar gating for LoKr weights.", + "lokr_weight_decompose": "Weight Decompose (WD)", + "lokr_weight_decompose_info": "Enable weight decomposition for more stable LoKr training.", + "lokr_learning_rate_info": "LoKr commonly uses a higher LR than LoRA. Tune per dataset.", + "lokr_output_dir_info": "Directory to save trained LoKr weights.", + "start_lokr_training_btn": "๐Ÿš€ Start Training LoKr", + "lokr_training_loss_title": "LoKr Training Loss", + "lokr_export_header": "Export LoKr", + "export_lokr_btn": "๐Ÿ“ฆ Export LoKr", + "lokr_checkpoint_epoch": "Checkpoint Epoch", + "lokr_checkpoint_epoch_info": "Select a specific epoch checkpoint to export, or keep Latest (auto).", + "refresh_epochs_btn": "โ†ป Refresh Epochs" }, "help": { "btn_label": "?", diff --git a/acestep/ui/gradio/i18n/he.json b/acestep/ui/gradio/i18n/he.json index db2b8a7f..b5cf715e 100644 --- a/acestep/ui/gradio/i18n/he.json +++ b/acestep/ui/gradio/i18n/he.json @@ -442,7 +442,37 @@ "lokr_no_weights_selected_epoch": "โŒ ืœื ื ืžืฆืื• ืžืฉืงืœื™ LoKr ืœืื™ื˜ืจืฆื™ื” ืฉื ื‘ื—ืจื”: {epoch}", "lokr_no_weights_latest_checkpoint": "โŒ ืœื ื ืžืฆืื• ืžืฉืงืœื™ LoKr ื‘ื ืงื•ื“ืช ื”ื‘ื™ืงื•ืจืช ื”ืื—ืจื•ื ื”", "lokr_no_trained_weights_found": "โŒ ืœื ื ืžืฆืื• ืžืฉืงืœื™ LoKr ืžืื•ืžื ื™ื ื‘-{path}", - "lokr_exported": "โœ… LoKr ื™ื•ืฆื ืœ-{path}" + "lokr_exported": "โœ… LoKr ื™ื•ืฆื ืœ-{path}", + "tab_train_lokr": "๐Ÿš€ ืื™ืžื•ืŸ LoKr", + "train_section_tensors": "Preprocessed Tensors", + "train_section_lora": "LoRA Settings", + "train_section_params": "Training Parameters", + "lokr_section_tensors": "Preprocessed Tensors", + "lokr_section_settings": "LoKr Settings", + "lokr_tensor_selection_desc": "ื‘ื—ืจ ืืช ื”ืกืคืจื™ื™ื” ื”ืžื›ื™ืœื” ืืช ืงื‘ืฆื™ ื”ื˜ื ื–ื•ืจื™ื ื”ืžืขื•ื‘ื“ื™ื (`.pt`).", + "lokr_linear_dim": "LoKr Linear Dim", + "lokr_linear_dim_info": "Rank (dimension) for LoKr adaptation matrices.", + "lokr_linear_alpha": "LoKr Linear Alpha", + "lokr_linear_alpha_info": "Scaling factor for LoKr (usually similar to dim).", + "lokr_factor": "LoKr Factor", + "lokr_factor_info": "Kronecker factor (-1 for auto).", + "lokr_decompose_both": "Decompose Both Sides", + "lokr_decompose_both_info": "When enabled, decomposes both left and right matrices.", + "lokr_use_tucker": "Use Tucker Decomposition", + "lokr_use_tucker_info": "Apply Tucker decomposition when applicable.", + "lokr_use_scalar": "Use Scalar Gate", + "lokr_use_scalar_info": "Enable scalar gating for LoKr weights.", + "lokr_weight_decompose": "Weight Decompose (WD)", + "lokr_weight_decompose_info": "Enable weight decomposition for more stable LoKr training.", + "lokr_learning_rate_info": "LoKr commonly uses a higher LR than LoRA. Tune per dataset.", + "lokr_output_dir_info": "ืกืคืจื™ื™ื” ืœืฉืžื™ืจืช ืžืฉืงืœื™ LoKr ื”ืžืื•ืžื ื™ื.", + "start_lokr_training_btn": "๐Ÿš€ ื”ืชื—ืœ ืื™ืžื•ืŸ LoKr", + "lokr_training_loss_title": "LoKr ืื•ื‘ื“ืŸ ืื™ืžื•ืŸ (Loss)", + "lokr_export_header": "ื™ื™ืฆื•ื LoKr", + "export_lokr_btn": "๐Ÿ“ฆ ื™ื™ืฆื•ื LoKr", + "lokr_checkpoint_epoch": "Checkpoint Epoch", + "lokr_checkpoint_epoch_info": "Select a specific epoch checkpoint to export, or keep Latest (auto).", + "refresh_epochs_btn": "โ†ป ืจืขื ื•ืŸ Epochs" }, "help": { "btn_label": "?", diff --git a/acestep/ui/gradio/i18n/ja.json b/acestep/ui/gradio/i18n/ja.json index beb6d446..34bae681 100644 --- a/acestep/ui/gradio/i18n/ja.json +++ b/acestep/ui/gradio/i18n/ja.json @@ -5,8 +5,8 @@ }, "common": { "language_metadata": { - "name": "Japanese", - "native_name": "ๆ—ฅๆœฌ่ชž" + "name": "Japanese", + "native_name": "ๆ—ฅๆœฌ่ชž" } }, "dataset": { @@ -426,7 +426,34 @@ "lokr_no_weights_selected_epoch": "โŒ ้ธๆŠžใ—ใŸใ‚จใƒใƒƒใ‚ฏ {epoch} ใซLoKr้‡ใฟใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“", "lokr_no_weights_latest_checkpoint": "โŒ ๆœ€ๆ–ฐใƒใ‚งใƒƒใ‚ฏใƒใ‚คใƒณใƒˆ {checkpoint} ใซLoKr้‡ใฟใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“", "lokr_no_trained_weights_found": "โŒ {path} ใซใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐๆธˆใฟLoKr้‡ใฟใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“", - "lokr_exported": "โœ… LoKr ใ‚’ {path} ใซใ‚จใ‚ฏใ‚นใƒใƒผใƒˆใ—ใพใ—ใŸ" + "lokr_exported": "โœ… LoKr ใ‚’ {path} ใซใ‚จใ‚ฏใ‚นใƒใƒผใƒˆใ—ใพใ—ใŸ", + "tab_train_lokr": "๐Ÿš€ LoKr ใ‚’ใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐ", + "lokr_section_tensors": "ๅ‰ๅ‡ฆ็†ๆธˆใฟใƒ‡ใƒผใ‚ฟใ‚ปใƒƒใƒˆ้ธๆŠž", + "lokr_section_settings": "LoKr ่จญๅฎš", + "lokr_tensor_selection_desc": "ๅ‰ๅ‡ฆ็†ใ•ใ‚ŒใŸใƒ†ใƒณใ‚ฝใƒซใƒ•ใ‚กใ‚คใƒซ๏ผˆ`.pt` ใƒ•ใ‚กใ‚คใƒซ๏ผ‰ใ‚’ๅซใ‚€ใƒ‡ใ‚ฃใƒฌใ‚ฏใƒˆใƒชใ‚’้ธๆŠžใ—ใพใ™ใ€‚\nใ“ใ‚Œใ‚‰ใฏใ€Œใƒ‡ใƒผใ‚ฟใ‚ปใƒƒใƒˆใƒ“ใƒซใƒ€ใƒผใ€ใ‚ฟใƒ–ใฎใ€Œๅ‰ๅ‡ฆ็†ใ€ใƒœใ‚ฟใƒณใ‚’ไฝฟ็”จใ—ใฆไฝœๆˆใ•ใ‚Œใพใ™ใ€‚", + "lokr_linear_dim": "LoKr Linear Dim", + "lokr_linear_dim_info": "Rank (dimension) for LoKr adaptation matrices.", + "lokr_linear_alpha": "LoKr Linear Alpha", + "lokr_linear_alpha_info": "Scaling factor for LoKr (usually similar to dim).", + "lokr_factor": "LoKr Factor", + "lokr_factor_info": "Kronecker factor (-1 for auto).", + "lokr_decompose_both": "Decompose Both Sides", + "lokr_decompose_both_info": "When enabled, decomposes both left and right matrices.", + "lokr_use_tucker": "Use Tucker Decomposition", + "lokr_use_tucker_info": "Apply Tucker decomposition when applicable.", + "lokr_use_scalar": "Use Scalar Gate", + "lokr_use_scalar_info": "Enable scalar gating for LoKr weights.", + "lokr_weight_decompose": "Weight Decompose (WD)", + "lokr_weight_decompose_info": "Enable weight decomposition for more stable LoKr training.", + "lokr_learning_rate_info": "LoKr commonly uses a higher LR than LoRA. Tune per dataset.", + "lokr_output_dir_info": "ใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐๆธˆใฟ LoKr ้‡ใฟใฎไฟๅญ˜ๅ…ˆใƒ‡ใ‚ฃใƒฌใ‚ฏใƒˆใƒชใ€‚", + "start_lokr_training_btn": "๐Ÿš€ ใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐ้–‹ๅง‹ LoKr", + "lokr_training_loss_title": "LoKr ใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐๆๅคฑ", + "lokr_export_header": "LoKr ใ‚’ใ‚จใ‚ฏใ‚นใƒใƒผใƒˆ", + "export_lokr_btn": "๐Ÿ“ฆ LoKr ใ‚’ใ‚จใ‚ฏใ‚นใƒใƒผใƒˆ", + "lokr_checkpoint_epoch": "Checkpoint Epoch", + "lokr_checkpoint_epoch_info": "Select a specific epoch checkpoint to export, or keep Latest (auto).", + "refresh_epochs_btn": "โ†ป ๆ›ดๆ–ฐ Epochs" }, "help": { "btn_label": "?", diff --git a/acestep/ui/gradio/i18n/zh.json b/acestep/ui/gradio/i18n/zh.json index 3a2c5716..74f28b3e 100644 --- a/acestep/ui/gradio/i18n/zh.json +++ b/acestep/ui/gradio/i18n/zh.json @@ -5,8 +5,8 @@ }, "common": { "language_metadata": { - "name": "Chinese", - "native_name": "ไธญๆ–‡" + "name": "Chinese", + "native_name": "ไธญๆ–‡" } }, "dataset": { @@ -426,7 +426,34 @@ "lokr_no_weights_selected_epoch": "โŒ ๆ‰€้€‰่ฝฎๆฌก {epoch} ไธญๆœชๆ‰พๅˆฐ LoKr ๆƒ้‡", "lokr_no_weights_latest_checkpoint": "โŒ ๆœ€ๆ–ฐๆฃ€ๆŸฅ็‚น {checkpoint} ไธญๆœชๆ‰พๅˆฐ LoKr ๆƒ้‡", "lokr_no_trained_weights_found": "โŒ ๅœจ {path} ไธญๆœชๆ‰พๅˆฐ่ฎญ็ปƒๅฅฝ็š„ LoKr ๆƒ้‡", - "lokr_exported": "โœ… LoKr ๅทฒๅฏผๅ‡บๅˆฐ {path}" + "lokr_exported": "โœ… LoKr ๅทฒๅฏผๅ‡บๅˆฐ {path}", + "tab_train_lokr": "๐Ÿš€ ่ฎญ็ปƒ LoKr", + "lokr_section_tensors": "้ข„ๅค„็†ๆ•ฐๆฎ้›†้€‰ๆ‹ฉ", + "lokr_section_settings": "LoKr ่ฎพ็ฝฎ", + "lokr_tensor_selection_desc": "้€‰ๆ‹ฉๅŒ…ๅซ้ข„ๅค„็†ๅผ ้‡ๆ–‡ไปถ๏ผˆ`.pt` ๆ–‡ไปถ๏ผ‰็š„็›ฎๅฝ•ใ€‚\n่ฟ™ไบ›ๆ–‡ไปถๅœจใ€Œๆ•ฐๆฎ้›†ๆž„ๅปบใ€ๆ ‡็ญพ้กตไธญไฝฟ็”จใ€Œ้ข„ๅค„็†ใ€ๆŒ‰้’ฎๅˆ›ๅปบใ€‚", + "lokr_linear_dim": "LoKr Linear Dim", + "lokr_linear_dim_info": "Rank (dimension) for LoKr adaptation matrices.", + "lokr_linear_alpha": "LoKr Linear Alpha", + "lokr_linear_alpha_info": "Scaling factor for LoKr (usually similar to dim).", + "lokr_factor": "LoKr Factor", + "lokr_factor_info": "Kronecker factor (-1 for auto).", + "lokr_decompose_both": "Decompose Both Sides", + "lokr_decompose_both_info": "When enabled, decomposes both left and right matrices.", + "lokr_use_tucker": "Use Tucker Decomposition", + "lokr_use_tucker_info": "Apply Tucker decomposition when applicable.", + "lokr_use_scalar": "Use Scalar Gate", + "lokr_use_scalar_info": "Enable scalar gating for LoKr weights.", + "lokr_weight_decompose": "Weight Decompose (WD)", + "lokr_weight_decompose_info": "Enable weight decomposition for more stable LoKr training.", + "lokr_learning_rate_info": "LoKr commonly uses a higher LR than LoRA. Tune per dataset.", + "lokr_output_dir_info": "ไฟๅญ˜่ฎญ็ปƒๅŽ LoKr ๆƒ้‡็š„็›ฎๅฝ•ใ€‚", + "start_lokr_training_btn": "๐Ÿš€ ๅผ€ๅง‹่ฎญ็ปƒ LoKr", + "lokr_training_loss_title": "LoKr ่ฎญ็ปƒๆŸๅคฑ", + "lokr_export_header": "ๅฏผๅ‡บ LoKr", + "export_lokr_btn": "๐Ÿ“ฆ ๅฏผๅ‡บ LoKr", + "lokr_checkpoint_epoch": "Checkpoint Epoch", + "lokr_checkpoint_epoch_info": "Select a specific epoch checkpoint to export, or keep Latest (auto).", + "refresh_epochs_btn": "โ†ป ๅˆทๆ–ฐ Epochs" }, "help": { "btn_label": "?", diff --git a/acestep/ui/gradio/interfaces/training.py b/acestep/ui/gradio/interfaces/training.py index 4512bb4b..c207a77e 100644 --- a/acestep/ui/gradio/interfaces/training.py +++ b/acestep/ui/gradio/interfaces/training.py @@ -1,835 +1,68 @@ -""" -Gradio UI Training Tab Module +"""Gradio UI training-tab facade that composes focused tab builders.""" -Contains the dataset builder and LoRA training interface components. -The outer gr.Tab wrapper is now created in __init__.py. -""" +from __future__ import annotations -import os import gradio as gr -from acestep.ui.gradio.i18n import t -from acestep.ui.gradio.help_content import create_help_button + from acestep.constants import DEBUG_TRAINING +from acestep.ui.gradio.interfaces.training_dataset_builder_tab import ( + create_dataset_builder_tab, +) +from acestep.ui.gradio.interfaces.training_lokr_tab import create_training_lokr_tab +from acestep.ui.gradio.interfaces.training_lora_tab import create_training_lora_tab + + +def _resolve_epoch_slider_defaults() -> tuple[int, int, int]: + """Return epoch slider defaults adjusted for training debug mode.""" + + debug_training_enabled = str(DEBUG_TRAINING).strip().upper() != "OFF" + if debug_training_enabled: + return 1, 1, 1 + return 100, 100, 1000 def create_training_section(dit_handler, llm_handler, init_params=None) -> dict: - """Create the training tab content (without the outer gr.Tab wrapper). - - The outer gr.Tab is now created in __init__.py so that Generation and - Training tabs are siblings under the same gr.Tabs container. - + """Create the training-tab content without the outer ``gr.Tab`` wrapper. + Args: - dit_handler: DiT handler instance - llm_handler: LLM handler instance - init_params: Dictionary containing initialization parameters and state. - If None, service will not be pre-initialized. - + dit_handler: DiT handler instance. + llm_handler: LLM handler instance. + init_params: Optional initialization parameters. + Returns: - Dictionary of Gradio components for event handling + Mapping of component keys to Gradio components for training-event wiring. """ - debug_training_enabled = str(DEBUG_TRAINING).strip().upper() != "OFF" - epoch_min = 1 if debug_training_enabled else 100 - epoch_step = 1 if debug_training_enabled else 100 - epoch_default = 1 if debug_training_enabled else 1000 - gr.HTML(""" + del dit_handler, llm_handler, init_params + + epoch_min, epoch_step, epoch_default = _resolve_epoch_slider_defaults() + + gr.HTML( + """

๐ŸŽต LoRA Training for ACE-Step

Build datasets from your audio files and train custom LoRA adapters

- """) - - with gr.Tabs(): - # ==================== Dataset Builder Tab ==================== - with gr.Tab(t("training.tab_dataset_builder")): - create_help_button("training_dataset") - # ========== Load Existing OR Scan New ========== - gr.HTML(f""" -
-

{t("training.quick_start_title")}

-

Choose one: Load existing dataset OR Scan new directory

-
- """) - - with gr.Row(): - with gr.Column(scale=1): - gr.HTML("

๐Ÿ“‚ Load Existing Dataset

") - with gr.Row(): - load_json_path = gr.Textbox( - label=t("training.load_dataset_label"), - placeholder="./datasets/my_lora_dataset.json", - info=t("training.load_dataset_info"), elem_classes=["has-info-container"], - scale=3, - ) - load_json_btn = gr.Button(t("training.load_btn"), variant="primary", scale=1) - load_json_status = gr.Textbox( - label=t("training.load_status"), - interactive=False, - ) - - with gr.Column(scale=1): - gr.HTML("

๐Ÿ” Scan New Directory

") - with gr.Row(): - audio_directory = gr.Textbox( - label=t("training.scan_label"), - placeholder="/path/to/your/audio/folder", - info=t("training.scan_info"), elem_classes=["has-info-container"], - scale=3, - ) - scan_btn = gr.Button(t("training.scan_btn"), variant="secondary", scale=1) - scan_status = gr.Textbox( - label=t("training.scan_status"), - interactive=False, - ) - - gr.HTML("
") - - with gr.Row(): - with gr.Column(scale=2): - - # Audio files table - audio_files_table = gr.Dataframe( - headers=["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"], - datatype=["number", "str", "str", "str", "str", "str", "str", "str"], - label=t("training.found_audio_files"), - interactive=False, - wrap=True, - ) - - with gr.Column(scale=1): - gr.HTML(f"

โš™๏ธ {t('training.dataset_settings_header')}

") - - dataset_name = gr.Textbox( - label=t("training.dataset_name"), - value="my_lora_dataset", - placeholder=t("training.dataset_name_placeholder"), - ) - - all_instrumental = gr.Checkbox( - label=t("training.all_instrumental"), - value=True, - info=t("training.all_instrumental_info"), elem_classes=["has-info-container"], - ) - - format_lyrics = gr.Checkbox( - label="Format Lyrics (LM)", - value=False, - info="Use LM to format/structure user-provided lyrics from .txt files (coming soon)", elem_classes=["has-info-container"], - interactive=False, # Disabled for now - model update needed - ) - - transcribe_lyrics = gr.Checkbox( - label="Transcribe Lyrics (LM)", - value=False, - info="Use LM to transcribe lyrics from audio (coming soon)", elem_classes=["has-info-container"], - interactive=False, # Disabled for now - model update needed - ) - - custom_tag = gr.Textbox( - label=t("training.custom_tag"), - placeholder="e.g., 8bit_retro, my_style", - info=t("training.custom_tag_info"), elem_classes=["has-info-container"], - ) - - tag_position = gr.Radio( - choices=[ - (t("training.tag_prepend"), "prepend"), - (t("training.tag_append"), "append"), - (t("training.tag_replace"), "replace"), - ], - value="replace", - label=t("training.tag_position"), - info=t("training.tag_position_info"), elem_classes=["has-info-container"], - ) - - genre_ratio = gr.Slider( - minimum=0, - maximum=100, - step=10, - value=0, - label=t("training.genre_ratio"), - info=t("training.genre_ratio_info"), elem_classes=["has-info-container"], - ) - - gr.HTML(f"

๐Ÿค– {t('training.step2_title')}

") - - with gr.Row(): - with gr.Column(scale=3): - gr.Markdown(t('training.step2_instruction')) - skip_metas = gr.Checkbox( - label=t("training.skip_metas"), - value=False, - info=t("training.skip_metas_info"), elem_classes=["has-info-container"], - ) - only_unlabeled = gr.Checkbox( - label=t("training.only_unlabeled"), - value=False, - info=t("training.only_unlabeled_info"), elem_classes=["has-info-container"], - ) - with gr.Column(scale=1): - auto_label_btn = gr.Button( - t("training.auto_label_btn"), - variant="primary", - size="lg", - ) - - label_progress = gr.Textbox( - label=t("training.label_progress"), - interactive=False, - lines=2, - ) - - gr.HTML(f"

๐Ÿ‘€ {t('training.step3_title')}

") - - with gr.Row(): - with gr.Column(scale=1): - sample_selector = gr.Slider( - minimum=0, - maximum=0, - step=1, - value=0, - label=t("training.select_sample"), - info=t("training.select_sample_info"), elem_classes=["has-info-container"], - ) - - preview_audio = gr.Audio( - label=t("training.audio_preview"), - type="filepath", - interactive=False, - ) - - preview_filename = gr.Textbox( - label=t("training.filename"), - interactive=False, - ) - - with gr.Column(scale=2): - with gr.Row(): - edit_caption = gr.Textbox( - label=t("training.caption"), - lines=3, - placeholder="Music description...", - ) - - with gr.Row(): - edit_genre = gr.Textbox( - label=t("training.genre"), - lines=1, - placeholder="pop, electronic, dance...", - ) - prompt_override = gr.Dropdown( - choices=["Use Global Ratio", "Caption", "Genre"], - value="Use Global Ratio", - label=t("training.prompt_override_label"), - info=t("training.prompt_override_info"), elem_classes=["has-info-container"], - ) - - with gr.Row(): - edit_lyrics = gr.Textbox( - label=t("training.lyrics_editable_label"), - lines=6, - placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...", - ) - raw_lyrics_display = gr.Textbox( - label=t("training.raw_lyrics_label"), - lines=6, - placeholder=t("training.no_lyrics_placeholder"), - interactive=False, # Read-only, can copy but not edit - visible=False, # Hidden when no raw lyrics - ) - has_raw_lyrics_state = gr.State(False) # Track visibility - - with gr.Row(): - edit_bpm = gr.Number( - label=t("training.bpm"), - precision=0, - ) - edit_keyscale = gr.Textbox( - label=t("training.key_label"), - placeholder=t("training.key_placeholder"), - ) - edit_timesig = gr.Dropdown( - choices=["", "2", "3", "4", "6", "N/A"], - label=t("training.time_sig"), - ) - edit_duration = gr.Number( - label=t("training.duration_s"), - precision=1, - interactive=False, - ) - - with gr.Row(): - edit_language = gr.Dropdown( - choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"], - value="instrumental", - label=t("training.language"), - ) - edit_instrumental = gr.Checkbox( - label=t("training.instrumental"), - value=True, - ) - save_edit_btn = gr.Button(t("training.save_changes_btn"), variant="secondary") - - edit_status = gr.Textbox( - label=t("training.edit_status"), - interactive=False, - ) - - gr.HTML(f"

๐Ÿ’พ {t('training.step4_title')}

") - - with gr.Row(): - with gr.Column(scale=3): - save_path = gr.Textbox( - label=t("training.save_path"), - value="./datasets/my_lora_dataset.json", - placeholder="./datasets/dataset_name.json", - info=t("training.save_path_info"), elem_classes=["has-info-container"], - ) - with gr.Column(scale=1): - save_dataset_btn = gr.Button( - t("training.save_dataset_btn"), - variant="primary", - size="lg", - ) - - save_status = gr.Textbox( - label=t("training.save_status"), - interactive=False, - lines=2, - ) - - gr.HTML(f"

โšก {t('training.step5_title')}

") - - gr.Markdown(t('training.step5_intro')) - - with gr.Row(): - with gr.Column(scale=3): - load_existing_dataset_path = gr.Textbox( - label=t("training.load_existing_label"), - placeholder="./datasets/my_lora_dataset.json", - info=t("training.load_existing_info"), elem_classes=["has-info-container"], - ) - with gr.Column(scale=1): - load_existing_dataset_btn = gr.Button( - t("training.load_dataset_btn"), - variant="secondary", - size="lg", - ) - - load_existing_status = gr.Textbox( - label=t("training.load_status"), - interactive=False, - ) - - gr.Markdown(t('training.step5_details')) - - with gr.Row(): - preprocess_mode = gr.Dropdown( - label="Preprocess For", - choices=["LoRA", "LoKr"], - value="LoRA", - info="LoRA keeps compatibility mode; LoKr uses per-sample source-style context.", elem_classes=["has-info-container"], - ) - - with gr.Row(): - with gr.Column(scale=3): - preprocess_output_dir = gr.Textbox( - label=t("training.tensor_output_dir"), - value="./datasets/preprocessed_tensors", - placeholder="./datasets/preprocessed_tensors", - info=t("training.tensor_output_info"), elem_classes=["has-info-container"], - ) - with gr.Column(scale=1): - preprocess_btn = gr.Button( - t("training.preprocess_btn"), - variant="primary", - size="lg", - ) - - preprocess_progress = gr.Textbox( - label=t("training.preprocess_progress"), - interactive=False, - lines=3, - ) - - # ==================== Training Tab ==================== - with gr.Tab(t("training.tab_train_lora")): - create_help_button("training_train") - with gr.Row(): - with gr.Column(scale=2): - gr.HTML(f"

๐Ÿ“Š {t('training.train_section_tensors')}

") - - gr.Markdown(t('training.train_tensor_selection_desc')) - - training_tensor_dir = gr.Textbox( - label=t("training.preprocessed_tensors_dir"), - placeholder="./datasets/preprocessed_tensors", - value="./datasets/preprocessed_tensors", - info=t("training.preprocessed_tensors_info"), elem_classes=["has-info-container"], - ) - - load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary") - - training_dataset_info = gr.Textbox( - label=t("training.dataset_info"), - interactive=False, - lines=3, - ) - - with gr.Column(scale=1): - gr.HTML(f"

โš™๏ธ {t('training.train_section_lora')}

") - - lora_rank = gr.Slider( - minimum=4, - maximum=256, - step=4, - value=64, - label=t("training.lora_rank"), - info=t("training.lora_rank_info"), elem_classes=["has-info-container"], - ) - - lora_alpha = gr.Slider( - minimum=4, - maximum=512, - step=4, - value=128, - label=t("training.lora_alpha"), - info=t("training.lora_alpha_info"), elem_classes=["has-info-container"], - ) - - lora_dropout = gr.Slider( - minimum=0.0, - maximum=0.5, - step=0.05, - value=0.1, - label=t("training.lora_dropout"), - ) - - gr.HTML(f"

๐ŸŽ›๏ธ {t('training.train_section_params')}

") - - with gr.Row(): - learning_rate = gr.Number( - label=t("training.learning_rate"), - value=3e-4, - info=t("training.learning_rate_info"), elem_classes=["has-info-container"], - ) - - train_epochs = gr.Slider( - minimum=epoch_min, - maximum=4000, - step=epoch_step, - value=epoch_default, - label=t("training.max_epochs"), - ) - - train_batch_size = gr.Slider( - minimum=1, - maximum=8, - step=1, - value=1, - label=t("training.batch_size"), - info=t("training.batch_size_info"), elem_classes=["has-info-container"], - ) - - gradient_accumulation = gr.Slider( - minimum=1, - maximum=16, - step=1, - value=1, - label=t("training.gradient_accumulation"), - info=t("training.gradient_accumulation_info"), elem_classes=["has-info-container"], - ) - - with gr.Row(): - save_every_n_epochs = gr.Slider( - minimum=1, - maximum=1000, - step=1, - value=10, - label=t("training.save_every_n_epochs"), - ) - - training_shift = gr.Slider( - minimum=1.0, - maximum=5.0, - step=0.5, - value=3.0, - label=t("training.shift"), - info=t("training.shift_info"), elem_classes=["has-info-container"], - ) - - training_seed = gr.Number( - label=t("training.seed"), - value=42, - precision=0, - ) - - with gr.Row(): - lora_output_dir = gr.Textbox( - label=t("training.output_dir"), - value="./lora_output", - placeholder="./lora_output", - info=t("training.output_dir_info"), elem_classes=["has-info-container"], - ) - - with gr.Row(): - resume_checkpoint_dir = gr.Textbox( - label="Resume Checkpoint", - placeholder="./lora_output/checkpoints/epoch_200", - info="Directory of a saved LoRA checkpoint to resume from", elem_classes=["has-info-container"], - ) - - gr.HTML("
") - - with gr.Row(): - with gr.Column(scale=1): - start_training_btn = gr.Button( - t("training.start_training_btn"), - variant="primary", - size="lg", - ) - with gr.Column(scale=1): - stop_training_btn = gr.Button( - t("training.stop_training_btn"), - variant="stop", - size="lg", - ) - - training_progress = gr.Textbox( - label=t("training.training_progress"), - interactive=False, - lines=2, - ) - - with gr.Row(): - training_log = gr.Textbox( - label=t("training.training_log"), - interactive=False, - lines=10, - max_lines=15, - scale=1, - ) - training_loss_plot = gr.Plot( - label=t("training.training_loss_title"), - scale=1, - ) - - gr.HTML(f"

๐Ÿ“ฆ {t('training.export_header')}

") - - with gr.Row(): - export_path = gr.Textbox( - label=t("training.export_path"), - value="./lora_output/final_lora", - placeholder="./lora_output/my_lora", - ) - export_lora_btn = gr.Button(t("training.export_lora_btn"), variant="secondary") - - export_status = gr.Textbox( - label=t("training.export_status"), - interactive=False, - ) - - # ==================== Train LoKr Tab ==================== - with gr.Tab("๐Ÿš€ Train LoKr"): - create_help_button("training_lokr") - with gr.Row(): - with gr.Column(scale=2): - gr.HTML("

๐Ÿ“Š Preprocessed Tensors

") - gr.Markdown( - "Select the directory containing preprocessed tensor files (`.pt` files). " - "These are created in the Dataset Builder tab." - ) - - lokr_training_tensor_dir = gr.Textbox( - label="Preprocessed Tensors Directory", - placeholder="./datasets/preprocessed_tensors", - value="./datasets/preprocessed_tensors", - info="Path to directory containing manifest.json and tensor .pt files.", elem_classes=["has-info-container"], - ) - - lokr_load_dataset_btn = gr.Button("Load Dataset", variant="secondary") - - lokr_training_dataset_info = gr.Textbox( - label="Dataset Info", - interactive=False, - lines=3, - ) - - with gr.Column(scale=1): - gr.HTML("

โš™๏ธ LoKr Settings

") - - lokr_linear_dim = gr.Slider( - minimum=4, - maximum=256, - step=4, - value=64, - label="LoKr Linear Dim", - info="Adapter rank-like width for LoKr linear layers.", elem_classes=["has-info-container"], - ) - lokr_linear_alpha = gr.Slider( - minimum=4, - maximum=512, - step=4, - value=128, - label="LoKr Linear Alpha", - info="Scaling factor for LoKr adapters.", elem_classes=["has-info-container"], - ) - lokr_factor = gr.Number( - label="LoKr Factor", - value=-1, - precision=0, - info="-1 uses automatic Kronecker factor selection.", elem_classes=["has-info-container"], - ) - lokr_decompose_both = gr.Checkbox( - label="Decompose Both", - value=False, - info="Enable decomposition on both matrices.", elem_classes=["has-info-container"], - ) - lokr_use_tucker = gr.Checkbox( - label="Use Tucker", - value=False, - info="Enable Tucker decomposition mode.", elem_classes=["has-info-container"], - ) - lokr_use_scalar = gr.Checkbox( - label="Use Scalar", - value=False, - info="Enable scalar calibration in LyCORIS.", elem_classes=["has-info-container"], - ) - lokr_weight_decompose = gr.Checkbox( - label="Weight Decompose (DoRA)", - value=True, - info="Enable DoRA-style weight decomposition when supported.", elem_classes=["has-info-container"], - ) - - gr.HTML("

๐ŸŽ›๏ธ Training Parameters

") - - with gr.Row(): - lokr_learning_rate = gr.Number( - label="Learning Rate", - value=1e-3, - info="LoKr commonly uses a higher LR than LoRA. Tune per dataset.", elem_classes=["has-info-container"], - ) - - lokr_train_epochs = gr.Slider( - minimum=1, - maximum=4000, - step=1, - value=500, - label="Max Epochs", - ) - - lokr_train_batch_size = gr.Slider( - minimum=1, - maximum=8, - step=1, - value=1, - label="Batch Size", - ) - - lokr_gradient_accumulation = gr.Slider( - minimum=1, - maximum=16, - step=1, - value=4, - label="Gradient Accumulation", - ) - - with gr.Row(): - lokr_save_every_n_epochs = gr.Slider( - minimum=1, - maximum=1000, - step=1, - value=10, - label="Save Every N Epochs", - ) - - lokr_training_shift = gr.Slider( - minimum=1.0, - maximum=5.0, - step=0.5, - value=3.0, - label="Shift", - info="Turbo model training timestep shift.", elem_classes=["has-info-container"], - ) - - lokr_training_seed = gr.Number( - label="Seed", - value=42, - precision=0, - ) - - with gr.Row(): - lokr_output_dir = gr.Textbox( - label="Output Directory", - value="./lokr_output", - placeholder="./lokr_output", - info="Where LoKr checkpoints and final weights will be written.", elem_classes=["has-info-container"], - ) - - gr.HTML("
") - - with gr.Row(): - with gr.Column(scale=1): - start_lokr_training_btn = gr.Button( - "Start LoKr Training", - variant="primary", - size="lg", - ) - with gr.Column(scale=1): - stop_lokr_training_btn = gr.Button( - "Stop Training", - variant="stop", - size="lg", - ) - - lokr_training_progress = gr.Textbox( - label="Training Progress", - interactive=False, - lines=2, - ) - - with gr.Row(): - lokr_training_log = gr.Textbox( - label="Training Log", - interactive=False, - lines=10, - max_lines=15, - scale=1, - ) - lokr_training_loss_plot = gr.Plot( - label="LoKr Training Loss", - scale=1, - ) - - gr.HTML("

๐Ÿ“ฆ Export LoKr

") - - with gr.Row(): - lokr_export_path = gr.Textbox( - label="Export Path", - value="./lokr_output/final_lokr", - placeholder="./lokr_output/my_lokr", - ) - export_lokr_btn = gr.Button("๐Ÿ“ฆ Export LoKr", variant="secondary") - - with gr.Row(): - lokr_export_epoch = gr.Dropdown( - choices=["Latest (auto)"], - value="Latest (auto)", - label="Checkpoint Epoch", - info="Select a specific epoch checkpoint to export, or keep Latest (auto).", elem_classes=["has-info-container"], - ) - refresh_lokr_export_epochs_btn = gr.Button("โ†ป Refresh Epochs", variant="secondary") + """ + ) - lokr_export_status = gr.Textbox( - label="Export Status", - interactive=False, + training_section: dict[str, object] = {} + with gr.Tabs(): + training_section.update(create_dataset_builder_tab()) + training_section.update( + create_training_lora_tab( + epoch_min=epoch_min, + epoch_step=epoch_step, + epoch_default=epoch_default, ) - - # Store dataset builder state - dataset_builder_state = gr.State(None) - training_state = gr.State({"is_training": False, "should_stop": False}) - - return { - # Dataset Builder - Load or Scan - "load_json_path": load_json_path, - "load_json_btn": load_json_btn, - "load_json_status": load_json_status, - "audio_directory": audio_directory, - "scan_btn": scan_btn, - "scan_status": scan_status, - "audio_files_table": audio_files_table, - "dataset_name": dataset_name, - "all_instrumental": all_instrumental, - "format_lyrics": format_lyrics, - "transcribe_lyrics": transcribe_lyrics, - "custom_tag": custom_tag, - "tag_position": tag_position, - "skip_metas": skip_metas, - "only_unlabeled": only_unlabeled, - "auto_label_btn": auto_label_btn, - "label_progress": label_progress, - "sample_selector": sample_selector, - "preview_audio": preview_audio, - "preview_filename": preview_filename, - "edit_caption": edit_caption, - "edit_genre": edit_genre, - "prompt_override": prompt_override, - "genre_ratio": genre_ratio, - "edit_lyrics": edit_lyrics, - "raw_lyrics_display": raw_lyrics_display, - "has_raw_lyrics_state": has_raw_lyrics_state, - "edit_bpm": edit_bpm, - "edit_keyscale": edit_keyscale, - "edit_timesig": edit_timesig, - "edit_duration": edit_duration, - "edit_language": edit_language, - "edit_instrumental": edit_instrumental, - "save_edit_btn": save_edit_btn, - "edit_status": edit_status, - "save_path": save_path, - "save_dataset_btn": save_dataset_btn, - "save_status": save_status, - # Preprocessing - "load_existing_dataset_path": load_existing_dataset_path, - "load_existing_dataset_btn": load_existing_dataset_btn, - "load_existing_status": load_existing_status, - "preprocess_mode": preprocess_mode, - "preprocess_output_dir": preprocess_output_dir, - "preprocess_btn": preprocess_btn, - "preprocess_progress": preprocess_progress, - "dataset_builder_state": dataset_builder_state, - # Training - "training_tensor_dir": training_tensor_dir, - "load_dataset_btn": load_dataset_btn, - "training_dataset_info": training_dataset_info, - "lora_rank": lora_rank, - "lora_alpha": lora_alpha, - "lora_dropout": lora_dropout, - "learning_rate": learning_rate, - "train_epochs": train_epochs, - "train_batch_size": train_batch_size, - "gradient_accumulation": gradient_accumulation, - "save_every_n_epochs": save_every_n_epochs, - "training_shift": training_shift, - "training_seed": training_seed, - "lora_output_dir": lora_output_dir, - "resume_checkpoint_dir": resume_checkpoint_dir, - "start_training_btn": start_training_btn, - "stop_training_btn": stop_training_btn, - "training_progress": training_progress, - "training_log": training_log, - "training_loss_plot": training_loss_plot, - "export_path": export_path, - "export_lora_btn": export_lora_btn, - "export_status": export_status, - # LoKr training - "lokr_training_tensor_dir": lokr_training_tensor_dir, - "lokr_load_dataset_btn": lokr_load_dataset_btn, - "lokr_training_dataset_info": lokr_training_dataset_info, - "lokr_linear_dim": lokr_linear_dim, - "lokr_linear_alpha": lokr_linear_alpha, - "lokr_factor": lokr_factor, - "lokr_decompose_both": lokr_decompose_both, - "lokr_use_tucker": lokr_use_tucker, - "lokr_use_scalar": lokr_use_scalar, - "lokr_weight_decompose": lokr_weight_decompose, - "lokr_learning_rate": lokr_learning_rate, - "lokr_train_epochs": lokr_train_epochs, - "lokr_train_batch_size": lokr_train_batch_size, - "lokr_gradient_accumulation": lokr_gradient_accumulation, - "lokr_save_every_n_epochs": lokr_save_every_n_epochs, - "lokr_training_shift": lokr_training_shift, - "lokr_training_seed": lokr_training_seed, - "lokr_output_dir": lokr_output_dir, - "start_lokr_training_btn": start_lokr_training_btn, - "stop_lokr_training_btn": stop_lokr_training_btn, - "lokr_training_progress": lokr_training_progress, - "lokr_training_log": lokr_training_log, - "lokr_training_loss_plot": lokr_training_loss_plot, - "lokr_export_path": lokr_export_path, - "lokr_export_epoch": lokr_export_epoch, - "refresh_lokr_export_epochs_btn": refresh_lokr_export_epochs_btn, - "export_lokr_btn": export_lokr_btn, - "lokr_export_status": lokr_export_status, - "training_state": training_state, - } + ) + training_section.update(create_training_lokr_tab()) + dataset_builder_state = gr.State(None) + training_state = gr.State({"is_training": False, "should_stop": False}) + training_section.update( + { + "dataset_builder_state": dataset_builder_state, + "training_state": training_state, + } + ) + return training_section diff --git a/acestep/ui/gradio/interfaces/training_contract_ast_utils.py b/acestep/ui/gradio/interfaces/training_contract_ast_utils.py new file mode 100644 index 00000000..967c807a --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_contract_ast_utils.py @@ -0,0 +1,77 @@ +"""AST utility helpers for training interface decomposition contract tests.""" + +from __future__ import annotations + +import ast +from pathlib import Path + + +INTERFACES_DIR = Path(__file__).resolve().parent +WIRING_DIR = INTERFACES_DIR.parent / "events" / "wiring" + + +def load_module(module_name: str) -> ast.Module: + """Parse and return AST for an interfaces module.""" + + path = INTERFACES_DIR / module_name + return ast.parse(path.read_text(encoding="utf-8")) + + +def call_name(node: ast.AST) -> str | None: + """Extract a simple call-target name from an AST call function node.""" + + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +def collect_return_dict_keys(module_name: str, function_name: str) -> set[str]: + """Collect string keys from dict literals assigned/updated/returned in a function.""" + + module = load_module(module_name) + function_node = None + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == function_name: + function_node = node + break + if function_node is None: + raise AssertionError(f"{function_name} not found in {module_name}") + + keys: set[str] = set() + for node in ast.walk(function_node): + if isinstance(node, ast.Assign) and isinstance(node.value, ast.Dict): + for key in node.value.keys: + if isinstance(key, ast.Constant) and isinstance(key.value, str): + keys.add(key.value) + if isinstance(node, ast.AnnAssign) and isinstance(node.value, ast.Dict): + for key in node.value.keys: + if isinstance(key, ast.Constant) and isinstance(key.value, str): + keys.add(key.value) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == "update": + if node.args and isinstance(node.args[0], ast.Dict): + for key in node.args[0].keys: + if isinstance(key, ast.Constant) and isinstance(key.value, str): + keys.add(key.value) + if isinstance(node, ast.Return) and isinstance(node.value, ast.Dict): + for key in node.value.keys: + if isinstance(key, ast.Constant) and isinstance(key.value, str): + keys.add(key.value) + return keys + + +def collect_training_section_keys_used_by_wiring() -> set[str]: + """Collect ``training_section[...]`` keys referenced by training wiring modules.""" + + keys: set[str] = set() + for path in WIRING_DIR.glob("training_*wiring.py"): + module = ast.parse(path.read_text(encoding="utf-8")) + for node in ast.walk(module): + if not isinstance(node, ast.Subscript): + continue + if not isinstance(node.value, ast.Name) or node.value.id != "training_section": + continue + if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str): + keys.add(node.slice.value) + return keys diff --git a/acestep/ui/gradio/interfaces/training_dataset_builder_tab.py b/acestep/ui/gradio/interfaces/training_dataset_builder_tab.py new file mode 100644 index 00000000..87551c0a --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_dataset_builder_tab.py @@ -0,0 +1,29 @@ +"""Dataset-builder tab facade for the Gradio training interface.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.help_content import create_help_button +from acestep.ui.gradio.i18n import t +from acestep.ui.gradio.interfaces.training_dataset_tab_label_preview import ( + build_dataset_label_and_preview_controls, +) +from acestep.ui.gradio.interfaces.training_dataset_tab_save_preprocess import ( + build_dataset_save_and_preprocess_controls, +) +from acestep.ui.gradio.interfaces.training_dataset_tab_scan_settings import ( + build_dataset_scan_and_settings_controls, +) + + +def create_dataset_builder_tab() -> dict[str, object]: + """Create the Dataset Builder tab and return all exposed component handles.""" + + with gr.Tab(t("training.tab_dataset_builder")): + create_help_button("training_dataset") + tab_controls: dict[str, object] = {} + tab_controls.update(build_dataset_scan_and_settings_controls()) + tab_controls.update(build_dataset_label_and_preview_controls()) + tab_controls.update(build_dataset_save_and_preprocess_controls()) + return tab_controls diff --git a/acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py b/acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py new file mode 100644 index 00000000..dcc4838f --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py @@ -0,0 +1,175 @@ +"""Dataset labeling and sample-preview controls for the training dataset tab.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_dataset_label_and_preview_controls() -> dict[str, object]: + """Render auto-label and sample-preview editors for dataset-builder workflows.""" + + gr.HTML(f"

๐Ÿค– {t('training.step2_title')}

") + + with gr.Row(): + with gr.Column(scale=3): + gr.Markdown(t("training.step2_instruction")) + skip_metas = gr.Checkbox( + label=t("training.skip_metas"), + value=False, + info=t("training.skip_metas_info"), + elem_classes=["has-info-container"], + ) + only_unlabeled = gr.Checkbox( + label=t("training.only_unlabeled"), + value=False, + info=t("training.only_unlabeled_info"), + elem_classes=["has-info-container"], + ) + with gr.Column(scale=1): + auto_label_btn = gr.Button( + t("training.auto_label_btn"), + variant="primary", + size="lg", + ) + + label_progress = gr.Textbox( + label=t("training.label_progress"), + interactive=False, + lines=2, + ) + + gr.HTML(f"

๐Ÿ‘€ {t('training.step3_title')}

") + + with gr.Row(): + with gr.Column(scale=1): + sample_selector = gr.Slider( + minimum=0, + maximum=0, + step=1, + value=0, + label=t("training.select_sample"), + info=t("training.select_sample_info"), + elem_classes=["has-info-container"], + ) + + preview_audio = gr.Audio( + label=t("training.audio_preview"), + type="filepath", + interactive=False, + ) + + preview_filename = gr.Textbox( + label=t("training.filename"), + interactive=False, + ) + + with gr.Column(scale=2): + with gr.Row(): + edit_caption = gr.Textbox( + label=t("training.caption"), + lines=3, + placeholder="Music description...", + ) + + with gr.Row(): + edit_genre = gr.Textbox( + label=t("training.genre"), + lines=1, + placeholder="pop, electronic, dance...", + ) + prompt_override = gr.Dropdown( + choices=["Use Global Ratio", "Caption", "Genre"], + value="Use Global Ratio", + label=t("training.prompt_override_label"), + info=t("training.prompt_override_info"), + elem_classes=["has-info-container"], + ) + + with gr.Row(): + edit_lyrics = gr.Textbox( + label=t("training.lyrics_editable_label"), + lines=6, + placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...", + ) + raw_lyrics_display = gr.Textbox( + label=t("training.raw_lyrics_label"), + lines=6, + placeholder=t("training.no_lyrics_placeholder"), + interactive=False, + visible=False, + ) + has_raw_lyrics_state = gr.State(False) + + with gr.Row(): + edit_bpm = gr.Number( + label=t("training.bpm"), + precision=0, + ) + edit_keyscale = gr.Textbox( + label=t("training.key_label"), + placeholder=t("training.key_placeholder"), + ) + edit_timesig = gr.Dropdown( + choices=["", "2", "3", "4", "6", "N/A"], + label=t("training.time_sig"), + ) + edit_duration = gr.Number( + label=t("training.duration_s"), + precision=1, + interactive=False, + ) + + with gr.Row(): + edit_language = gr.Dropdown( + choices=[ + "instrumental", + "en", + "zh", + "ja", + "ko", + "es", + "fr", + "de", + "pt", + "ru", + "unknown", + ], + value="instrumental", + label=t("training.language"), + ) + edit_instrumental = gr.Checkbox( + label=t("training.instrumental"), + value=True, + ) + save_edit_btn = gr.Button(t("training.save_changes_btn"), variant="secondary") + + edit_status = gr.Textbox( + label=t("training.edit_status"), + interactive=False, + ) + + return { + "skip_metas": skip_metas, + "only_unlabeled": only_unlabeled, + "auto_label_btn": auto_label_btn, + "label_progress": label_progress, + "sample_selector": sample_selector, + "preview_audio": preview_audio, + "preview_filename": preview_filename, + "edit_caption": edit_caption, + "edit_genre": edit_genre, + "prompt_override": prompt_override, + "edit_lyrics": edit_lyrics, + "raw_lyrics_display": raw_lyrics_display, + "has_raw_lyrics_state": has_raw_lyrics_state, + "edit_bpm": edit_bpm, + "edit_keyscale": edit_keyscale, + "edit_timesig": edit_timesig, + "edit_duration": edit_duration, + "edit_language": edit_language, + "edit_instrumental": edit_instrumental, + "save_edit_btn": save_edit_btn, + "edit_status": edit_status, + } diff --git a/acestep/ui/gradio/interfaces/training_dataset_tab_save_preprocess.py b/acestep/ui/gradio/interfaces/training_dataset_tab_save_preprocess.py new file mode 100644 index 00000000..2729ae0a --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_dataset_tab_save_preprocess.py @@ -0,0 +1,105 @@ +"""Dataset save and preprocess controls for the training dataset tab.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_dataset_save_and_preprocess_controls() -> dict[str, object]: + """Render dataset save/load-preprocess controls and return component handles.""" + + gr.HTML(f"

๐Ÿ’พ {t('training.step4_title')}

") + + with gr.Row(): + with gr.Column(scale=3): + save_path = gr.Textbox( + label=t("training.save_path"), + value="./datasets/my_lora_dataset.json", + placeholder="./datasets/dataset_name.json", + info=t("training.save_path_info"), + elem_classes=["has-info-container"], + ) + with gr.Column(scale=1): + save_dataset_btn = gr.Button( + t("training.save_dataset_btn"), + variant="primary", + size="lg", + ) + + save_status = gr.Textbox( + label=t("training.save_status"), + interactive=False, + lines=2, + ) + + gr.HTML(f"

โšก {t('training.step5_title')}

") + + gr.Markdown(t("training.step5_intro")) + + with gr.Row(): + with gr.Column(scale=3): + load_existing_dataset_path = gr.Textbox( + label=t("training.load_existing_label"), + placeholder="./datasets/my_lora_dataset.json", + info=t("training.load_existing_info"), + elem_classes=["has-info-container"], + ) + with gr.Column(scale=1): + load_existing_dataset_btn = gr.Button( + t("training.load_dataset_btn"), + variant="secondary", + size="lg", + ) + + load_existing_status = gr.Textbox( + label=t("training.load_status"), + interactive=False, + ) + + gr.Markdown(t("training.step5_details")) + + with gr.Row(): + preprocess_mode = gr.Dropdown( + label="Preprocess For", + choices=["LoRA", "LoKr"], + value="LoRA", + info="LoRA keeps compatibility mode; LoKr uses per-sample source-style context.", + elem_classes=["has-info-container"], + ) + + with gr.Row(): + with gr.Column(scale=3): + preprocess_output_dir = gr.Textbox( + label=t("training.tensor_output_dir"), + value="./datasets/preprocessed_tensors", + placeholder="./datasets/preprocessed_tensors", + info=t("training.tensor_output_info"), + elem_classes=["has-info-container"], + ) + with gr.Column(scale=1): + preprocess_btn = gr.Button( + t("training.preprocess_btn"), + variant="primary", + size="lg", + ) + + preprocess_progress = gr.Textbox( + label=t("training.preprocess_progress"), + interactive=False, + lines=3, + ) + + return { + "save_path": save_path, + "save_dataset_btn": save_dataset_btn, + "save_status": save_status, + "load_existing_dataset_path": load_existing_dataset_path, + "load_existing_dataset_btn": load_existing_dataset_btn, + "load_existing_status": load_existing_status, + "preprocess_mode": preprocess_mode, + "preprocess_output_dir": preprocess_output_dir, + "preprocess_btn": preprocess_btn, + "preprocess_progress": preprocess_progress, + } diff --git a/acestep/ui/gradio/interfaces/training_dataset_tab_scan_settings.py b/acestep/ui/gradio/interfaces/training_dataset_tab_scan_settings.py new file mode 100644 index 00000000..b7b7a463 --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_dataset_tab_scan_settings.py @@ -0,0 +1,143 @@ +"""Dataset scan/load and settings controls for the training dataset tab.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_dataset_scan_and_settings_controls() -> dict[str, object]: + """Render scan/load controls and dataset settings for the dataset-builder tab.""" + + gr.HTML( + f""" +
+

{t("training.quick_start_title")}

+

Choose one: Load existing dataset OR Scan new directory

+
+ """ + ) + + with gr.Row(): + with gr.Column(scale=1): + gr.HTML("

๐Ÿ“‚ Load Existing Dataset

") + with gr.Row(): + load_json_path = gr.Textbox( + label=t("training.load_dataset_label"), + placeholder="./datasets/my_lora_dataset.json", + info=t("training.load_dataset_info"), + elem_classes=["has-info-container"], + scale=3, + ) + load_json_btn = gr.Button(t("training.load_btn"), variant="primary", scale=1) + load_json_status = gr.Textbox( + label=t("training.load_status"), + interactive=False, + ) + + with gr.Column(scale=1): + gr.HTML("

๐Ÿ” Scan New Directory

") + with gr.Row(): + audio_directory = gr.Textbox( + label=t("training.scan_label"), + placeholder="/path/to/your/audio/folder", + info=t("training.scan_info"), + elem_classes=["has-info-container"], + scale=3, + ) + scan_btn = gr.Button(t("training.scan_btn"), variant="secondary", scale=1) + scan_status = gr.Textbox( + label=t("training.scan_status"), + interactive=False, + ) + + gr.HTML("
") + + with gr.Row(): + with gr.Column(scale=2): + audio_files_table = gr.Dataframe( + headers=["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"], + datatype=["number", "str", "str", "str", "str", "str", "str", "str"], + label=t("training.found_audio_files"), + interactive=False, + wrap=True, + ) + + with gr.Column(scale=1): + gr.HTML(f"

โš™๏ธ {t('training.dataset_settings_header')}

") + + dataset_name = gr.Textbox( + label=t("training.dataset_name"), + value="my_lora_dataset", + placeholder=t("training.dataset_name_placeholder"), + ) + + all_instrumental = gr.Checkbox( + label=t("training.all_instrumental"), + value=True, + info=t("training.all_instrumental_info"), + elem_classes=["has-info-container"], + ) + + format_lyrics = gr.Checkbox( + label="Format Lyrics (LM)", + value=False, + info="Use LM to format/structure user-provided lyrics from .txt files (coming soon)", + elem_classes=["has-info-container"], + interactive=False, + ) + + transcribe_lyrics = gr.Checkbox( + label="Transcribe Lyrics (LM)", + value=False, + info="Use LM to transcribe lyrics from audio (coming soon)", + elem_classes=["has-info-container"], + interactive=False, + ) + + custom_tag = gr.Textbox( + label=t("training.custom_tag"), + placeholder="e.g., 8bit_retro, my_style", + info=t("training.custom_tag_info"), + elem_classes=["has-info-container"], + ) + + tag_position = gr.Radio( + choices=[ + (t("training.tag_prepend"), "prepend"), + (t("training.tag_append"), "append"), + (t("training.tag_replace"), "replace"), + ], + value="replace", + label=t("training.tag_position"), + info=t("training.tag_position_info"), + elem_classes=["has-info-container"], + ) + + genre_ratio = gr.Slider( + minimum=0, + maximum=100, + step=10, + value=0, + label=t("training.genre_ratio"), + info=t("training.genre_ratio_info"), + elem_classes=["has-info-container"], + ) + + return { + "load_json_path": load_json_path, + "load_json_btn": load_json_btn, + "load_json_status": load_json_status, + "audio_directory": audio_directory, + "scan_btn": scan_btn, + "scan_status": scan_status, + "audio_files_table": audio_files_table, + "dataset_name": dataset_name, + "all_instrumental": all_instrumental, + "format_lyrics": format_lyrics, + "transcribe_lyrics": transcribe_lyrics, + "custom_tag": custom_tag, + "tag_position": tag_position, + "genre_ratio": genre_ratio, + } diff --git a/acestep/ui/gradio/interfaces/training_decomposition_contract_test.py b/acestep/ui/gradio/interfaces/training_decomposition_contract_test.py new file mode 100644 index 00000000..6018778a --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_decomposition_contract_test.py @@ -0,0 +1,154 @@ +"""AST contract tests for training interface decomposition.""" + +from __future__ import annotations + +import ast +import unittest +from pathlib import Path + +try: + from .training_contract_ast_utils import ( + call_name, + collect_return_dict_keys, + collect_training_section_keys_used_by_wiring, + load_module, + ) +except ImportError: + from training_contract_ast_utils import ( # type: ignore[no-redef] + call_name, + collect_return_dict_keys, + collect_training_section_keys_used_by_wiring, + load_module, + ) + + +class TrainingDecompositionContractTests(unittest.TestCase): + """Verify the training interface facade composes focused helper modules.""" + + def test_training_facade_imports_tab_helpers(self) -> None: + """``training.py`` should import dataset, LoRA, and LoKr helper modules.""" + + module = load_module("training.py") + imported_modules = [] + for node in ast.walk(module): + if isinstance(node, ast.ImportFrom) and node.module: + imported_modules.append(node.module) + + self.assertIn( + "acestep.ui.gradio.interfaces.training_dataset_builder_tab", + imported_modules, + ) + self.assertIn("acestep.ui.gradio.interfaces.training_lora_tab", imported_modules) + self.assertIn("acestep.ui.gradio.interfaces.training_lokr_tab", imported_modules) + + def test_training_facade_merges_helper_sections(self) -> None: + """``training.py`` should compose helper returns into one training-section map.""" + + module = load_module("training.py") + call_names: list[str] = [] + update_calls = 0 + for node in ast.walk(module): + if not isinstance(node, ast.Call): + continue + name = call_name(node.func) + if name: + call_names.append(name) + if isinstance(node.func, ast.Attribute) and node.func.attr == "update": + update_calls += 1 + + self.assertIn("create_dataset_builder_tab", call_names) + self.assertIn("create_training_lora_tab", call_names) + self.assertIn("create_training_lokr_tab", call_names) + self.assertGreaterEqual(update_calls, 4) + + def test_dataset_builder_tab_delegates_to_section_builders(self) -> None: + """Dataset-builder facade should delegate to scan/label/preprocess builders.""" + + module = load_module("training_dataset_builder_tab.py") + call_names: list[str] = [] + for node in ast.walk(module): + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + call_names.append(name) + + self.assertIn("build_dataset_scan_and_settings_controls", call_names) + self.assertIn("build_dataset_label_and_preview_controls", call_names) + self.assertIn("build_dataset_save_and_preprocess_controls", call_names) + + def test_training_keys_cover_wiring_requirements(self) -> None: + """Returned training keys should cover all keys consumed by wiring modules.""" + + produced_keys: set[str] = set() + key_sources = [ + ("training.py", "create_training_section"), + ("training_dataset_tab_scan_settings.py", "build_dataset_scan_and_settings_controls"), + ( + "training_dataset_tab_label_preview.py", + "build_dataset_label_and_preview_controls", + ), + ( + "training_dataset_tab_save_preprocess.py", + "build_dataset_save_and_preprocess_controls", + ), + ("training_lora_tab_dataset.py", "build_lora_dataset_and_adapter_controls"), + ("training_lora_tab_run_export.py", "build_lora_run_and_export_controls"), + ("training_lokr_tab_dataset.py", "build_lokr_dataset_and_adapter_controls"), + ("training_lokr_tab_run_export.py", "build_lokr_run_and_export_controls"), + ] + for module_name, function_name in key_sources: + produced_keys |= collect_return_dict_keys(module_name, function_name) + + required_keys = collect_training_section_keys_used_by_wiring() + self.assertTrue( + required_keys.issubset(produced_keys), + f"Missing training_section keys: {sorted(required_keys - produced_keys)}", + ) + + def test_training_ui_markers_preserved(self) -> None: + """Key emoji UI markers should remain present after decomposition.""" + + interfaces_dir = Path(__file__).resolve().parent + expected_markers = { + "training.py": ["๐ŸŽต LoRA Training for ACE-Step"], + "training_dataset_tab_scan_settings.py": ["๐Ÿ“‚ Load Existing Dataset", "๐Ÿ” Scan New Directory"], + "training_dataset_tab_label_preview.py": ["๐Ÿค–", "๐Ÿ‘€"], + "training_dataset_tab_save_preprocess.py": ["๐Ÿ’พ", "โšก"], + "training_lora_tab_dataset.py": ["๐Ÿ“Š", "โš™๏ธ"], + "training_lora_tab_run_export.py": ["๐ŸŽ›๏ธ", "๐Ÿ“ฆ"], + "training_lokr_tab_run_export.py": ["๐ŸŽ›๏ธ", "๐Ÿ“ฆ"], + } + for module_name, markers in expected_markers.items(): + source = (interfaces_dir / module_name).read_text(encoding="utf-8") + for marker in markers: + self.assertIn(marker, source, f"Missing marker {marker!r} in {module_name}") + + def test_lokr_helpers_use_i18n_translation_calls(self) -> None: + """LoKr modules should import and call ``t(...)`` for user-facing labels.""" + + for module_name in ( + "training_lokr_tab.py", + "training_lokr_tab_dataset.py", + "training_lokr_tab_run_export.py", + ): + module = load_module(module_name) + imported_i18n = False + call_names: list[str] = [] + for node in ast.walk(module): + if ( + isinstance(node, ast.ImportFrom) + and node.module == "acestep.ui.gradio.i18n" + and any(alias.name == "t" for alias in node.names) + ): + imported_i18n = True + if isinstance(node, ast.Call): + name = call_name(node.func) + if name: + call_names.append(name) + + self.assertTrue(imported_i18n, f"{module_name} does not import t from i18n") + self.assertIn("t", call_names, f"{module_name} does not call t(...)") + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/interfaces/training_lokr_tab.py b/acestep/ui/gradio/interfaces/training_lokr_tab.py new file mode 100644 index 00000000..7938d271 --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lokr_tab.py @@ -0,0 +1,25 @@ +"""LoKr training-tab facade for the Gradio training interface.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.help_content import create_help_button +from acestep.ui.gradio.i18n import t +from acestep.ui.gradio.interfaces.training_lokr_tab_dataset import ( + build_lokr_dataset_and_adapter_controls, +) +from acestep.ui.gradio.interfaces.training_lokr_tab_run_export import ( + build_lokr_run_and_export_controls, +) + + +def create_training_lokr_tab() -> dict[str, object]: + """Create the LoKr training tab and return component handles for wiring.""" + + with gr.Tab(t("training.tab_train_lokr")): + create_help_button("training_lokr") + tab_controls: dict[str, object] = {} + tab_controls.update(build_lokr_dataset_and_adapter_controls()) + tab_controls.update(build_lokr_run_and_export_controls()) + return tab_controls diff --git a/acestep/ui/gradio/interfaces/training_lokr_tab_dataset.py b/acestep/ui/gradio/interfaces/training_lokr_tab_dataset.py new file mode 100644 index 00000000..b07101ee --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lokr_tab_dataset.py @@ -0,0 +1,98 @@ +"""LoKr tab dataset and adapter-setting controls.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_lokr_dataset_and_adapter_controls() -> dict[str, object]: + """Render LoKr dataset selector and adapter-parameter controls.""" + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(f"

๐Ÿ“Š {t('training.lokr_section_tensors')}

") + gr.Markdown(t("training.lokr_tensor_selection_desc")) + + lokr_training_tensor_dir = gr.Textbox( + label=t("training.preprocessed_tensors_dir"), + placeholder="./datasets/preprocessed_tensors", + value="./datasets/preprocessed_tensors", + info=t("training.preprocessed_tensors_info"), + elem_classes=["has-info-container"], + ) + + lokr_load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary") + + lokr_training_dataset_info = gr.Textbox( + label=t("training.dataset_info"), + interactive=False, + lines=3, + ) + + with gr.Column(scale=1): + gr.HTML(f"

โš™๏ธ {t('training.lokr_section_settings')}

") + + lokr_linear_dim = gr.Slider( + minimum=4, + maximum=256, + step=4, + value=64, + label=t("training.lokr_linear_dim"), + info=t("training.lokr_linear_dim_info"), + elem_classes=["has-info-container"], + ) + lokr_linear_alpha = gr.Slider( + minimum=4, + maximum=512, + step=4, + value=128, + label=t("training.lokr_linear_alpha"), + info=t("training.lokr_linear_alpha_info"), + elem_classes=["has-info-container"], + ) + lokr_factor = gr.Number( + label=t("training.lokr_factor"), + value=-1, + precision=0, + info=t("training.lokr_factor_info"), + elem_classes=["has-info-container"], + ) + lokr_decompose_both = gr.Checkbox( + label=t("training.lokr_decompose_both"), + value=False, + info=t("training.lokr_decompose_both_info"), + elem_classes=["has-info-container"], + ) + lokr_use_tucker = gr.Checkbox( + label=t("training.lokr_use_tucker"), + value=False, + info=t("training.lokr_use_tucker_info"), + elem_classes=["has-info-container"], + ) + lokr_use_scalar = gr.Checkbox( + label=t("training.lokr_use_scalar"), + value=False, + info=t("training.lokr_use_scalar_info"), + elem_classes=["has-info-container"], + ) + lokr_weight_decompose = gr.Checkbox( + label=t("training.lokr_weight_decompose"), + value=True, + info=t("training.lokr_weight_decompose_info"), + elem_classes=["has-info-container"], + ) + + return { + "lokr_training_tensor_dir": lokr_training_tensor_dir, + "lokr_load_dataset_btn": lokr_load_dataset_btn, + "lokr_training_dataset_info": lokr_training_dataset_info, + "lokr_linear_dim": lokr_linear_dim, + "lokr_linear_alpha": lokr_linear_alpha, + "lokr_factor": lokr_factor, + "lokr_decompose_both": lokr_decompose_both, + "lokr_use_tucker": lokr_use_tucker, + "lokr_use_scalar": lokr_use_scalar, + "lokr_weight_decompose": lokr_weight_decompose, + } diff --git a/acestep/ui/gradio/interfaces/training_lokr_tab_run_export.py b/acestep/ui/gradio/interfaces/training_lokr_tab_run_export.py new file mode 100644 index 00000000..40e4f1ea --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lokr_tab_run_export.py @@ -0,0 +1,162 @@ +"""LoKr tab run and export controls.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_lokr_run_and_export_controls() -> dict[str, object]: + """Render LoKr training-run and export controls for the training tab.""" + + gr.HTML(f"

๐ŸŽ›๏ธ {t('training.train_section_params')}

") + + with gr.Row(): + lokr_learning_rate = gr.Number( + label=t("training.learning_rate"), + value=1e-3, + info=t("training.lokr_learning_rate_info"), + elem_classes=["has-info-container"], + ) + + lokr_train_epochs = gr.Slider( + minimum=1, + maximum=4000, + step=1, + value=500, + label=t("training.max_epochs"), + ) + + lokr_train_batch_size = gr.Slider( + minimum=1, + maximum=8, + step=1, + value=1, + label=t("training.batch_size"), + ) + + lokr_gradient_accumulation = gr.Slider( + minimum=1, + maximum=16, + step=1, + value=4, + label=t("training.gradient_accumulation"), + ) + + with gr.Row(): + lokr_save_every_n_epochs = gr.Slider( + minimum=1, + maximum=1000, + step=1, + value=10, + label=t("training.save_every_n_epochs"), + ) + + lokr_training_shift = gr.Slider( + minimum=1.0, + maximum=5.0, + step=0.5, + value=3.0, + label=t("training.shift"), + info=t("training.shift_info"), + elem_classes=["has-info-container"], + ) + + lokr_training_seed = gr.Number( + label=t("training.seed"), + value=42, + precision=0, + ) + + with gr.Row(): + lokr_output_dir = gr.Textbox( + label=t("training.output_dir"), + value="./lokr_output", + placeholder="./lokr_output", + info=t("training.lokr_output_dir_info"), + elem_classes=["has-info-container"], + ) + + gr.HTML("
") + + with gr.Row(): + with gr.Column(scale=1): + start_lokr_training_btn = gr.Button( + t("training.start_lokr_training_btn"), + variant="primary", + size="lg", + ) + with gr.Column(scale=1): + stop_lokr_training_btn = gr.Button( + t("training.stop_training_btn"), + variant="stop", + size="lg", + ) + + lokr_training_progress = gr.Textbox( + label=t("training.training_progress"), + interactive=False, + lines=2, + ) + + with gr.Row(): + lokr_training_log = gr.Textbox( + label=t("training.training_log"), + interactive=False, + lines=10, + max_lines=15, + scale=1, + ) + lokr_training_loss_plot = gr.Plot( + label=t("training.lokr_training_loss_title"), + scale=1, + ) + + gr.HTML(f"

๐Ÿ“ฆ {t('training.lokr_export_header')}

") + + with gr.Row(): + lokr_export_path = gr.Textbox( + label=t("training.export_path"), + value="./lokr_output/final_lokr", + placeholder="./lokr_output/my_lokr", + ) + export_lokr_btn = gr.Button(t("training.export_lokr_btn"), variant="secondary") + + with gr.Row(): + lokr_export_epoch = gr.Dropdown( + choices=[t("training.latest_auto")], + value=t("training.latest_auto"), + label=t("training.lokr_checkpoint_epoch"), + info=t("training.lokr_checkpoint_epoch_info"), + elem_classes=["has-info-container"], + ) + refresh_lokr_export_epochs_btn = gr.Button( + t("training.refresh_epochs_btn"), variant="secondary" + ) + + lokr_export_status = gr.Textbox( + label=t("training.export_status"), + interactive=False, + ) + + return { + "lokr_learning_rate": lokr_learning_rate, + "lokr_train_epochs": lokr_train_epochs, + "lokr_train_batch_size": lokr_train_batch_size, + "lokr_gradient_accumulation": lokr_gradient_accumulation, + "lokr_save_every_n_epochs": lokr_save_every_n_epochs, + "lokr_training_shift": lokr_training_shift, + "lokr_training_seed": lokr_training_seed, + "lokr_output_dir": lokr_output_dir, + "start_lokr_training_btn": start_lokr_training_btn, + "stop_lokr_training_btn": stop_lokr_training_btn, + "lokr_training_progress": lokr_training_progress, + "lokr_training_log": lokr_training_log, + "lokr_training_loss_plot": lokr_training_loss_plot, + "lokr_export_path": lokr_export_path, + "lokr_export_epoch": lokr_export_epoch, + "refresh_lokr_export_epochs_btn": refresh_lokr_export_epochs_btn, + "export_lokr_btn": export_lokr_btn, + "lokr_export_status": lokr_export_status, + } diff --git a/acestep/ui/gradio/interfaces/training_lora_tab.py b/acestep/ui/gradio/interfaces/training_lora_tab.py new file mode 100644 index 00000000..c8654244 --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lora_tab.py @@ -0,0 +1,36 @@ +"""LoRA training-tab facade for the Gradio training interface.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.help_content import create_help_button +from acestep.ui.gradio.i18n import t +from acestep.ui.gradio.interfaces.training_lora_tab_dataset import ( + build_lora_dataset_and_adapter_controls, +) +from acestep.ui.gradio.interfaces.training_lora_tab_run_export import ( + build_lora_run_and_export_controls, +) + + +def create_training_lora_tab( + *, + epoch_min: int, + epoch_step: int, + epoch_default: int, +) -> dict[str, object]: + """Create the LoRA training tab and return component handles for wiring.""" + + with gr.Tab(t("training.tab_train_lora")): + create_help_button("training_train") + tab_controls: dict[str, object] = {} + tab_controls.update(build_lora_dataset_and_adapter_controls()) + tab_controls.update( + build_lora_run_and_export_controls( + epoch_min=epoch_min, + epoch_step=epoch_step, + epoch_default=epoch_default, + ) + ) + return tab_controls diff --git a/acestep/ui/gradio/interfaces/training_lora_tab_dataset.py b/acestep/ui/gradio/interfaces/training_lora_tab_dataset.py new file mode 100644 index 00000000..dbbabd92 --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lora_tab_dataset.py @@ -0,0 +1,72 @@ +"""LoRA tab dataset and adapter-setting controls.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_lora_dataset_and_adapter_controls() -> dict[str, object]: + """Render LoRA dataset selector and adapter-parameter controls.""" + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(f"

๐Ÿ“Š {t('training.train_section_tensors')}

") + gr.Markdown(t("training.train_tensor_selection_desc")) + + training_tensor_dir = gr.Textbox( + label=t("training.preprocessed_tensors_dir"), + placeholder="./datasets/preprocessed_tensors", + value="./datasets/preprocessed_tensors", + info=t("training.preprocessed_tensors_info"), + elem_classes=["has-info-container"], + ) + + load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary") + + training_dataset_info = gr.Textbox( + label=t("training.dataset_info"), + interactive=False, + lines=3, + ) + + with gr.Column(scale=1): + gr.HTML(f"

โš™๏ธ {t('training.train_section_lora')}

") + + lora_rank = gr.Slider( + minimum=4, + maximum=256, + step=4, + value=64, + label=t("training.lora_rank"), + info=t("training.lora_rank_info"), + elem_classes=["has-info-container"], + ) + + lora_alpha = gr.Slider( + minimum=4, + maximum=512, + step=4, + value=128, + label=t("training.lora_alpha"), + info=t("training.lora_alpha_info"), + elem_classes=["has-info-container"], + ) + + lora_dropout = gr.Slider( + minimum=0.0, + maximum=0.5, + step=0.05, + value=0.1, + label=t("training.lora_dropout"), + ) + + return { + "training_tensor_dir": training_tensor_dir, + "load_dataset_btn": load_dataset_btn, + "training_dataset_info": training_dataset_info, + "lora_rank": lora_rank, + "lora_alpha": lora_alpha, + "lora_dropout": lora_dropout, + } diff --git a/acestep/ui/gradio/interfaces/training_lora_tab_run_export.py b/acestep/ui/gradio/interfaces/training_lora_tab_run_export.py new file mode 100644 index 00000000..d75f7e05 --- /dev/null +++ b/acestep/ui/gradio/interfaces/training_lora_tab_run_export.py @@ -0,0 +1,166 @@ +"""LoRA tab run and export controls.""" + +from __future__ import annotations + +import gradio as gr + +from acestep.ui.gradio.i18n import t + + +def build_lora_run_and_export_controls( + *, + epoch_min: int, + epoch_step: int, + epoch_default: int, +) -> dict[str, object]: + """Render LoRA training-run and export controls for the training tab.""" + + gr.HTML(f"

๐ŸŽ›๏ธ {t('training.train_section_params')}

") + + with gr.Row(): + learning_rate = gr.Number( + label=t("training.learning_rate"), + value=3e-4, + info=t("training.learning_rate_info"), + elem_classes=["has-info-container"], + ) + + train_epochs = gr.Slider( + minimum=epoch_min, + maximum=4000, + step=epoch_step, + value=epoch_default, + label=t("training.max_epochs"), + ) + + train_batch_size = gr.Slider( + minimum=1, + maximum=8, + step=1, + value=1, + label=t("training.batch_size"), + info=t("training.batch_size_info"), + elem_classes=["has-info-container"], + ) + + gradient_accumulation = gr.Slider( + minimum=1, + maximum=16, + step=1, + value=1, + label=t("training.gradient_accumulation"), + info=t("training.gradient_accumulation_info"), + elem_classes=["has-info-container"], + ) + + with gr.Row(): + save_every_n_epochs = gr.Slider( + minimum=1, + maximum=1000, + step=1, + value=10, + label=t("training.save_every_n_epochs"), + ) + + training_shift = gr.Slider( + minimum=1.0, + maximum=5.0, + step=0.5, + value=3.0, + label=t("training.shift"), + info=t("training.shift_info"), + elem_classes=["has-info-container"], + ) + + training_seed = gr.Number( + label=t("training.seed"), + value=42, + precision=0, + ) + + with gr.Row(): + lora_output_dir = gr.Textbox( + label=t("training.output_dir"), + value="./lora_output", + placeholder="./lora_output", + info=t("training.output_dir_info"), + elem_classes=["has-info-container"], + ) + + with gr.Row(): + resume_checkpoint_dir = gr.Textbox( + label="Resume Checkpoint", + placeholder="./lora_output/checkpoints/epoch_200", + info="Directory of a saved LoRA checkpoint to resume from", + elem_classes=["has-info-container"], + ) + + gr.HTML("
") + + with gr.Row(): + with gr.Column(scale=1): + start_training_btn = gr.Button( + t("training.start_training_btn"), + variant="primary", + size="lg", + ) + with gr.Column(scale=1): + stop_training_btn = gr.Button( + t("training.stop_training_btn"), + variant="stop", + size="lg", + ) + + training_progress = gr.Textbox( + label=t("training.training_progress"), + interactive=False, + lines=2, + ) + + with gr.Row(): + training_log = gr.Textbox( + label=t("training.training_log"), + interactive=False, + lines=10, + max_lines=15, + scale=1, + ) + training_loss_plot = gr.Plot( + label=t("training.training_loss_title"), + scale=1, + ) + + gr.HTML(f"

๐Ÿ“ฆ {t('training.export_header')}

") + + with gr.Row(): + export_path = gr.Textbox( + label=t("training.export_path"), + value="./lora_output/final_lora", + placeholder="./lora_output/my_lora", + ) + export_lora_btn = gr.Button(t("training.export_lora_btn"), variant="secondary") + + export_status = gr.Textbox( + label=t("training.export_status"), + interactive=False, + ) + + return { + "learning_rate": learning_rate, + "train_epochs": train_epochs, + "train_batch_size": train_batch_size, + "gradient_accumulation": gradient_accumulation, + "save_every_n_epochs": save_every_n_epochs, + "training_shift": training_shift, + "training_seed": training_seed, + "lora_output_dir": lora_output_dir, + "resume_checkpoint_dir": resume_checkpoint_dir, + "start_training_btn": start_training_btn, + "stop_training_btn": stop_training_btn, + "training_progress": training_progress, + "training_log": training_log, + "training_loss_plot": training_loss_plot, + "export_path": export_path, + "export_lora_btn": export_lora_btn, + "export_status": export_status, + }