diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index bd1db928..4019b416 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -4,7 +4,7 @@ from typing import Callable, List import torch.nn as nn -from accelerate.utils.dataclasses import get_module_class_from_name +from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from modalities.config.lookup_enum import LookupEnum