Skip to content

Commit

Permalink
Merge pull request #127 from Modalities/update_to_new_accelerate_api
Browse files Browse the repository at this point in the history
debug: adapt to new accelerate API
  • Loading branch information
mali-git committed May 10, 2024
2 parents fdcf36b + d46040a commit 435ab82
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/modalities/running_env/fsdp/fsdp_auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, List

import torch.nn as nn
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
from accelerate.utils.dataclasses import get_module_class_from_name
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from modalities.config.lookup_enum import LookupEnum
Expand All @@ -28,7 +28,7 @@ def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str])
for cls_block_name in block_names:
# TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct
# block class. In the long-term we should implmement this ourselves in a robuster fashion.
block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name)
block_type = get_module_class_from_name(model, cls_block_name)
if block_type is None:
raise ValueError(f"Could not find block with name {cls_block_name} in model")
fsdp_block_types.append(block_type)
Expand Down

0 comments on commit 435ab82

Please sign in to comment.