Skip to content

AttributeError: module 'torch.amp' has no attribute 'GradScaler' #81

@VermiNew

Description

@VermiNew

I encountered an AttributeError when running the XTTS-API-SERVER. The error message indicates that the torch.amp module does not have an attribute GradScaler. Below is the full traceback of the error:

Traceback (most recent call last):
  File "my_script.py", line 1, in <module>
    from xtts_api_server.server import app
  File "/opt/conda/lib/python3.10/site-packages/xtts_api_server/server.py", line 1, in <module>
    from TTS.api import TTS
  File "/opt/conda/lib/python3.10/site-packages/TTS/api.py", line 8, in <module>
    from TTS.config import load_config
  File "/opt/conda/lib/python3.10/site-packages/TTS/config/__init__.py", line 10, in <module>
    from TTS.config.shared_configs import *
  File "/opt/conda/lib/python3.10/site-packages/TTS/config/shared_configs.py", line 5, in <module>
    from trainer import TrainerConfig
  File "/opt/conda/lib/python3.10/site-packages/trainer/__init__.py", line 4, in <module>
    from trainer.model import *
  File "/opt/conda/lib/python3.10/site-packages/trainer/model.py", line 7, in <module>
    from trainer.trainer import Trainer
  File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 63, in <module>
    class Trainer:
  File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 947, in Trainer
    def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler):
AttributeError: module 'torch.amp' has no attribute 'GradScaler'

Steps to Reproduce:

  1. Install the XTTS-API-SERVER package and its dependencies.
  2. Import the app module from xtts_api_server.server.
  3. Run the script.

Expected Behavior:
The server should start without any errors.

Actual Behavior:
An AttributeError is raised, indicating that torch.amp does not have the GradScaler attribute.

Environment:

  • OS: Ubuntu 20.04 & second try: Windows 11
  • Python Version: 3.10.11
  • XTTS-API-SERVER Version: 0.9.0
  • PyTorch Version: 2.1.1+cu118

Additional Information:
This error might be due to a mismatch between the PyTorch version and the usage of the GradScaler attribute in the code. In recent versions of PyTorch, GradScaler is located under torch.cuda.amp rather than torch.amp.

Potential solution: Update the import statement in trainer.py to:

from torch.cuda.amp import GradScaler

Thank you for looking into this issue!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions