Skip to content

Commit a64dde1

Browse files
committed
doc: disable TE be default
1 parent 7643c64 commit a64dde1

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

ChatTTS/model/gpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
],
9292
)
9393

94-
def from_pretrained(self, file_path: str):
94+
def from_pretrained(self, file_path: str, experimental=False):
9595
if self.is_vllm and platform.system().lower() == "linux":
9696
from safetensors.torch import save_file
9797

@@ -134,12 +134,12 @@ def from_pretrained(self, file_path: str):
134134
self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True))
135135

136136
if (
137-
"cuda" in str(self.device_gpt) and platform.system().lower() == "linux"
137+
experimental and "cuda" in str(self.device_gpt) and platform.system().lower() == "linux"
138138
): # is TELlamaModel
139139
try:
140140
from .cuda import TELlamaModel
141141

142-
self.logger.info("Linux with CUDA, try NVIDIA accelerated TELlamaModel")
142+
self.logger.warning("Linux with CUDA, try NVIDIA accelerated TELlamaModel because experimental is enabled")
143143
state_dict = self.gpt.state_dict()
144144
vanilla = TELlamaModel.from_state_dict(state_dict, self.llama_config)
145145
# Force mem release. Taken from huggingface code

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,28 @@ pip install safetensors vllm==0.2.7 torchaudio
107107
```
108108

109109
#### Unrecommended Optional: Install TransformerEngine if using NVIDIA GPU (Linux only)
110-
> [!Note]
111-
> The installation process is very slow.
112-
113110
> [!Warning]
111+
> DO NOT INSTALL!
114112
> The adaptation of TransformerEngine is currently under development and CANNOT run properly now.
115-
> Only install it on developing purpose.
113+
> Only install it on developing purpose. See more details on at #672 #676
114+
115+
> [!Note]
116+
> The installation process is very slow.
116117
117118
```bash
118119
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
119120
```
120121

121122
#### Unrecommended Optional: Install FlashAttention-2 (mainly NVIDIA GPU)
122-
> [!Note]
123-
> See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
124-
125123
> [!Warning]
124+
> DO NOT INSTALL!
126125
> Currently the FlashAttention-2 will slow down the generating speed according to [this issue](https://github.com/huggingface/transformers/issues/26990).
127126
> Only install it on developing purpose.
128127
128+
> [!Note]
129+
> See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
130+
131+
129132
```bash
130133
pip install flash-attn --no-build-isolation
131134
```

0 commit comments

Comments
 (0)