Skip to content

Commit

Permalink
only support model id in api now
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Aug 14, 2024
1 parent 50fcfc0 commit 4114d3b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 38 deletions.
70 changes: 34 additions & 36 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import importlib
import os
from functools import partial
from typing import List, Union
from typing import List

import torch
from torch.fx import GraphModule
from transformers import AutoConfig

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline
Expand All @@ -43,7 +44,7 @@ def parallelize_backend(


def parallelize_model(
model: Union[torch.nn.Module, str],
model: str,
parallel_ctx: ParallelExecutionCtx,
*model_args,
**kwargs,
Expand All @@ -52,8 +53,8 @@ def parallelize_model(
API for automatic model parallelism through Pytorch FX.
Args:
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
model (str):
Model to parallelize, a model id on the Huggingface Hub.
parallel_ctx (ParallelExecutionCtx):
Parallel execution context containing process groups the current process belongs to.
*model_args (Any):
Expand All @@ -80,44 +81,41 @@ def parallelize_model(
setattr(parallel_config, k, v)
kwargs.pop(k)

if isinstance(model, str):
from transformers import AutoConfig

is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model

# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)

if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)

with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

move_model_to_device(model, device=parallel_ctx.current_device)
initialize_parameter_meta(model)
Expand Down
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def propagate(self) -> List[int]:
# last resort, if no input is being parallelized, then we make output also not parallelized,
# this will give us relief on writing policies for strange ops which don't actually need
# parallelization in most cases
if all([self.extract_axis(arg) is None for arg in self.node.all_input_nodes]):
if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes):
return [None]

raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}")
2 changes: 1 addition & 1 deletion optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs)
stable_topological_sort(graph)

nodes = [node for node in graph.nodes]
nodes = list(graph.nodes)

def search(idx: int):
if idx == len(nodes):
Expand Down

0 comments on commit 4114d3b

Please sign in to comment.