-
Notifications
You must be signed in to change notification settings - Fork 141
Open
Description
I encountered an AttributeError when running the XTTS-API-SERVER. The error message indicates that the torch.amp
module does not have an attribute GradScaler
. Below is the full traceback of the error:
Traceback (most recent call last):
File "my_script.py", line 1, in <module>
from xtts_api_server.server import app
File "/opt/conda/lib/python3.10/site-packages/xtts_api_server/server.py", line 1, in <module>
from TTS.api import TTS
File "/opt/conda/lib/python3.10/site-packages/TTS/api.py", line 8, in <module>
from TTS.config import load_config
File "/opt/conda/lib/python3.10/site-packages/TTS/config/__init__.py", line 10, in <module>
from TTS.config.shared_configs import *
File "/opt/conda/lib/python3.10/site-packages/TTS/config/shared_configs.py", line 5, in <module>
from trainer import TrainerConfig
File "/opt/conda/lib/python3.10/site-packages/trainer/__init__.py", line 4, in <module>
from trainer.model import *
File "/opt/conda/lib/python3.10/site-packages/trainer/model.py", line 7, in <module>
from trainer.trainer import Trainer
File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 63, in <module>
class Trainer:
File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 947, in Trainer
def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler):
AttributeError: module 'torch.amp' has no attribute 'GradScaler'
Steps to Reproduce:
- Install the XTTS-API-SERVER package and its dependencies.
- Import the
app
module fromxtts_api_server.server
. - Run the script.
Expected Behavior:
The server should start without any errors.
Actual Behavior:
An AttributeError is raised, indicating that torch.amp
does not have the GradScaler
attribute.
Environment:
- OS: Ubuntu 20.04 & second try: Windows 11
- Python Version: 3.10.11
- XTTS-API-SERVER Version: 0.9.0
- PyTorch Version: 2.1.1+cu118
Additional Information:
This error might be due to a mismatch between the PyTorch version and the usage of the GradScaler
attribute in the code. In recent versions of PyTorch, GradScaler
is located under torch.cuda.amp
rather than torch.amp
.
Potential solution: Update the import statement in trainer.py
to:
from torch.cuda.amp import GradScaler
Thank you for looking into this issue!
Metadata
Metadata
Assignees
Labels
No labels