From 3367fca73253c85e386ef69af3068d42cea09e4f Mon Sep 17 00:00:00 2001 From: marijnfs Date: Mon, 6 May 2024 21:43:42 +0200 Subject: [PATCH] Gradio configuration parameters (#1591) * Gradio Configuration Settings * Making various Gradio variables configurable instead of hardcoded * Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3 * Fix type of gradio_temperature * revert un-necessary change and lint --------- Co-authored-by: Marijn Stollenga Co-authored-by: Marijn Stollenga Co-authored-by: Wing Lian --- src/axolotl/cli/__init__.py | 12 +++++++++--- .../utils/config/models/input/v0_4_1/__init__.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 9f40bb476..7ec3f524a 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -264,8 +264,8 @@ def generate(instruction): with torch.no_grad(): generation_config = GenerationConfig( repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, + max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), + temperature=cfg.get("gradio_temperature", 0.9), top_p=0.95, top_k=40, bos_token_id=tokenizer.bos_token_id, @@ -300,7 +300,13 @@ def generate(instruction): outputs="text", title=cfg.get("gradio_title", "Axolotl Gradio Interface"), ) - demo.queue().launch(show_api=False, share=True) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) def choose_config(path: Path): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 72e82e823..419deee58 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -409,6 +409,17 @@ def check_wandb_run(cls, data): return data +class GradioConfig(BaseModel): + """Gradio configuration subset""" + + gradio_title: Optional[str] = None + gradio_share: Optional[bool] = None + gradio_server_name: Optional[str] = None + gradio_server_port: Optional[int] = None + gradio_max_new_tokens: Optional[int] = None + gradio_temperature: Optional[float] = None + + # pylint: disable=too-many-public-methods,too-many-ancestors class AxolotlInputConfig( ModelInputConfig, @@ -419,6 +430,7 @@ class AxolotlInputConfig( WandbConfig, MLFlowConfig, LISAConfig, + GradioConfig, RemappedParameters, DeprecatedParameters, BaseModel,