diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 915811c7b0..a2a42796f2 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -180,6 +180,9 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper" ${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate ${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio + ${{ env.SELF_HOST_PYTHON }} -m pip install -U cachetools + ${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad + ${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ --disable-warnings \ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \ diff --git a/setup.cfg b/setup.cfg index 2d420de8fd..0bbd2f5021 100644 --- a/setup.cfg +++ b/setup.cfg @@ -129,6 +129,8 @@ all = natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech + cachetools # For Fish Speech + silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B @@ -210,6 +212,8 @@ audio = natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech + cachetools # For Fish Speech + silero-vad # For Fish Speech doc = ipython>=6.5.0 sphinx>=3.0.0 diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt index 6671d4bfb4..79d4e2defd 100644 --- a/xinference/deploy/docker/requirements.txt +++ b/xinference/deploy/docker/requirements.txt @@ -7,7 +7,7 @@ click tqdm>=4.27 tabulate requests -pydantic +pydantic>2 fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 @@ -72,6 +72,8 @@ loguru # For Fish Speech natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech +cachetools # For Fish Speech +silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt index e4410a0358..4105a2e709 100644 --- a/xinference/deploy/docker/requirements_cpu.txt +++ b/xinference/deploy/docker/requirements_cpu.txt @@ -6,7 +6,7 @@ click tqdm>=4.27 tabulate requests -pydantic +pydantic>2 fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 @@ -67,6 +67,8 @@ loguru # For Fish Speech natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech +cachetools # For Fish Speech +silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B diff --git a/xinference/model/audio/model_spec.json b/xinference/model/audio/model_spec.json index b7c436d474..912d8399e9 100644 --- a/xinference/model/audio/model_spec.json +++ b/xinference/model/audio/model_spec.json @@ -159,7 +159,7 @@ "model_name": "FishSpeech-1.4", "model_family": "FishAudio", "model_id": "fishaudio/fish-speech-1.4", - "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d", + "model_revision": "069c573759936b35191d3380deb89183c0656f59", "model_ability": "text-to-audio", "multilingual": true } diff --git a/xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/conversation.py b/xinference/thirdparty/fish_speech/fish_speech/conversation.py index c9ca0ef918..9bbc1cdb6c 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/conversation.py +++ b/xinference/thirdparty/fish_speech/fish_speech/conversation.py @@ -1,2 +1,256 @@ +from dataclasses import dataclass, field +from typing import Literal + +import torch +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast + +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" SEMANTIC_TOKEN = "<|semantic|>" +MEL_TOKEN = "<|mel|>" +PHONEME_START_TOKEN = "<|phoneme_start|>" +PHONEME_END_TOKEN = "<|phoneme_end|>" +ALL_SPECIAL_TOKENS = [ + IM_START_TOKEN, + IM_END_TOKEN, + SEMANTIC_TOKEN, + MEL_TOKEN, + PHONEME_START_TOKEN, + PHONEME_END_TOKEN, +] + CODEBOOK_PAD_TOKEN_ID = 0 + + +class FishTokenizerConfig(PretrainedConfig): + share_codebook_embeddings: bool = True + codebook_size: int = 1024 + num_codebooks: int = 8 + + +class FishTokenizerFast(PreTrainedTokenizerFast): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True) + self.codebook_size = kwargs.pop("codebook_size", 1024) + self.num_codebooks = kwargs.pop("num_codebooks", 8) + + +AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast) + + +@dataclass(kw_only=True) +class BasePart: + pass + + +@dataclass(kw_only=True) +class VQPart(BasePart): + codes: torch.Tensor + + +@dataclass(kw_only=True) +class TextPart(BasePart): + text: str + + +@dataclass(kw_only=True) +class MelPart(BasePart): + mels: torch.Tensor + + +@dataclass(kw_only=True) +class EncodedMessage: + tokens: torch.Tensor + labels: torch.Tensor + vq_parts: list[torch.Tensor] + mel_parts: list[torch.Tensor] + vq_require_losses: torch.Tensor | None = None + + +@dataclass(kw_only=True) +class Message: + role: Literal["system", "user", "assistant"] + parts: list[VQPart | TextPart | MelPart] = field(default_factory=list) + add_im_start: bool = True + add_im_end: bool = True + cal_loss: bool = False + + # By default, ignore the loss of the auto-generated im_start token + ignore_im_start_loss: bool = True + + def encode( + self: "Message", + tokenizer: AutoTokenizer, + ) -> EncodedMessage: + all_tokens = [] + all_labels = [] + + # Multi-modal tokens + vq_parts = [] + mel_parts = [] + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + + parts = self.parts.copy() + if self.add_im_start: + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n")) + + if self.add_im_end: + parts.append(TextPart(text="<|im_end|>")) + + for part in parts: + if isinstance(part, TextPart): + tokens = tokenizer.encode( + part.text, + add_special_tokens=False, + truncation=False, + return_tensors="pt", + ).int()[0] + elif isinstance(part, VQPart): + tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id + codes = part.codes.clone() + 1 + + if getattr(tokenizer, "share_codebook_embeddings", True) is False: + for i in range(len(codes)): + codes[i] += tokenizer.codebook_size * i + + vq_parts.append(codes) + elif isinstance(part, MelPart): + tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id + mel_parts.append(part.mels) + else: + raise ValueError(f"Unsupported part type: {type(part)}") + + all_tokens.append(tokens) + if self.cal_loss: + all_labels.append(tokens.clone()) + else: + all_labels.append(torch.full_like(tokens, -100)) + + tokens = torch.cat(all_tokens, dim=0) + labels = torch.cat(all_labels, dim=0) + assert tokens.shape == labels.shape + + if self.ignore_im_start_loss and self.add_im_start: + labels[: len(all_tokens[0])] = -100 + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + ) + + +@dataclass +class Conversation: + messages: list[Message] + + def encode( + self: "Conversation", + tokenizer: AutoTokenizer, + add_shift: bool = True, + ) -> EncodedMessage: + # Build the input_ids and labels + tokens = [] + labels = [] + vq_parts = [] + mel_parts = [] + vq_require_losses = [] + + for message in self.messages: + encoded = message.encode( + tokenizer, + ) + tokens.append(encoded.tokens) + labels.append(encoded.labels) + vq_parts.extend(encoded.vq_parts) + mel_parts.extend(encoded.mel_parts) + vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) + + tokens = torch.cat(tokens, dim=0) + labels = torch.cat(labels, dim=0) + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) + + if add_shift: + tokens = tokens[:-1] + labels = labels[1:] + + assert tokens.dtype in [ + torch.int, + torch.long, + ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}" + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + vq_require_losses=vq_require_losses, + ) + + def encode_for_inference( + self: "Conversation", + tokenizer: AutoTokenizer, + num_codebooks: int, + ) -> EncodedMessage: + encoded = self.encode(tokenizer, add_shift=False) + tokens = encoded.tokens + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) + values[0] = tokens + + if encoded.vq_parts is None or len(encoded.vq_parts) == 0: + return values + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + vq_parts = encoded.vq_parts + vq_parts = torch.cat(vq_parts, dim=1) + values[1:, tokens == semantic_id] = vq_parts + return values + + def visualize(self: "Conversation", tokenizer: AutoTokenizer): + encoded = self.encode(tokenizer, add_shift=False) + + print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="") + print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="") + + for tok, lab in zip(encoded.tokens, encoded.labels): + val = tokenizer.decode(tok, skip_special_tokens=False) + if val == "\n": + val = "\\n\n" + + if lab == -100: + print_in_green(val) + else: + print_in_blue(val) + + print() + + +if __name__ == "__main__": + message0 = Message( + role="user", + parts=[ + TextPart(text="Hello, how are you?"), + VQPart(codes=torch.zeros((4, 10))), + ], + cal_loss=False, + ) + + message1 = Message( + role="assistant", + parts=[TextPart(text="I'm fine, thank you.")], + cal_loss=True, + ) + conversation = Conversation([message0, message1]) + tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") + conversation.visualize(tokenizer) + + encoded = conversation.encode(tokenizer) + print(encoded) + print(tokenizer.batch_decode(encoded.tokens)) diff --git a/xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json index 6e280c236e..d36c774313 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json @@ -118,5 +118,6 @@ "new": "new", "Realtime Transform Text": "Realtime Transform Text", "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", - "Text Normalization": "Text Normalization" + "Text Normalization": "Text Normalization", + "Select Example Audio": "Select Example Audio" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json index 3285341f68..7a4757967d 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json @@ -118,5 +118,6 @@ "new": "nuevo", "Realtime Transform Text": "Transformación de Texto en Tiempo Real", "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", - "Text Normalization": "Normalización de Texto" + "Text Normalization": "Normalización de Texto", + "Select Example Audio": "Selecionar áudio de exemplo" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json index d30bac7bcd..863b8b0b41 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json @@ -118,6 +118,6 @@ "new": "新規", "Realtime Transform Text": "リアルタイム変換テキスト", "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", - "Text Normalization": "テキスト正規化" - + "Text Normalization": "テキスト正規化", + "Select Example Audio": "サンプル音声を選択" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json new file mode 100644 index 0000000000..180263874b --- /dev/null +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", + "Accumulate Gradient Batches": "그라디언트 배치 누적", + "Add to Processing Area": "처리 영역에 추가", + "Added path successfully!": "경로가 성공적으로 추가되었습니다!", + "Advanced Config": "고급 설정", + "Base LLAMA Model": "기본 LLAMA 모델", + "Batch Inference": "배치 추론", + "Batch Size": "배치 크기", + "Changing with the Model Path": "모델 경로에 따라 변경 중", + "Chinese": "중국어", + "Compile Model": "모델 컴파일", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", + "Copy": "복사", + "Data Preprocessing": "데이터 전처리", + "Data Preprocessing Path": "데이터 전처리 경로", + "Data Source": "데이터 소스", + "Decoder Model Config": "디코더 모델 설정", + "Decoder Model Path": "디코더 모델 경로", + "Disabled": "비활성화 됨", + "Enable Reference Audio": "참고 음성 활성화", + "English": "영어", + "Error Message": "오류 메시지", + "File Preprocessing": "파일 전처리", + "Generate": "생성", + "Generated Audio": "생성된 오디오", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", + "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", + "Inference Configuration": "추론 설정", + "Inference Server Configuration": "추론 서버 설정", + "Inference Server Error": "추론 서버 오류", + "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", + "Initial Learning Rate": "초기 학습률", + "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", + "Input Text": "입력 텍스트", + "Invalid path: {}": "유효하지 않은 경로: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", + "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", + "Japanese": "일본어", + "LLAMA Configuration": "LLAMA 설정", + "LLAMA Model Config": "LLAMA 모델 설정", + "LLAMA Model Path": "LLAMA 모델 경로", + "Labeling Device": "라벨링 장치", + "LoRA Model to be merged": "병합할 LoRA 모델", + "Maximum Audio Duration": "최대 오디오 길이", + "Maximum Length per Sample": "샘플당 최대 길이", + "Maximum Training Steps": "최대 학습 단계", + "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", + "Merge": "병합", + "Merge LoRA": "LoRA 병합", + "Merge successfully": "성공적으로 병합 되었습니다.", + "Minimum Audio Duration": "최소 오디오 길이", + "Model Output Path": "모델 출력 경로", + "Model Size": "모델 크기", + "Move": "이동", + "Move files successfully": "파일이 성공적으로 이동되었습니다.", + "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", + "No selected options": "옵션이 선택되지 않았습니다.", + "Number of Workers": "작업자 수", + "Open Inference Server": "추론 서버 열기", + "Open Labeler WebUI": "라벨러 WebUI 열기", + "Open Tensorboard": "Tensorboard 열기", + "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", + "Optional Label Language": "선택적 라벨 언어", + "Optional online ver": "온라인 버전 선택", + "Output Path": "출력 경로", + "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", + "Precision": "정밀도", + "Probability of applying Speaker Condition": "화자 조건 적용 확률", + "Put your text here.": "여기에 텍스트를 입력하세요.", + "Reference Audio": "참고 오디오", + "Reference Text": "참고 텍스트", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", + "Remove Selected Data": "선택한 데이터 제거", + "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", + "Repetition Penalty": "반복 패널티", + "Save model every n steps": "n 단계마다 모델 저장", + "Select LLAMA ckpt": "LLAMA ckpt 선택", + "Select VITS ckpt": "VITS ckpt 선택", + "Select VQGAN ckpt": "VQGAN ckpt 선택", + "Select source file processing method": "소스 파일 처리 방법 선택", + "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", + "Selected: {}": "선택됨: {}", + "Speaker": "화자", + "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", + "Start Training": "학습 시작", + "Streaming Audio": "스트리밍 오디오", + "Streaming Generate": "스트리밍 생성", + "Tensorboard Host": "Tensorboard 호스트", + "Tensorboard Log Path": "Tensorboard 로그 경로", + "Tensorboard Port": "Tensorboard 포트", + "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", + "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", + "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", + "Training Configuration": "학습 설정", + "Training Error": "학습 오류", + "Training stopped": "학습이 중지되었습니다.", + "Type name of the speaker": "화자의 이름을 입력하세요.", + "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", + "Use LoRA": "LoRA 사용", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", + "Use filelist": "파일 목록 사용", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", + "VITS Configuration": "VITS 설정", + "VQGAN Configuration": "VQGAN 설정", + "Validation Batch Size": "검증 배치 크기", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", + "WebUI Host": "WebUI 호스트", + "WebUI Port": "WebUI 포트", + "Whisper Model": "Whisper 모델", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", + "latest": "최신", + "new": "새로운", + "Realtime Transform Text": "실시간 텍스트 변환", + "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", + "Text Normalization": "텍스트 정규화", + "Select Example Audio": "예시 오디오 선택" +} diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json index 3dd1a5cd1c..9068ef0b9a 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json @@ -118,5 +118,6 @@ "new": "创建新的检查点", "Realtime Transform Text": "实时规范化文本", "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", - "Text Normalization": "文本规范化" + "Text Normalization": "文本规范化", + "Select Example Audio": "选择参考音频" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py b/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py index 0725dfb9b7..6ea15e595f 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py @@ -1,3 +1,4 @@ +import dataclasses import json import math from collections import OrderedDict @@ -57,6 +58,10 @@ class BaseModelArgs: # Initialize the model initializer_range: float = 0.02 + # Dummy vars + is_reward_model: bool = False + share_codebook_embeddings: bool = True + def __post_init__(self): if self.n_local_heads == -1: self.n_local_heads = self.n_head @@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs): class DualARModelArgs(BaseModelArgs): model_type: str = "dual_ar" n_fast_layer: int = 4 + fast_dim: int | None = None + fast_n_head: int | None = None + fast_n_local_heads: int | None = None + fast_head_dim: int | None = None + fast_intermediate_size: int | None = None + fast_attention_qkv_bias: bool | None = None + + def __post_init__(self): + super().__post_init__() + + self.fast_dim = self.fast_dim or self.dim + self.fast_n_head = self.fast_n_head or self.n_head + self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads + self.fast_head_dim = self.fast_head_dim or self.head_dim + self.fast_intermediate_size = ( + self.fast_intermediate_size or self.intermediate_size + ) + self.fast_attention_qkv_bias = ( + self.fast_attention_qkv_bias + if self.fast_attention_qkv_bias is not None + else self.attention_qkv_bias + ) class KVCache(nn.Module): @@ -369,7 +396,10 @@ def from_pretrained( model = simple_quantizer.convert_for_runtime() weights = torch.load( - Path(path) / "model.pth", map_location="cpu", mmap=True + Path(path) / "model.pth", + map_location="cpu", + mmap=True, + weights_only=True, ) if "state_dict" in weights: @@ -471,20 +501,46 @@ class DualARTransformer(BaseTransformer): def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: super().__init__(config, init_weights=False, tokenizer=tokenizer) + # Project to fast dim if needed + if config.fast_dim is not None and config.fast_dim != config.dim: + self.fast_project_in = nn.Linear(config.dim, config.fast_dim) + else: + self.fast_project_in = nn.Identity() + # Fast transformer - self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) + self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim) # The equivalent bs is so large that sdpa doesn't work + override_config = dataclasses.replace( + config, + dim=config.fast_dim, + n_head=config.fast_n_head, + n_local_heads=config.fast_n_local_heads, + head_dim=config.fast_head_dim, + intermediate_size=config.fast_intermediate_size, + attention_qkv_bias=config.fast_attention_qkv_bias, + ) + self.fast_layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) + TransformerBlock(override_config, use_sdpa=False) + for _ in range(config.n_fast_layer) ) - self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps) self.fast_output = nn.Linear( - config.dim, + config.fast_dim, config.codebook_size, bias=False, ) + self.register_buffer( + "fast_freqs_cis", + precompute_freqs_cis( + config.num_codebooks, + config.fast_dim // config.fast_n_head, + config.rope_base, + ), + persistent=False, + ) self.apply(self._init_weights) def setup_caches( @@ -492,7 +548,7 @@ def setup_caches( ): super().setup_caches(max_batch_size, max_seq_len, dtype) - head_dim = self.config.dim // self.config.n_head + head_dim = self.config.fast_dim // self.config.fast_n_head # Fast transformer # The max seq len here is the number of codebooks @@ -500,7 +556,7 @@ def setup_caches( b.attention.kv_cache = KVCache( max_batch_size, self.config.num_codebooks, - self.config.n_local_heads, + self.config.fast_n_local_heads, head_dim, dtype=dtype, ) @@ -513,13 +569,13 @@ def forward( parent_result = super().forward(inp, key_padding_mask) token_logits = parent_result.logits x = parent_result.hidden_states + x = self.fast_project_in(x) # Fast transformer fast_seq_len = self.config.num_codebooks fast_mask = self.causal_mask[ None, None, :fast_seq_len, :fast_seq_len ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[:fast_seq_len] # Drop the last token and rotate left codebooks = inp[:, 1:-1, 1:] @@ -542,9 +598,11 @@ def forward( for layer in self.fast_layers: if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) + x = checkpoint( + layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True + ) else: - x = layer(x, fast_freqs_cis, fast_mask) + x = layer(x, self.fast_freqs_cis, fast_mask) # unflatten the batch and num_codebooks fast_out = self.fast_norm(x) @@ -584,7 +642,7 @@ def forward_generate_fast( fast_mask = self.causal_mask[ None, None, input_pos, : self.config.num_codebooks ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[input_pos] + fast_freqs_cis = self.fast_freqs_cis[input_pos] for layer in self.fast_layers: x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) @@ -595,6 +653,13 @@ def forward_generate_fast( return codebook_logits + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + x = super().forward_generate(x, input_pos) + x.hidden_states = self.fast_project_in(x.hidden_states) + return x + class TransformerBlock(nn.Module): def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py index aa21839b54..91fc9118cc 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py @@ -102,8 +102,8 @@ def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) return self @@ -128,8 +128,8 @@ def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) return self @@ -178,9 +178,9 @@ def forward(self, x): def remove_parametrizations(self): for conv in self.convs1: - remove_parametrizations(conv, tensor_name="weight") + conv.remove_parametrizations() for conv in self.convs2: - remove_parametrizations(conv, tensor_name="weight") + conv.remove_parametrizations() class ParallelBlock(nn.Module): @@ -288,11 +288,11 @@ def forward(self, x): def remove_parametrizations(self): for up in self.ups: - remove_parametrizations(up, tensor_name="weight") + up.remove_parametrizations() for block in self.resblocks: block.remove_parametrizations() - remove_parametrizations(self.conv_pre, tensor_name="weight") - remove_parametrizations(self.conv_post, tensor_name="weight") + self.conv_pre.remove_parametrizations() + self.conv_post.remove_parametrizations() # DropPath copied from timm library diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py index 7ea4853376..954553bbfe 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py @@ -99,7 +99,7 @@ def forward(self, z) -> FSQResult: if diff > 0: result.z = F.pad(result.z, (left, right)) elif diff < 0: - result.z = result.z[..., left:-right] + result.z = result.z[..., -left:right] return result diff --git a/xinference/thirdparty/fish_speech/fish_speech/text/clean.py b/xinference/thirdparty/fish_speech/fish_speech/text/clean.py index c228dfcd13..dbaf843d78 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/text/clean.py +++ b/xinference/thirdparty/fish_speech/fish_speech/text/clean.py @@ -1,6 +1,8 @@ import re SYMBOLS_MAPPING = { + "\n": "", + "…": ".", "“": "'", "”": "'", "‘": "'", @@ -13,7 +15,19 @@ ")": "", "(": "", ")": "", - "・": "·", + "・": "", + "·": "", + "「": "'", + "」": "'", + "《": "'", + "》": "'", + "—": "", + "~": "", + "~": "", + ":": ",", + ";": ",", + ";": ",", + ":": ",", } REPLACE_SYMBOL_REGEX = re.compile( @@ -21,6 +35,17 @@ ) +EMOJI_REGEX = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "]+", + flags=re.UNICODE, +) + + def clean_text(text): # Clean the text text = text.strip() @@ -28,4 +53,10 @@ def clean_text(text): # Replace all chinese symbols with their english counterparts text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + # Remove emojis + text = EMOJI_REGEX.sub(r"", text) + + # Remove continuous periods (...) and commas (,,,) + text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) + return text diff --git a/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py index 05378519db..53cf2f2317 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +++ b/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py @@ -5,7 +5,7 @@ from .logger import RankedLogger from .logging_utils import log_hyperparameters from .rich_utils import enforce_tags, print_config_tree -from .utils import extras, get_metric_value, task_wrapper +from .utils import extras, get_metric_value, set_seed, task_wrapper __all__ = [ "enforce_tags", @@ -20,4 +20,5 @@ "braceexpand", "get_latest_checkpoint", "autocast_exclude_mps", + "set_seed", ] diff --git a/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py b/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py index c546bfa1ed..5a34bdcfed 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +++ b/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py @@ -1,7 +1,10 @@ +import random import warnings from importlib.util import find_spec from typing import Callable +import numpy as np +import torch from omegaconf import DictConfig from .logger import RankedLogger @@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float: log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value + + +def set_seed(seed: int): + if seed < 0: + seed = -seed + if seed > (1 << 31): + seed = 1 << 31 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py b/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py index 2f57b595a2..790c0e632c 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +++ b/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py @@ -114,7 +114,7 @@ def __init__( block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", - button_shadow="*shadow_drop_lg", + # button_shadow="*shadow_drop_lg", button_small_padding="0px", button_large_padding="3px", ) diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py b/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py index 4ec3fcac25..c21233eee3 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +++ b/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py @@ -794,7 +794,7 @@ def llama_quantify(llama_weight, quantify_mode): value="VQGAN", ) with gr.Row(): - with gr.Tabs(): + with gr.Column(): with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: gr.HTML("You don't need to train this model!") diff --git a/xinference/thirdparty/fish_speech/tools/api.py b/xinference/thirdparty/fish_speech/tools/api.py index 7fcc9330ae..cc12f3a0fd 100644 --- a/xinference/thirdparty/fish_speech/tools/api.py +++ b/xinference/thirdparty/fish_speech/tools/api.py @@ -1,16 +1,16 @@ -import base64 import io -import json +import os import queue -import random -import sys +import re +import time import traceback import wave from argparse import ArgumentParser from http import HTTPStatus from pathlib import Path -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any +import librosa import numpy as np import ormsgpack # import pyrootutils @@ -28,27 +28,74 @@ # Kui, # OpenAPI, # StreamResponse, +# request, # ) # from kui.asgi.routing import MultimethodRoutes from loguru import logger -from pydantic import BaseModel, Field, conint +from transformers import AutoTokenizer # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +import struct +from threading import Lock + +import httpx +from cachetools import LRUCache, cached +from funasr import AutoModel +from silero_vad import get_speech_timestamps, load_silero_vad + +from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN +from fish_speech.models.text2semantic.llama import BaseModelArgs # from fish_speech.models.vqgan.lit_module import VQGAN from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture from fish_speech.text.chn_text_norm.text import Text as ChnNormedText -from fish_speech.utils import autocast_exclude_mps -from tools.commons import ServeReferenceAudio, ServeTTSRequest +from fish_speech.utils import autocast_exclude_mps, set_seed from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text from tools.llama.generate import ( GenerateRequest, GenerateResponse, WrappedGenerateResponse, launch_thread_safe_queue, + launch_thread_safe_queue_agent, +) +from tools.schema import ( + GLOBAL_NUM_SAMPLES, + ASRPackRequest, + ServeASRRequest, + ServeASRResponse, + ServeASRSegment, + ServeAudioPart, + ServeForwardMessage, + ServeMessage, + ServeRequest, + ServeResponse, + ServeStreamDelta, + ServeStreamResponse, + ServeTextPart, + ServeTimedASRResponse, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, + ServeVQPart, ) from tools.vqgan.inference import load_model as load_decoder_model +global_lock = Lock() + +# Whether to disable keepalive (which is helpful if the server is in the same cluster) +DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true" +async_client = httpx.AsyncClient( + timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None) +) +backends = torchaudio.list_audio_backends() + +if "ffmpeg" in backends: + backend = "ffmpeg" +else: + backend = "soundfile" + def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer = io.BytesIO() @@ -91,9 +138,7 @@ def load_audio(reference_audio, sr): audio_data = reference_audio reference_audio = io.BytesIO(audio_data) - waveform, original_sr = torchaudio.load( - reference_audio, backend="sox" if sys.platform == "linux" else "soundfile" - ) + waveform, original_sr = torchaudio.load(reference_audio, backend=backend) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) @@ -167,9 +212,390 @@ def get_content_type(audio_format): return "application/octet-stream" +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_encode(model, audios: list[bytes | torch.Tensor]): + audios = [ + ( + torch.from_numpy( + librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] + )[None] + if isinstance(audio, bytes) + else audio + ) + for audio in audios + ] + + # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios): + # raise ValueError("Single audio length is too long (>120s)") + + max_length = max(audio.shape[-1] for audio in audios) + print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") + + lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1])) + for audio in audios + ] + ).to(model.device) + + features, feature_lengths = model.encode(padded, audio_lengths=lengths) + features, feature_lengths = features.cpu(), feature_lengths.cpu() + + return [feature[..., :length] for feature, length in zip(features, feature_lengths)] + + +@cached( + cache=LRUCache(maxsize=10000), + key=lambda model, audios: (model.device, tuple(audios)), +) +def cached_vqgan_batch_encode(model, audios: list[bytes]): + return batch_encode(model, audios) + + +# @routes.http.post("/v1/vqgan/encode") +# def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): +# +# start_time = time.time() +# tokens = cached_vqgan_batch_encode(decoder_model, payload.audios) +# logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") +# +# return ormsgpack.packb( +# ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), +# option=ormsgpack.OPT_SERIALIZE_PYDANTIC, +# ) + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def vqgan_decode(model, features): + lengths = torch.tensor( + [feature.shape[-1] for feature in features], device=model.device + ) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) + for feature in features + ] + ).to(model.device) + + # If bs too large, we do micro batch decode + audios, audio_lengths = [], [] + for i in range(0, padded.shape[0], 8): + audio, audio_length = model.decode( + padded[i : i + 8], feature_lengths=lengths[i : i + 8] + ) + audios.append(audio) + audio_lengths.append(audio_length) + audios = torch.cat(audios, dim=0) + audio_lengths = torch.cat(audio_lengths, dim=0) + audios, audio_lengths = audios.cpu(), audio_lengths.cpu() + + return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] + + +# @routes.http.post("/v1/vqgan/decode") +# def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): +# tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens] +# start_time = time.time() +# audios = vqgan_decode(decoder_model, tokens) +# logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") +# audios = [audio.astype(np.float16).tobytes() for audio in audios] +# return ormsgpack.packb( +# ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC +# ) + + +@torch.no_grad() +def batch_asr(model, audios, sr, language="auto"): + resampled_audios = [] + for audio in audios: + audio = torchaudio.functional.resample(audio, sr, 16000) + assert audio.ndim == 1 + resampled_audios.append(audio) + + with global_lock: + res = model.generate( + input=resampled_audios, + batch_size=len(resampled_audios), + language=language, + use_itn=True, + ) + + results = [] + for r, audio in zip(res, audios): + text = r["text"] + text = re.sub(r"<\|.*?\|>", "", text) + duration = len(audio) / sr * 1000 + huge_gap = False + + if "timestamp" in r and len(r["timestamp"]) > 2: + for timestamp_a, timestamp_b in zip( + r["timestamp"][:-1], r["timestamp"][1:] + ): + # If there is a gap of more than 5 seconds, we consider it as a huge gap + if timestamp_b[0] - timestamp_a[1] > 5000: + huge_gap = True + break + + # Doesn't make sense to have a huge gap at the end + if duration - r["timestamp"][-1][1] > 3000: + huge_gap = True + + results.append( + { + "text": text, + "duration": duration, + "huge_gap": huge_gap, + } + ) + + return results + + +# @routes.http.post("/v1/asr") +# def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): +# start_time = time.time() +# audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios] +# audios = [torch.from_numpy(audio).float() for audio in audios] +# +# if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios): +# raise HTTPException(status_code=400, detail="Audio length is too long") +# +# transcriptions = batch_asr( +# asr_model, audios=audios, sr=payload.sample_rate, language=payload.language +# ) +# logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") +# +# return ormsgpack.packb( +# ServeASRResponse(transcriptions=transcriptions), +# option=ormsgpack.OPT_SERIALIZE_PYDANTIC, +# ) + + +from fish_speech.conversation import Conversation, Message + + +def execute_request( + input_queue: queue.Queue, + tokenizer: AutoTokenizer, + config: BaseModelArgs, + request: ServeRequest, + device: str = "cuda:0", +): + semantic_id, im_end_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, IM_END_TOKEN] + ) + messages = [] + for message in request.messages: + messages.append(message.to_conversation_message()) + + assert len(messages) >= 1, "At least one message is required" + # assert messages[-1].role == "user", "The last message must be from the user" + + if messages[-1].role == "user": + messages.append(Message(role="assistant", parts=[], add_im_end=False)) + else: + assert ( + messages[-1].role == "assistant" + ), "The last message must be from the assistant" + messages[-1].add_im_end = False + + conv = Conversation(messages=messages) + prompt = conv.encode_for_inference( + tokenizer=tokenizer, num_codebooks=config.num_codebooks + ).to(device) + + if request.streaming: + for i in range(request.num_samples): + yield ServeStreamResponse( + sample_id=i, + delta=ServeStreamDelta( + role="assistant", + ), + ) + + req = { + "prompt": prompt, + "max_new_tokens": request.max_new_tokens, + "im_end_id": im_end_id, + "semantic_id": semantic_id, + "temperature": request.temperature, + "top_p": request.top_p, + "repetition_penalty": request.repetition_penalty, + "num_samples": request.num_samples, + "early_stop_threshold": request.early_stop_threshold, + } + + start = time.time() + response_queue = queue.Queue() + input_queue.put(GenerateRequest(req, response_queue)) + + # Decoding + decode_buffer = [[] for _ in range(request.num_samples)] + parts = [[] for _ in range(request.num_samples)] + + def send_reset_buffer(sample_id): + nonlocal decode_buffer + if len(decode_buffer[sample_id]) == 0: + return + + decoded = tokenizer.decode(decode_buffer[sample_id]) + part = ServeTextPart(text=decoded) + + if request.streaming: + yield ServeStreamResponse(delta=ServeStreamDelta(part=part)) + else: + parts[sample_id].append(part) + + decode_buffer[sample_id] = [] + + # Decode process + finished = [False for _ in range(request.num_samples)] + stats = {} + idx = 0 + while True: + response = response_queue.get() + + if response in ["stop", "error"]: + break + + for sample_id, tokens in enumerate(response): + if finished[sample_id]: + continue + + if tokens[0] == im_end_id: + finished[sample_id] = True + if request.streaming: + yield from send_reset_buffer(sample_id) + yield ServeStreamResponse( + sample_id=sample_id, + finish_reason="stop", + stats=stats, + ) + continue + + if tokens[0] == semantic_id and request.streaming: + yield from send_reset_buffer(sample_id) + # Streaming vq + _tokens = tokens[1:].clone() - 1 + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + yield ServeStreamResponse( + sample_id=sample_id, + delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), + ) + continue + + # Not streaming vq + if tokens[0] == semantic_id: + yield from send_reset_buffer(sample_id) + # None streaming vq + if len(parts[sample_id]) == 0 or not isinstance( + parts[sample_id][-1], ServeVQPart + ): + _tokens = tokens[1:].clone() - 1 + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) + else: + for codebook_id, value in enumerate(tokens[1:, :]): + val = value.item() - 1 + if config.share_codebook_embeddings is False: + val -= config.codebook_size * codebook_id + + parts[sample_id][-1].codes[codebook_id].append(val) + continue + + if tokens[0] != semantic_id: + # Stream text decode is not supported now + decode_buffer[sample_id].append(tokens[0, 0]) + + if idx == 0: + stats["time_to_first_token"] = (time.time() - start) * 1000 + + idx += 1 + + for sample_id in range(request.num_samples): + yield from send_reset_buffer(sample_id) + + stats["total_time"] = (time.time() - start) * 1000 + stats["total_tokens"] = idx + + if request.streaming: + for sample_id in range(request.num_samples): + if finished[sample_id]: + continue + yield ServeStreamResponse( + finish_reason=response, stats=stats, sample_id=sample_id + ) + return + + yield ServeResponse( + messages=[ + ServeMessage(role="assistant", parts=parts[i]) + for i in range(request.num_samples) + ], + finish_reason=response, + stats=stats, + ) + + +# @routes.http.post("/v1/chat") +# def api_invoke_chat( +# req: Annotated[ServeRequest, Body(exclusive=True)], +# ): +# """ +# Invoke model and generate audio +# """ +# +# # This makes torch compile happy +# assert ( +# req.num_samples == GLOBAL_NUM_SAMPLES +# ), f"num_samples must be {GLOBAL_NUM_SAMPLES}" +# +# content_type = request.headers.get("Content-Type", "application/json") +# json_mode = "application/json" in content_type +# +# async def wrapped_generator(): +# generator = execute_request(llama_queue, tokenizer, config, req, args.device) +# +# for i in generator: +# if json_mode: +# body = i.model_dump_json().encode("utf-8") +# yield b"data: " + body + b"\n\n" +# else: +# body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) +# yield struct.pack("I", len(body)) + body +# +# # Naive mode +# if req.streaming is False: +# result = next(execute_request(llama_queue, tokenizer, config, req, args.device)) +# +# if json_mode: +# return JSONResponse(result.model_dump()) +# else: +# return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) +# +# return StreamResponse( +# iterable=wrapped_generator(), content_type="text/event-stream" +# ) + + @torch.inference_mode() def inference(req: ServeTTSRequest): + global prompt_tokens, prompt_texts + idstr: str | None = req.reference_id if idstr is not None: ref_folder = Path("references") / idstr @@ -177,33 +603,47 @@ def inference(req: ServeTTSRequest): ref_audios = list_files( ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False ) - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=audio_to_bytes(str(ref_audio)), - enable_reference_audio=True, - ) - for ref_audio in ref_audios - ] - prompt_texts = [ - read_ref_text(str(ref_audio.with_suffix(".lab"))) - for ref_audio in ref_audios - ] + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + else: + logger.info("Use same references") else: # Parse reference audio aka prompt refs = req.references - if refs is None: - refs = [] - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=ref.audio, - enable_reference_audio=True, - ) - for ref in refs - ] - prompt_texts = [ref.text for ref in refs] + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + else: + logger.info("Use same references") + + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") # LLAMA Inference request = dict( @@ -220,7 +660,7 @@ def inference(req: ServeTTSRequest): compile=args.compile, iterative_prompt=req.chunk_length > 0, chunk_length=req.chunk_length, - max_length=2048, + max_length=4096, prompt_tokens=prompt_tokens, prompt_text=prompt_texts, ) @@ -342,6 +782,8 @@ async def buffer_to_async_generator(buffer): def parse_args(): parser = ArgumentParser() + parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") + parser.add_argument("--load-asr-model", action="store_true") parser.add_argument( "--llama-checkpoint-path", type=str, @@ -367,18 +809,26 @@ def parse_args(): # openapi = OpenAPI( # { # "title": "Fish Speech API", +# "version": "1.4.2", # }, # ).routes # # # class MsgPackRequest(HttpRequest): -# async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: +# async def data( +# self, +# ) -> Annotated[ +# Any, ContentType("application/msgpack"), ContentType("application/json") +# ]: # if self.content_type == "application/msgpack": # return ormsgpack.unpackb(await self.body) # +# elif self.content_type == "application/json": +# return await self.json +# # raise HTTPException( # HTTPStatus.UNSUPPORTED_MEDIA_TYPE, -# headers={"Accept": "application/msgpack"}, +# headers={"Accept": "application/msgpack, application/json"}, # ) # # @@ -393,48 +843,101 @@ def parse_args(): # ) -if __name__ == "__main__": +def load_asr_model(*, device="cuda", hub="ms"): + return AutoModel( + model="iic/SenseVoiceSmall", + device=device, + disable_pbar=True, + hub=hub, + ) - import uvicorn - args = parse_args() - args.precision = torch.half if args.half else torch.bfloat16 +# Each worker process created by Uvicorn has its own memory space, +# meaning that models and variables are not shared between processes. +# Therefore, any global variables (like `llama_queue` or `decoder_model`) +# will not be shared across workers. - logger.info("Loading Llama model...") - llama_queue = launch_thread_safe_queue( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - logger.info("Llama model loaded, loading VQ-GAN model...") - decoder_model = load_decoder_model( - config_name=args.decoder_config_name, - checkpoint_path=args.decoder_checkpoint_path, - device=args.device, - ) +# Multi-threading for deep learning can cause issues, such as inconsistent +# outputs if multiple threads access the same buffers simultaneously. +# Instead, it's better to use multiprocessing or independent models per thread. +# @app.on_startup +# def initialize_app(app: Kui): +# +# global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts +# +# prompt_tokens, prompt_texts = [], [] +# +# args = parse_args() # args same as ones in other processes +# args.precision = torch.half if args.half else torch.bfloat16 +# +# if args.load_asr_model: +# logger.info(f"Loading ASR model...") +# asr_model = load_asr_model(device=args.device) +# +# logger.info("Loading Llama model...") +# +# if args.mode == "tts": +# llama_queue = launch_thread_safe_queue( +# checkpoint_path=args.llama_checkpoint_path, +# device=args.device, +# precision=args.precision, +# compile=args.compile, +# ) +# else: +# llama_queue, tokenizer, config = launch_thread_safe_queue_agent( +# checkpoint_path=args.llama_checkpoint_path, +# device=args.device, +# precision=args.precision, +# compile=args.compile, +# ) +# +# logger.info("Llama model loaded, loading VQ-GAN model...") +# +# decoder_model = load_decoder_model( +# config_name=args.decoder_config_name, +# checkpoint_path=args.decoder_checkpoint_path, +# device=args.device, +# ) +# +# logger.info("VQ-GAN model loaded, warming up...") +# +# vad_model = load_silero_vad() +# +# logger.info("VAD model loaded, warming up...") +# +# if args.mode == "tts": +# # Dry run to ensure models work and avoid first-time latency +# list( +# inference( +# ServeTTSRequest( +# text="Hello world.", +# references=[], +# reference_id=None, +# max_new_tokens=0, +# chunk_length=200, +# top_p=0.7, +# repetition_penalty=1.2, +# temperature=0.7, +# emotion=None, +# format="wav", +# ) +# ) +# ) +# +# logger.info(f"Warming up done, starting server at http://{args.listen}") - logger.info("VQ-GAN model loaded, warming up...") - - # Dry run to check if the model is loaded correctly and avoid the first-time latency - list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=1024, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.2, - temperature=0.7, - emotion=None, - format="wav", - ) - ) - ) - logger.info(f"Warming up done, starting server at http://{args.listen}") +if __name__ == "__main__": + + import uvicorn + + args = parse_args() host, port = args.listen.split(":") - uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") + uvicorn.run( + "tools.api:app", + host=host, + port=int(port), + workers=args.workers, + log_level="info", + ) diff --git a/xinference/thirdparty/fish_speech/tools/commons.py b/xinference/thirdparty/fish_speech/tools/commons.py deleted file mode 100644 index f81cadec1e..0000000000 --- a/xinference/thirdparty/fish_speech/tools/commons.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Annotated, Literal, Optional - -from pydantic import BaseModel, Field, conint - - -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - -class ServeTTSRequest(BaseModel): - text: str - chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 - # Audio format - format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 - # References audios for in-context learning - references: list[ServeReferenceAudio] = [] - # Reference id - # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ - # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 - reference_id: str | None = None - # Normalize text for en & zh, this increase stability for numbers - normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" - # not usually used below - streaming: bool = False - emotion: Optional[str] = None - max_new_tokens: int = 1024 - top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 - temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/xinference/thirdparty/fish_speech/tools/e2e_webui.py b/xinference/thirdparty/fish_speech/tools/e2e_webui.py new file mode 100644 index 0000000000..37474fbd56 --- /dev/null +++ b/xinference/thirdparty/fish_speech/tools/e2e_webui.py @@ -0,0 +1,232 @@ +import io +import re +import wave + +import gradio as gr +import numpy as np + +from .fish_e2e import FishE2EAgent, FishE2EEventType +from .schema import ServeMessage, ServeTextPart, ServeVQPart + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +class ChatState: + def __init__(self): + self.conversation = [] + self.added_systext = False + self.added_sysaudio = False + + def get_history(self): + results = [] + for msg in self.conversation: + results.append({"role": msg.role, "content": self.repr_message(msg)}) + + # Process assistant messages to extract questions and update user messages + for i, msg in enumerate(results): + if msg["role"] == "assistant": + match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) + if match and i > 0 and results[i - 1]["role"] == "user": + # Update previous user message with extracted question + results[i - 1]["content"] += "\n" + match.group(1) + # Remove the Question/Answer format from assistant message + msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] + return results + + def repr_message(self, msg: ServeMessage): + response = "" + for part in msg.parts: + if isinstance(part, ServeTextPart): + response += part.text + elif isinstance(part, ServeVQPart): + response += f"