From 5e39787d4a62e750dba16c504519758e8a5e5f30 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 3 Jun 2024 07:02:12 +0200 Subject: [PATCH 01/25] WIP --- optimum/fx/parallelization/__init__.py | 0 optimum/fx/parallelization/analyze.py | 233 ++++++++++++++++++ optimum/fx/parallelization/chainable_pass.py | 134 ++++++++++ optimum/fx/parallelization/core/__init__.py | 2 + optimum/fx/parallelization/core/config.py | 43 ++++ optimum/fx/parallelization/core/context.py | 6 + .../parallel_layers/__init__.py | 1 + .../parallelization/parallel_layers/linear.py | 60 +++++ optimum/fx/parallelization/transform.py | 0 optimum/fx/parallelization/utils.py | 28 +++ 10 files changed, 507 insertions(+) create mode 100644 optimum/fx/parallelization/__init__.py create mode 100644 optimum/fx/parallelization/analyze.py create mode 100644 optimum/fx/parallelization/chainable_pass.py create mode 100644 optimum/fx/parallelization/core/__init__.py create mode 100644 optimum/fx/parallelization/core/config.py create mode 100644 optimum/fx/parallelization/core/context.py create mode 100644 optimum/fx/parallelization/parallel_layers/__init__.py create mode 100644 optimum/fx/parallelization/parallel_layers/linear.py create mode 100644 optimum/fx/parallelization/transform.py create mode 100644 optimum/fx/parallelization/utils.py diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/optimum/fx/parallelization/analyze.py b/optimum/fx/parallelization/analyze.py new file mode 100644 index 0000000000..d824dc058b --- /dev/null +++ b/optimum/fx/parallelization/analyze.py @@ -0,0 +1,233 @@ +from typing import Any, Dict, List, Type, Callable, Optional +from torch.fx import Graph, GraphModule, Node +from torch._inductor.pattern_matcher import stable_topological_sort +from functools import reduce +from collections import defaultdict +from .chainable_pass import ChainablePass +from .utils import is_linear, is_sdpa, is_activation, is_matmul + + +class AnalyzeBase(ChainablePass): + # place unique meta_key in `meta` to prevent duplicate fields + @property + def meta_key(self) -> str: + return f'{self.signature()}' + + def get_stored_field_info(self, node : Node, field : Any) -> Any: + if not self.already_executed_per_node(node): + return None + + info : Dict[Any, Any] = node.meta[self.meta_key] + if field not in info: + raise ValueError(f"Invalid query field {field} for {self.__name__}, valid fields are {list(info.keys())}") + + return info[field] + + def already_executed_per_node(self, node : Node) -> None: + return self.meta_key in node.meta + + def place_marker_per_node(self, node : Node, info : Dict[Any, Any]) -> None: + node.meta[self.meta_key] = info + + def clear_marker_per_node(self, node : Node) -> None: + if self.meta_key in node.meta: + node.meta.pop(self.meta_key) + + def clean_all(self, graph_module : GraphModule) -> None: + g : Graph = graph_module.graph + for node in g.nodes: + self.clear_marker_per_node(node) + + +class PostDominatorSolverPass(AnalyzeBase): + def __init__( + self, + node_filter : Callable[[Node], bool] = lambda x : True, + next: Optional[ChainablePass] = None) -> None: + super().__init__(next) + self.node_filter = node_filter + + def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: + g : Graph = graph_module.graph + stable_topological_sort(g) + + for node in reversed(g.nodes): + doms = {node} + candidates = [] + for user in node.users: + dom = self.get_stored_field_info(user, 'post_doms') + assert dom is not None + candidates.append(dom) + if len(candidates): + doms = doms.union(reduce(lambda x, y: x.intersection(y), candidates)) + self.place_marker_per_node(node, {'post_doms' : doms}) + + for node in g.nodes: + if not self.node_filter(node): + self.clear_marker_per_node() + + return graph_module + + +class DependencySetSolverPass(AnalyzeBase): + def __init__( + self, + node_filter : Callable[[Node], bool] = lambda x : True, + next: Optional[ChainablePass] = None) -> None: + super().__init__(next) + self.node_filter = node_filter + def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: + g : Graph = graph_module.graph + stable_topological_sort(g) + + for node in g.nodes: + deps = {node} + candidates = [] + for pred in node.all_input_nodes: + dep = self.get_stored_field_info(pred, 'dependency_nodes') + assert dep is not None + candidates.append(dep) + deps = reduce(lambda x, y: x.union(y), candidates, deps) + self.place_marker_per_node(node, {'dependency_nodes' : deps}) + + for node in g.nodes: + if not self.node_filter(node): + self.clear_marker_per_node() + + return graph_module + + +class ParallelLinearAnnotatePass(AnalyzeBase): + dependencies = [PostDominatorSolverPass, DependencySetSolverPass] + + def mark_attention_related_linears( + self, + graph : Graph, + dependency_set_solver_pass : AnalyzeBase, + post_dominator_solver_pass : AnalyzeBase, + downstream_linears : List[Node] + ) -> None: + deps, post_doms = [], [] + for linear in downstream_linears: + dep = dependency_set_solver_pass.get_stored_field_info(linear, field='dependency_nodes') + assert dep is not None, "`DependencySetSolverPass` must have run before `ParallelLinearAnnotatePass`" + deps.append(dep) + + post_dom = post_dominator_solver_pass.get_stored_field_info(linear, 'post_doms') + assert post_dom is not None, "`PostDominatorSolverPass` must have run before `ParallelLinearAnnotatePass`" + post_doms.append(post_dom) + + # Check 1: no dependencies between parallel linears + if {downstream_linears[0], downstream_linears[1]}.intersection(deps[2]) or \ + {downstream_linears[1], downstream_linears[2]}.intersection(deps[0]) or \ + {downstream_linears[0], downstream_linears[2]}.intersection(deps[1]): + return + + # Check 2: there is a Linear after these three Linears and it post-dominates these three linears + # Need topo-order here + node, last_node = downstream_linears[-1].next, next(reversed(graph.nodes)) + sdpas, matmul_2, matmul_3 = 0, 0, 0 + while node is not last_node and not is_linear(node): + if is_matmul(node): + doms = sum([int(node in post_dom) for post_dom in post_doms]) + if doms == 2: + # we find a matmul dominating the two linears(Q,K) out of all three linears + matmul_2 += 1 + elif doms == 3 and matmul_2 == 1: + # we find a matmul dominating the previous matmul and all three linears + matmul_3 += 1 + elif is_sdpa(node) and all([node in post_dom for post_dom in post_doms]): + sdpas += 1 + node = node.next + + if node is last_node or any([node not in post_dom for post_dom in post_doms]): + return + + # Check 3: there is two dominating matmuls or there is one dominating sdpa + if not ((sdpas == 1) ^ (matmul_2 == 1 and matmul_3 == 1)): + return + + # we can almost certainly say we have captured an self-attention pattern here, + # we will be fine as long as we are right under 99% of situations + for linear in downstream_linears: + self.place_marker_per_node(linear, {'replace_by' : 'column'}) + + self.place_marker_per_node(node, {'replace_by' : 'row'}) + + + def mark_mlp_related_linears( + self, + graph : Graph, + dependency_set_solver_pass : AnalyzeBase, + post_dominator_solver_pass : AnalyzeBase, + linears : List[Node] + ) -> None: + if any([self.already_executed_per_node(node) for node in linears]): + return + + deps, post_doms = [], [] + for linear in linears: + dep = dependency_set_solver_pass.get_stored_field_info(linear, field='dependency_nodes') + assert dep is not None, "`DependencySetSolverPass` must have run before `ParallelLinearAnnotatePass`" + deps.append(dep) + + post_dom = post_dominator_solver_pass.get_stored_field_info(linear, 'post_doms') + assert post_dom is not None, "`PostDominatorSolverPass` must have run before `ParallelLinearAnnotatePass`" + post_doms.append(post_dom) + + if len(linears) == 2 and linears[0] in deps[1] or linears[1] in deps[0]: + return + + node, last_node = linears[-1].next, next(reversed(graph.nodes)) + + activations = 0 + while node is not last_node and not is_linear(node): + if is_activation(node) and sum([int(node in post_dom) for post_dom in post_doms]): + activations += 1 + node = node.next + + if node is last_node or self.already_executed_per_node(node) or any([node not in post_dom for post_dom in post_doms]): + return + + # should have at least one activation node in between + if activations == 0: + return + + for linear in linears: + self.place_marker_per_node(linear, {'replace_by' : 'column'}) + + self.place_marker_per_node(node, {'replace_by' : 'row'}) + + + def run( + self, + graph_module: GraphModule, + passes : Dict[Type[ChainablePass], ChainablePass], + **kwargs + ) -> GraphModule: + g : Graph = graph_module.graph + stable_topological_sort(g) + + linear_groups : Dict[Node, List[Node]] = defaultdict(list) + for node in g.nodes: + if is_linear(node): + linear_groups[node.args[0]].append(node) + + dependency_set_solver_pass, post_dominator_solver_pass = self.extract_depending_passes(passes) + + # first process attention-related linears, q_proj, k_proj, v_proj, o_proj + for _, downstream_linears in linear_groups.items(): + if len(downstream_linears) == 3: + self.mark_attention_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + + # then llama-style mlp + for _, downstream_linears in linear_groups.items(): + if len(downstream_linears) == 2: + self.mark_mlp_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + + # finally classic-style mlp + for _, downstream_linears in linear_groups.items(): + if len(downstream_linears) == 1: + self.mark_mlp_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + + return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/chainable_pass.py b/optimum/fx/parallelization/chainable_pass.py new file mode 100644 index 0000000000..ae0a743c0a --- /dev/null +++ b/optimum/fx/parallelization/chainable_pass.py @@ -0,0 +1,134 @@ +from __future__ import annotations +from typing import Type, List, Dict, Optional, Any +from abc import ABC, abstractmethod +from torch.fx import GraphModule +from .core import ExecutionCtx, PassPipelineConfig +import warnings + + +class Chainable: + def __init__(self, next : Optional[Chainable]= None) -> None: + self._next = next + + @property + def next(self) -> Optional[Chainable]: + return self._next + + @next.setter + def next(self, next : Optional[Chainable] = None): + self._next = next + + +class PassBase(ABC): + dependencies : List[Type[PassBase]] = [] + + @property + def signature(self) -> int: + return id(self) + + @abstractmethod + def run(self, graph_module : GraphModule, **kwargs: Any) -> GraphModule: + raise NotImplementedError("Implement this first.") + + +class ChainablePass(Chainable, PassBase): + def __init__(self, next: Optional[ChainablePass] = None) -> None: + super().__init__(next) + super(Chainable, self).__init__() + + def extract_depending_passes( + self, + passes : Dict[Type[ChainablePass], List[ChainablePass]] + ) -> List[ChainablePass]: + depending_passes = [] + for dependency_pass_type in self.dependencies: + if dependency_pass_type not in passes: + raise RuntimeError( + f"No {dependency_pass_type.__name__} in the current pipeline, please considering adding it before {self.__class__.__name__}" + ) + elif len(passes[dependency_pass_type]) >= 2: + warnings.warn( + f"Multiple {dependency_pass_type.__name__} found in current pipeline, this might incur incorrect results" + ) + depending_passes.append(passes[dependency_pass_type][-1]) + return passes + + def __call__( + self, + graph_module: GraphModule, + passes: Dict[Type[ChainablePass], List[ChainablePass]] = {}, + ctx: ExecutionCtx = None, + lint_and_recompile: bool = True, + clean_markers_after_all_passes: bool = True, + **kwargs + ) -> GraphModule: + graph_module = self.run(graph_module, passes, ctx, **kwargs) + if lint_and_recompile: + graph_module.graph.lint() + graph_module.recompile() + if self.next: + passes[self.__class__].append(self) + graph_module = self.next(graph_module, passes, ctx, **kwargs) + + from .analyze import AnalyzeBase + if clean_markers_after_all_passes and isinstance(self, AnalyzeBase): + self.clean_all() + return graph_module + + +def build_passes_from_config(config : PassPipelineConfig) -> List[ChainablePass]: + # we traverse the all pass configs in dependency-aware order and collect them if they are active + + from .analyze import PostDominatorSolverPass, DependencySetSolverPass, ParallelLinearAnnotatePass + passes = [] + + if config.post_dominator_solver_config.is_active: + passes.append(PostDominatorSolverPass(node_filter=config.post_dominator_solver_config.node_filter)) + if config.dependency_set_solver_config.is_active: + passes.append(DependencySetSolverPass(node_filter=config.dependency_set_solver_config.node_filter)) + if config.parellel_linear_annotate_config.is_active: + passes.append(ParallelLinearAnnotatePass()) + return passes + + +class ChainablePassPipeline: + def __init__( + self, + passes : List[ChainablePass] = [], + config : PassPipelineConfig = None, + ) -> None: + if len(passes) and config is not None: + raise RuntimeError( + "You can't initiate both `passes` and `config` arguments because there might be" + " conflicts, and `ChainablePassPipeline` won't try detecting and correcting it." + ) + if config is not None: + passes = build_passes_from_config(config) + + self.lead = passes[0] if len(passes) else None + for (prev, next) in zip(passes[:-1], passes[1:]): + prev.next = next + + @classmethod + def from_config(cls, config : PassPipelineConfig): + return cls(config=config) + + def __call__( + self, + graph_module: GraphModule, + passes: Dict[Type[ChainablePass], List[ChainablePass]] = {}, + ctx: ExecutionCtx = None, + lint_and_recompile : bool = True, + clean_markers_after_all_passes : bool = True, + **kwargs: Any + ) -> GraphModule: + if self.lead is not None: + graph_module = self.lead( + graph_module, + passes=passes, + ctx=ctx, + lint_and_recompile=lint_and_recompile, + clean_markers_after_all_passes=clean_markers_after_all_passes, + **kwargs + ) + return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/core/__init__.py b/optimum/fx/parallelization/core/__init__.py new file mode 100644 index 0000000000..6d84129776 --- /dev/null +++ b/optimum/fx/parallelization/core/__init__.py @@ -0,0 +1,2 @@ +from .context import ExecutionCtx +from .config import PassPipelineConfig \ No newline at end of file diff --git a/optimum/fx/parallelization/core/config.py b/optimum/fx/parallelization/core/config.py new file mode 100644 index 0000000000..59d7186fb0 --- /dev/null +++ b/optimum/fx/parallelization/core/config.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from typing import Callable +from torch.fx import Node + + +PARALLEL_INTERESTED_NODES = ( + ('call_module', nn.Linear), + ('call_module', nn.GELU), + ('call_module', nn.SiLU), + ('call_function', torch.matmul), + ('call_function', F.scaled_dot_product_attention), + ('call_function', F.gelu), + ('call_function', F.silu), +) + +@dataclass +class PassConfig: + is_active : bool = False + +@dataclass +class PostDominatorSolverConfig(PassConfig): + # only information of nodes satisfying `node_filter` will be kept + # for later uses in consideration of memory consumption + node_filter : Callable[[Node], bool] = lambda x : True + +@dataclass +class DependencySetSolverConfig(PassConfig): + # only information of nodes satisfying `node_filter` will be kept + # for later uses in consideration of memory consumption + node_filter : Callable[[Node], bool] = lambda x : True + +@dataclass +class ParallelLinearAnnotateConfig(PassConfig): + pass + +@dataclass +class PassPipelineConfig: + post_dominator_solver_config : PostDominatorSolverConfig = PostDominatorSolverConfig() + dependency_set_solver_config : DependencySetSolverConfig = DependencySetSolverConfig() + parellel_linear_annotate_config : ParallelLinearAnnotateConfig = ParallelLinearAnnotateConfig() \ No newline at end of file diff --git a/optimum/fx/parallelization/core/context.py b/optimum/fx/parallelization/core/context.py new file mode 100644 index 0000000000..28643eb212 --- /dev/null +++ b/optimum/fx/parallelization/core/context.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass +import torch.distributed as dist + +@dataclass +class ExecutionCtx: + tp_group : dist.ProcessGroup \ No newline at end of file diff --git a/optimum/fx/parallelization/parallel_layers/__init__.py b/optimum/fx/parallelization/parallel_layers/__init__.py new file mode 100644 index 0000000000..2b5b54c39b --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/__init__.py @@ -0,0 +1 @@ +from .linear import RowParallelLinear, ColumnParallelLinear \ No newline at end of file diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py new file mode 100644 index 0000000000..abd8cafe93 --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +from ..dist import ( + differentiable_all_gather, + differentiable_scatter, + differentiable_all_reduce_sum, +) + + +class ColumnParallelLinear(nn.Linear): + def __init__( + self, + process_group: dist.ProcessGroup, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + gather_output: bool = True, + ) -> None: + self.process_group = process_group + self.word_size = process_group.size() + assert out_features % self.word_size == 0 + + super().__init__(in_features, out_features // self.word_size, bias, device, dtype) + self.gather_output = gather_output + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = super().forward(input) + if self.gather_output: + output = differentiable_all_gather(output, self.process_group) + return output + + +class RowParallelLinear(nn.Linear): + def __init__( + self, + process_group: dist.ProcessGroup, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + input_is_parallel: bool = False, + ) -> None: + self.process_group = process_group + self.word_size = process_group.size() + assert in_features % self.word_size == 0 + + super().__init__(in_features // self.word_size, out_features, bias, device, dtype) + self.input_is_parallel = input_is_parallel + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.input_is_parallel: + input = differentiable_scatter(input, self.process_group) + + output = super().forward(input) + output = differentiable_all_reduce_sum(output, self.process_group) + return output diff --git a/optimum/fx/parallelization/transform.py b/optimum/fx/parallelization/transform.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py new file mode 100644 index 0000000000..73ab21d7b1 --- /dev/null +++ b/optimum/fx/parallelization/utils.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import Node + +def is_linear(node : Node) -> bool: + if node.op != 'call_module': + return False + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.Linear) + +def is_matmul(node : Node) -> bool: + if node.op != 'call_function': + return False + return node.target is torch.matmul + +def is_sdpa(node : Node) -> bool: + if node.op != 'call_function': + return False + return node.target is torch._C._nn.scaled_dot_product_attention + +def is_activation(node : Node) -> bool: + if node.op == 'call_function': + return node.target in {F.gelu, F.silu, F.sigmoid, F.relu, } + elif node.op == 'call_module': + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), (nn.GELU, nn.SiLU, nn.Sigmoid, nn.ReLU)) + return False \ No newline at end of file From 7a5d39404d1c444a7ab881e089d980b660c05383 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 12 Jun 2024 01:03:44 +0200 Subject: [PATCH 02/25] add dist ops --- .../parallelization/distributed/__init__.py | 6 ++ .../parallelization/distributed/dist_ops.py | 100 ++++++++++++++++++ .../parallelization/parallel_layers/linear.py | 2 +- 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 optimum/fx/parallelization/distributed/__init__.py create mode 100644 optimum/fx/parallelization/distributed/dist_ops.py diff --git a/optimum/fx/parallelization/distributed/__init__.py b/optimum/fx/parallelization/distributed/__init__.py new file mode 100644 index 0000000000..f4efcae471 --- /dev/null +++ b/optimum/fx/parallelization/distributed/__init__.py @@ -0,0 +1,6 @@ +from .dist_ops import ( + differentiable_all_gather, + differentiable_identity, + differentiable_all_reduce_sum, + differentiable_scatter, +) \ No newline at end of file diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py new file mode 100644 index 0000000000..94eacb7bd9 --- /dev/null +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -0,0 +1,100 @@ +import torch +import torch.distributed as dist + +def all_reduce(group: dist.ProcessGroup, tensor : torch.Tensor) -> torch.Tensor: + word_size = dist.get_world_size(group) + if word_size == 1: + return tensor + + dist.all_reduce(tensor, group=group) + return tensor + + +def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim = -1) -> torch.Tensor: + word_size = dist.get_world_size(group) + if word_size == 1: + return tensor + rank = dist.get_rank(group = group) + + tensor = tensor.contiguous() + tensors = [torch.empty_like(tensor) for _ in range(word_size)] + tensors[rank] = tensor + + dist.all_gather(tensors, tensor, group=group) + return torch.cat(tensors, dim=gather_dim) + + +def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim = -1) -> torch.Tensor: + word_size = dist.get_world_size(group) + if word_size == 1: + return tensor + + rank = dist.get_rank(group) + + assert tensor.size()[split_dim] % word_size == 0 + + tensors = torch.split(tensor, word_size, dim = split_dim) + + tensor = tensors[rank].contiguous() + + return tensor + + +class DifferentiableIdentity(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, group: dist.ProcessGroup): + ctx.group = group + return tensor + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return DifferentiableAllReduceSum.apply(grad_output, group), None + + +class DifferentiableAllReduceSum(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + ctx.group = group + return all_reduce(group=group, tensor=tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Any: + return grad_output, None + + +class DifferentiableScatter(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim = -1) -> torch.Tensor: + ctx.group = group + ctx.dim = dim + return split(group=group, tensor=tensor, split_dim = dim) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + return DifferentiableAllGather.apply(grad_output, group = ctx.group, dim = ctx.dim), None, None + + +class DifferentiableAllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + ctx.group = group + ctx.dim = dim + return all_gather(group = group, tensor = tensor, gather_dim = dim) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + return DifferentiableScatter.apply(grad_output, group = ctx.group, dim = ctx.dim), None, None + + +def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup): + return DifferentiableAllReduceSum.apply(tensor, group) + +def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup): + return DifferentiableIdentity.apply(tensor, group) + +def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1): + return DifferentiableAllGather.apply(tensor, group, dim) + +def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1): + return DifferentiableScatter.apply(tensor, group, dim) \ No newline at end of file diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index abd8cafe93..6799dcd79b 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.distributed as dist -from ..dist import ( +from ..distributed import ( differentiable_all_gather, differentiable_scatter, differentiable_all_reduce_sum, From 98e58462f1282b88e66a3199eff5bf0d182a8a84 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 15 Jun 2024 20:27:21 +0200 Subject: [PATCH 03/25] add index propagation --- optimum/fx/parallelization/analyze.py | 276 +++++++++++++------ optimum/fx/parallelization/chainable_pass.py | 134 --------- optimum/fx/parallelization/core/config.py | 13 +- optimum/fx/parallelization/core/context.py | 7 +- optimum/fx/parallelization/pass_base.py | 100 +++++++ optimum/fx/parallelization/utils.py | 28 +- 6 files changed, 333 insertions(+), 225 deletions(-) delete mode 100644 optimum/fx/parallelization/chainable_pass.py create mode 100644 optimum/fx/parallelization/pass_base.py diff --git a/optimum/fx/parallelization/analyze.py b/optimum/fx/parallelization/analyze.py index d824dc058b..1a6267cf89 100644 --- a/optimum/fx/parallelization/analyze.py +++ b/optimum/fx/parallelization/analyze.py @@ -1,37 +1,62 @@ -from typing import Any, Dict, List, Type, Callable, Optional +from typing import Any, Dict, List, Callable from torch.fx import Graph, GraphModule, Node from torch._inductor.pattern_matcher import stable_topological_sort +from torch.fx.passes.shape_prop import ShapeProp from functools import reduce from collections import defaultdict -from .chainable_pass import ChainablePass -from .utils import is_linear, is_sdpa, is_activation, is_matmul - - -class AnalyzeBase(ChainablePass): - # place unique meta_key in `meta` to prevent duplicate fields - @property - def meta_key(self) -> str: - return f'{self.signature()}' - - def get_stored_field_info(self, node : Node, field : Any) -> Any: - if not self.already_executed_per_node(node): - return None - - info : Dict[Any, Any] = node.meta[self.meta_key] +from .pass_base import PassBase +from .utils import ( + is_linear, + is_sdpa, + is_activation, + is_matmul, + is_transpose, + is_permute, + is_getitem, +) +from .core import ExecutionCtx + + +class AnalyzeBase(PassBase): + # place class-wise unique meta_key in `meta` to prevent duplicate fields + @classmethod + def meta_key(cls) -> str: + return cls.signature() + + @classmethod + def get_stored_field_info(cls, node : Node, field : Any, must_have : bool = False) -> Any: + if not cls.already_executed_per_node(node): + if not must_have: + return None + else: + raise RuntimeError( + f"Can't find information related with {cls.__name__} in the current node `{node}`" + "make sure {cls.__name__} has run and marked it" + ) + + info : Dict[Any, Any] = node.meta[cls.meta_key()] if field not in info: - raise ValueError(f"Invalid query field {field} for {self.__name__}, valid fields are {list(info.keys())}") + raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") return info[field] - def already_executed_per_node(self, node : Node) -> None: - return self.meta_key in node.meta + @classmethod + def already_executed_per_node(cls, node : Node) -> None: + return cls.meta_key() in node.meta def place_marker_per_node(self, node : Node, info : Dict[Any, Any]) -> None: - node.meta[self.meta_key] = info + if self.already_executed_per_node(node): + raise RuntimeError( + f"Node {node} has already been marked by the current pass, check if " + "the current pass has already been executed in the pipeline" + ) + + node.meta[self.meta_key()] = info def clear_marker_per_node(self, node : Node) -> None: - if self.meta_key in node.meta: - node.meta.pop(self.meta_key) + key = self.meta_key() + if key in node.meta: + node.meta.pop(key) def clean_all(self, graph_module : GraphModule) -> None: g : Graph = graph_module.graph @@ -39,12 +64,16 @@ def clean_all(self, graph_module : GraphModule) -> None: self.clear_marker_per_node(node) +class ShapePropagationPass(AnalyzeBase): + def run(self, graph_module: GraphModule, ctx: ExecutionCtx, **kwargs) -> GraphModule: + example_inputs = ctx.example_inputs + ShapeProp(graph_module).propagate(*example_inputs) + return graph_module + + class PostDominatorSolverPass(AnalyzeBase): - def __init__( - self, - node_filter : Callable[[Node], bool] = lambda x : True, - next: Optional[ChainablePass] = None) -> None: - super().__init__(next) + def __init__(self, node_filter : Callable[[Node], bool] = lambda x : True) -> None: + super().__init__() self.node_filter = node_filter def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: @@ -55,8 +84,7 @@ def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: doms = {node} candidates = [] for user in node.users: - dom = self.get_stored_field_info(user, 'post_doms') - assert dom is not None + dom = self.get_stored_field_info(user, field='post_doms', must_have=True) candidates.append(dom) if len(candidates): doms = doms.union(reduce(lambda x, y: x.intersection(y), candidates)) @@ -64,17 +92,14 @@ def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: for node in g.nodes: if not self.node_filter(node): - self.clear_marker_per_node() + self.clear_marker_per_node(node) return graph_module class DependencySetSolverPass(AnalyzeBase): - def __init__( - self, - node_filter : Callable[[Node], bool] = lambda x : True, - next: Optional[ChainablePass] = None) -> None: - super().__init__(next) + def __init__(self, node_filter : Callable[[Node], bool] = lambda x : True) -> None: + super().__init__() self.node_filter = node_filter def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: g : Graph = graph_module.graph @@ -84,50 +109,39 @@ def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: deps = {node} candidates = [] for pred in node.all_input_nodes: - dep = self.get_stored_field_info(pred, 'dependency_nodes') - assert dep is not None + dep = self.get_stored_field_info(pred, field='dependency_nodes', must_have=True) candidates.append(dep) deps = reduce(lambda x, y: x.union(y), candidates, deps) self.place_marker_per_node(node, {'dependency_nodes' : deps}) for node in g.nodes: if not self.node_filter(node): - self.clear_marker_per_node() + self.clear_marker_per_node(node) return graph_module class ParallelLinearAnnotatePass(AnalyzeBase): - dependencies = [PostDominatorSolverPass, DependencySetSolverPass] - - def mark_attention_related_linears( - self, - graph : Graph, - dependency_set_solver_pass : AnalyzeBase, - post_dominator_solver_pass : AnalyzeBase, - downstream_linears : List[Node] - ) -> None: + def mark_attention_related_linears(self, graph : Graph, linears : List[Node]) -> None: deps, post_doms = [], [] - for linear in downstream_linears: - dep = dependency_set_solver_pass.get_stored_field_info(linear, field='dependency_nodes') - assert dep is not None, "`DependencySetSolverPass` must have run before `ParallelLinearAnnotatePass`" + for linear in linears: + dep = DependencySetSolverPass.get_stored_field_info(linear, field='dependency_nodes', must_have=True) deps.append(dep) - post_dom = post_dominator_solver_pass.get_stored_field_info(linear, 'post_doms') - assert post_dom is not None, "`PostDominatorSolverPass` must have run before `ParallelLinearAnnotatePass`" + post_dom = PostDominatorSolverPass.get_stored_field_info(linear, field='post_doms', must_have=True) post_doms.append(post_dom) # Check 1: no dependencies between parallel linears - if {downstream_linears[0], downstream_linears[1]}.intersection(deps[2]) or \ - {downstream_linears[1], downstream_linears[2]}.intersection(deps[0]) or \ - {downstream_linears[0], downstream_linears[2]}.intersection(deps[1]): + if {linears[0], linears[1]}.intersection(deps[2]) or \ + {linears[1], linears[2]}.intersection(deps[0]) or \ + {linears[0], linears[2]}.intersection(deps[1]): return # Check 2: there is a Linear after these three Linears and it post-dominates these three linears # Need topo-order here - node, last_node = downstream_linears[-1].next, next(reversed(graph.nodes)) + node, last_node = linears[0].next, next(iter(reversed(graph.nodes))) sdpas, matmul_2, matmul_3 = 0, 0, 0 - while node is not last_node and not is_linear(node): + while node is not last_node and (node in linears or not is_linear(node)): if is_matmul(node): doms = sum([int(node in post_dom) for post_dom in post_doms]) if doms == 2: @@ -149,39 +163,31 @@ def mark_attention_related_linears( # we can almost certainly say we have captured an self-attention pattern here, # we will be fine as long as we are right under 99% of situations - for linear in downstream_linears: + for linear in linears: self.place_marker_per_node(linear, {'replace_by' : 'column'}) self.place_marker_per_node(node, {'replace_by' : 'row'}) - def mark_mlp_related_linears( - self, - graph : Graph, - dependency_set_solver_pass : AnalyzeBase, - post_dominator_solver_pass : AnalyzeBase, - linears : List[Node] - ) -> None: + def mark_mlp_related_linears(self, graph : Graph, linears : List[Node]) -> None: if any([self.already_executed_per_node(node) for node in linears]): return deps, post_doms = [], [] for linear in linears: - dep = dependency_set_solver_pass.get_stored_field_info(linear, field='dependency_nodes') - assert dep is not None, "`DependencySetSolverPass` must have run before `ParallelLinearAnnotatePass`" + dep = DependencySetSolverPass.get_stored_field_info(linear, field='dependency_nodes', must_have=True) deps.append(dep) - post_dom = post_dominator_solver_pass.get_stored_field_info(linear, 'post_doms') - assert post_dom is not None, "`PostDominatorSolverPass` must have run before `ParallelLinearAnnotatePass`" + post_dom = PostDominatorSolverPass.get_stored_field_info(linear, field='post_doms', must_have=True) post_doms.append(post_dom) - if len(linears) == 2 and linears[0] in deps[1] or linears[1] in deps[0]: + if len(linears) == 2 and (linears[0] in deps[1] or linears[1] in deps[0]): return - node, last_node = linears[-1].next, next(reversed(graph.nodes)) + node, last_node = linears[0], next(iter(reversed(graph.nodes))) activations = 0 - while node is not last_node and not is_linear(node): + while node is not last_node and (node in linears or not is_linear(node)): if is_activation(node) and sum([int(node in post_dom) for post_dom in post_doms]): activations += 1 node = node.next @@ -199,12 +205,7 @@ def mark_mlp_related_linears( self.place_marker_per_node(node, {'replace_by' : 'row'}) - def run( - self, - graph_module: GraphModule, - passes : Dict[Type[ChainablePass], ChainablePass], - **kwargs - ) -> GraphModule: + def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: g : Graph = graph_module.graph stable_topological_sort(g) @@ -213,21 +214,130 @@ def run( if is_linear(node): linear_groups[node.args[0]].append(node) - dependency_set_solver_pass, post_dominator_solver_pass = self.extract_depending_passes(passes) - # first process attention-related linears, q_proj, k_proj, v_proj, o_proj for _, downstream_linears in linear_groups.items(): if len(downstream_linears) == 3: - self.mark_attention_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + self.mark_attention_related_linears(g, downstream_linears) # then llama-style mlp for _, downstream_linears in linear_groups.items(): if len(downstream_linears) == 2: - self.mark_mlp_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + self.mark_mlp_related_linears(g, downstream_linears) # finally classic-style mlp for _, downstream_linears in linear_groups.items(): if len(downstream_linears) == 1: - self.mark_mlp_related_linears(g, dependency_set_solver_pass, post_dominator_solver_pass, downstream_linears) + self.mark_mlp_related_linears(g, downstream_linears) + + return graph_module + + +class AttentionHeadIndexPropagationPass(AnalyzeBase): + def propagate_transpose(self, node: Node, head_idx: int) -> bool: + if 'dim0' in node.kwargs and 'dim1' in node.kwargs: + dim0, dim1, dims = node.kwargs['dim0'], node.kwargs['dim1'], len(node.meta['tensor_meta'].shape) + dim0 = (dim0 + dims) % dims + dim1 = (dim1 + dims) % dims + if dim0 == head_idx: + self.place_marker_per_node(node, {'head_idx' : dim1}) + return True + elif dim1 == head_idx: + self.place_marker_per_node(node, {'head_idx' : dim0}) + return True + return False + + if len(node.args) == 3: + dims = len(node.meta['tensor_meta'].shape) + if head_idx not in node.args and head_idx - dims not in node.args: + return False + for arg in node.args: + if isinstance(arg, int) and (arg + dims) % dims != head_idx: + self.place_marker_per_node(node, {'head_idx' : (arg + dims) % dims}) + return True + + return False + + def propagate_permute(self, node: Node, head_idx: int) -> bool: + if 'dims' in node.kwargs: + dims = node.kwargs['dims'] + else: + dims = list(node.args[1]) if isinstance(node.args[1], tuple) else [arg for arg in node.args if isinstance(arg,int)] + + dim_len = len(node.meta['tensor_meta'].shape) + dims = [dim + dim_len if dim < 0 else dim for dim in dims] + + for i,dim in enumerate(dims): + if dim == head_idx: + self.place_marker_per_node(node, {'head_idx' : i}) + return True + return False + + def propagate_getitem(self, node: Node, head_idx: int) -> bool: + slices = node.args[1] + dims = len(node.meta['tensor_meta'].shape) + assert head_idx < dims + inc, i, j = 0, 0, 0 + + while i < head_idx and j < len(slices): + if isinstance(slices[j], int): + inc -= 1 + i += 1 + elif slices[j] is None: + inc += 1 + elif slices[j] is Ellipsis: + i = dims + k = j + while k < len(slices): + if isinstance(slices[k], (slice, int)): + i -= 1 + k += 1 + else: + i += 1 + j += 1 + + if inc != 0: + assert head_idx + inc < dims and head_idx + inc >= 0 + self.place_marker_per_node(node, {'head_idx' : head_idx + inc}) + return True + return False + + def run(self, graph_module: GraphModule, ctx: ExecutionCtx, **kwargs) -> GraphModule: + g: Graph = graph_module.graph + stable_topological_sort(g) + for node in g.nodes: + if ParallelLinearAnnotatePass.already_executed_per_node(node): + # start propagating at ColumnLinear + replace_by = ParallelLinearAnnotatePass.get_stored_field_info(node, field='replace_by', must_have=True) + if replace_by == 'column': + self.place_marker_per_node(node, {'head_idx' : 2}) + # stop propagating at RowLinear, concluding the life cycle of attention heads + else: + continue + else: + already_marked_args, head_idx = [], None + for arg in node.all_input_nodes: + if not self.already_executed_per_node(arg): + continue + if head_idx is None: + head_idx = self.get_stored_field_info(arg, field='head_idx', must_have=True) + else: + assert head_idx == self.get_stored_field_info(arg, field='head_idx', must_have=True), \ + "`head_idx` should be equal for all arguments in any related ops" + already_marked_args.append(arg) + + if not already_marked_args: + continue + + marked = False + if is_transpose(node): + marked = self.propagate_transpose(node, head_idx) + elif is_permute(node): + marked = self.propagate_permute(node, head_idx) + elif is_getitem(node): + marked = self.propagate_getitem(node, head_idx) + + # fall back + if not marked: + self.place_marker_per_node(node, {'head_idx' : head_idx}) return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/chainable_pass.py b/optimum/fx/parallelization/chainable_pass.py deleted file mode 100644 index ae0a743c0a..0000000000 --- a/optimum/fx/parallelization/chainable_pass.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations -from typing import Type, List, Dict, Optional, Any -from abc import ABC, abstractmethod -from torch.fx import GraphModule -from .core import ExecutionCtx, PassPipelineConfig -import warnings - - -class Chainable: - def __init__(self, next : Optional[Chainable]= None) -> None: - self._next = next - - @property - def next(self) -> Optional[Chainable]: - return self._next - - @next.setter - def next(self, next : Optional[Chainable] = None): - self._next = next - - -class PassBase(ABC): - dependencies : List[Type[PassBase]] = [] - - @property - def signature(self) -> int: - return id(self) - - @abstractmethod - def run(self, graph_module : GraphModule, **kwargs: Any) -> GraphModule: - raise NotImplementedError("Implement this first.") - - -class ChainablePass(Chainable, PassBase): - def __init__(self, next: Optional[ChainablePass] = None) -> None: - super().__init__(next) - super(Chainable, self).__init__() - - def extract_depending_passes( - self, - passes : Dict[Type[ChainablePass], List[ChainablePass]] - ) -> List[ChainablePass]: - depending_passes = [] - for dependency_pass_type in self.dependencies: - if dependency_pass_type not in passes: - raise RuntimeError( - f"No {dependency_pass_type.__name__} in the current pipeline, please considering adding it before {self.__class__.__name__}" - ) - elif len(passes[dependency_pass_type]) >= 2: - warnings.warn( - f"Multiple {dependency_pass_type.__name__} found in current pipeline, this might incur incorrect results" - ) - depending_passes.append(passes[dependency_pass_type][-1]) - return passes - - def __call__( - self, - graph_module: GraphModule, - passes: Dict[Type[ChainablePass], List[ChainablePass]] = {}, - ctx: ExecutionCtx = None, - lint_and_recompile: bool = True, - clean_markers_after_all_passes: bool = True, - **kwargs - ) -> GraphModule: - graph_module = self.run(graph_module, passes, ctx, **kwargs) - if lint_and_recompile: - graph_module.graph.lint() - graph_module.recompile() - if self.next: - passes[self.__class__].append(self) - graph_module = self.next(graph_module, passes, ctx, **kwargs) - - from .analyze import AnalyzeBase - if clean_markers_after_all_passes and isinstance(self, AnalyzeBase): - self.clean_all() - return graph_module - - -def build_passes_from_config(config : PassPipelineConfig) -> List[ChainablePass]: - # we traverse the all pass configs in dependency-aware order and collect them if they are active - - from .analyze import PostDominatorSolverPass, DependencySetSolverPass, ParallelLinearAnnotatePass - passes = [] - - if config.post_dominator_solver_config.is_active: - passes.append(PostDominatorSolverPass(node_filter=config.post_dominator_solver_config.node_filter)) - if config.dependency_set_solver_config.is_active: - passes.append(DependencySetSolverPass(node_filter=config.dependency_set_solver_config.node_filter)) - if config.parellel_linear_annotate_config.is_active: - passes.append(ParallelLinearAnnotatePass()) - return passes - - -class ChainablePassPipeline: - def __init__( - self, - passes : List[ChainablePass] = [], - config : PassPipelineConfig = None, - ) -> None: - if len(passes) and config is not None: - raise RuntimeError( - "You can't initiate both `passes` and `config` arguments because there might be" - " conflicts, and `ChainablePassPipeline` won't try detecting and correcting it." - ) - if config is not None: - passes = build_passes_from_config(config) - - self.lead = passes[0] if len(passes) else None - for (prev, next) in zip(passes[:-1], passes[1:]): - prev.next = next - - @classmethod - def from_config(cls, config : PassPipelineConfig): - return cls(config=config) - - def __call__( - self, - graph_module: GraphModule, - passes: Dict[Type[ChainablePass], List[ChainablePass]] = {}, - ctx: ExecutionCtx = None, - lint_and_recompile : bool = True, - clean_markers_after_all_passes : bool = True, - **kwargs: Any - ) -> GraphModule: - if self.lead is not None: - graph_module = self.lead( - graph_module, - passes=passes, - ctx=ctx, - lint_and_recompile=lint_and_recompile, - clean_markers_after_all_passes=clean_markers_after_all_passes, - **kwargs - ) - return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/core/config.py b/optimum/fx/parallelization/core/config.py index 59d7186fb0..958a19eff6 100644 --- a/optimum/fx/parallelization/core/config.py +++ b/optimum/fx/parallelization/core/config.py @@ -20,6 +20,11 @@ class PassConfig: is_active : bool = False + +@dataclass +class ShapePropagationConfig(PassConfig): + pass + @dataclass class PostDominatorSolverConfig(PassConfig): # only information of nodes satisfying `node_filter` will be kept @@ -36,8 +41,14 @@ class DependencySetSolverConfig(PassConfig): class ParallelLinearAnnotateConfig(PassConfig): pass +@dataclass +class AttentionHeadIndexPropagationConfig(PassConfig): + pass + @dataclass class PassPipelineConfig: + shape_propagation_config : ShapePropagationConfig = ShapePropagationConfig() post_dominator_solver_config : PostDominatorSolverConfig = PostDominatorSolverConfig() dependency_set_solver_config : DependencySetSolverConfig = DependencySetSolverConfig() - parellel_linear_annotate_config : ParallelLinearAnnotateConfig = ParallelLinearAnnotateConfig() \ No newline at end of file + parellel_linear_annotate_config : ParallelLinearAnnotateConfig = ParallelLinearAnnotateConfig() + attention_head_index_propagation_config : AttentionHeadIndexPropagationConfig = AttentionHeadIndexPropagationConfig() \ No newline at end of file diff --git a/optimum/fx/parallelization/core/context.py b/optimum/fx/parallelization/core/context.py index 28643eb212..0f28bb9b18 100644 --- a/optimum/fx/parallelization/core/context.py +++ b/optimum/fx/parallelization/core/context.py @@ -1,6 +1,9 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import torch.distributed as dist +from typing import List, Any, List + @dataclass class ExecutionCtx: - tp_group : dist.ProcessGroup \ No newline at end of file + example_inputs : List[Any] = field(default_factory=list) + tp_group : dist.ProcessGroup = None \ No newline at end of file diff --git a/optimum/fx/parallelization/pass_base.py b/optimum/fx/parallelization/pass_base.py new file mode 100644 index 0000000000..b79cf1ac39 --- /dev/null +++ b/optimum/fx/parallelization/pass_base.py @@ -0,0 +1,100 @@ +from __future__ import annotations +from typing import List, Any +from abc import ABC, abstractmethod +from torch.fx import GraphModule +from .core import ExecutionCtx, PassPipelineConfig + + +class PassBase(ABC): + @classmethod + def signature(cls) -> str: + return cls.__name__ + + @abstractmethod + def run(self, graph_module : GraphModule, **kwargs: Any) -> GraphModule: + raise NotImplementedError("Implement this first.") + + def __call__( + self, + graph_module: GraphModule, + ctx: ExecutionCtx = ExecutionCtx(), + lint_and_recompile: bool = True, + **kwargs + ) -> GraphModule: + graph_module = self.run(graph_module, ctx=ctx, **kwargs) + if lint_and_recompile: + graph_module.graph.lint() + graph_module.recompile() + return graph_module + + +def build_passes_from_config(config : PassPipelineConfig) -> List[PassBase]: + # we traverse the all pass configs in dependency-aware order and collect them if they are active + + from .analyze import ( + ShapePropagationPass, + PostDominatorSolverPass, + DependencySetSolverPass, + ParallelLinearAnnotatePass, + AttentionHeadIndexPropagationPass, + ) + passes = [] + if config.shape_propagation_config.is_active: + passes.append(ShapePropagationPass()) + if config.post_dominator_solver_config.is_active: + passes.append(PostDominatorSolverPass(node_filter=config.post_dominator_solver_config.node_filter)) + if config.dependency_set_solver_config.is_active: + passes.append(DependencySetSolverPass(node_filter=config.dependency_set_solver_config.node_filter)) + if config.parellel_linear_annotate_config.is_active: + passes.append(ParallelLinearAnnotatePass()) + if config.attention_head_index_propagation_config.is_active: + passes.append(AttentionHeadIndexPropagationPass()) + + return passes + + +class PassPipeline: + def __init__( + self, + passes : List[PassBase] = [], + config : PassPipelineConfig = None, + ) -> None: + if len(passes) and config is not None: + raise RuntimeError( + "You can't initiate both `passes` and `config` arguments because there might be" + " conflicts, and `PassPipeline` won't try detecting and correcting it." + ) + if config is not None: + passes = build_passes_from_config(config) + + self._passes = passes + + @classmethod + def from_config(cls, config : PassPipelineConfig): + return cls(config=config) + + def __iter__(self,): + return self._passes.__iter__() + + def __call__( + self, + graph_module: GraphModule, + ctx: ExecutionCtx = ExecutionCtx(), + lint_and_recompile : bool = True, + clean_markers_after_all_passes : bool = True, + **kwargs: Any + ) -> GraphModule: + for PASS in self._passes: + graph_module = PASS( + graph_module=graph_module, + ctx=ctx, + lint_and_recompile=lint_and_recompile + ) + + from .analyze import AnalyzeBase + + if clean_markers_after_all_passes: + for PASS in self._passes: + if isinstance(PASS, AnalyzeBase): + PASS.clean_all(graph_module) + return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 73ab21d7b1..d22de264cb 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -2,27 +2,45 @@ import torch.nn as nn import torch.nn.functional as F from torch.fx import Node +import operator -def is_linear(node : Node) -> bool: +def is_linear(node: Node) -> bool: if node.op != 'call_module': return False mod = node.graph.owning_module return isinstance(mod.get_submodule(node.target), nn.Linear) -def is_matmul(node : Node) -> bool: +def is_matmul(node: Node) -> bool: if node.op != 'call_function': return False return node.target is torch.matmul -def is_sdpa(node : Node) -> bool: +def is_sdpa(node: Node) -> bool: if node.op != 'call_function': return False return node.target is torch._C._nn.scaled_dot_product_attention -def is_activation(node : Node) -> bool: +def is_activation(node: Node) -> bool: if node.op == 'call_function': return node.target in {F.gelu, F.silu, F.sigmoid, F.relu, } elif node.op == 'call_module': mod = node.graph.owning_module return isinstance(mod.get_submodule(node.target), (nn.GELU, nn.SiLU, nn.Sigmoid, nn.ReLU)) - return False \ No newline at end of file + return False + +def is_transpose(node: Node) -> bool: + if node.op == 'call_method': + return node.target in {'transpose', 'transpose_'} + elif node.op == 'call_function': + return node.target is torch.transpose + return False + +def is_permute(node: Node) -> bool: + if node.op == 'call_method': + return node.target in {'permute'} + elif node.op == 'call_function': + return node.target is torch.permute + return False + +def is_getitem(node: Node) -> bool: + return node.op == 'call_function' and node.target is operator.getitem \ No newline at end of file From 2036dbb12455694477b5e9e2b63f6e7fcb49ad5d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 1 Jul 2024 04:41:04 +0200 Subject: [PATCH 04/25] support tp for linears --- optimum/fx/parallelization/__init__.py | 13 + optimum/fx/parallelization/analyze.py | 343 ------------- optimum/fx/parallelization/core.py | 109 +++++ optimum/fx/parallelization/core/__init__.py | 2 - optimum/fx/parallelization/core/config.py | 54 --- optimum/fx/parallelization/core/context.py | 9 - .../parallelization/distributed/__init__.py | 1 + .../parallelization/distributed/dist_ops.py | 57 ++- .../parallelization/parallel_layers/linear.py | 173 ++++++- optimum/fx/parallelization/pass_base.py | 100 ---- optimum/fx/parallelization/passes.py | 449 ++++++++++++++++++ optimum/fx/parallelization/transform.py | 0 optimum/fx/parallelization/utils.py | 205 +++++++- tests/fx/parallelization/dist_utils.py | 55 +++ .../parallelization/test_tensor_parallel.py | 190 ++++++++ 15 files changed, 1189 insertions(+), 571 deletions(-) delete mode 100644 optimum/fx/parallelization/analyze.py create mode 100644 optimum/fx/parallelization/core.py delete mode 100644 optimum/fx/parallelization/core/__init__.py delete mode 100644 optimum/fx/parallelization/core/config.py delete mode 100644 optimum/fx/parallelization/core/context.py delete mode 100644 optimum/fx/parallelization/pass_base.py create mode 100644 optimum/fx/parallelization/passes.py delete mode 100644 optimum/fx/parallelization/transform.py create mode 100644 tests/fx/parallelization/dist_utils.py create mode 100644 tests/fx/parallelization/test_tensor_parallel.py diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py index e69de29bb2..ee32f3915d 100644 --- a/optimum/fx/parallelization/__init__.py +++ b/optimum/fx/parallelization/__init__.py @@ -0,0 +1,13 @@ +import torch +from torch.fx import GraphModule +from typing import List +from .core import ParallelExecutionCtx, Config +from .passes import build_parallel_pass_pipeline + + +def parallelize_backend(graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config): + ctx.example_inputs = example_inputs + pass_pipeline = build_parallel_pass_pipeline() + graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) + ctx.compile_times += 1 + return graph_module diff --git a/optimum/fx/parallelization/analyze.py b/optimum/fx/parallelization/analyze.py deleted file mode 100644 index 1a6267cf89..0000000000 --- a/optimum/fx/parallelization/analyze.py +++ /dev/null @@ -1,343 +0,0 @@ -from typing import Any, Dict, List, Callable -from torch.fx import Graph, GraphModule, Node -from torch._inductor.pattern_matcher import stable_topological_sort -from torch.fx.passes.shape_prop import ShapeProp -from functools import reduce -from collections import defaultdict -from .pass_base import PassBase -from .utils import ( - is_linear, - is_sdpa, - is_activation, - is_matmul, - is_transpose, - is_permute, - is_getitem, -) -from .core import ExecutionCtx - - -class AnalyzeBase(PassBase): - # place class-wise unique meta_key in `meta` to prevent duplicate fields - @classmethod - def meta_key(cls) -> str: - return cls.signature() - - @classmethod - def get_stored_field_info(cls, node : Node, field : Any, must_have : bool = False) -> Any: - if not cls.already_executed_per_node(node): - if not must_have: - return None - else: - raise RuntimeError( - f"Can't find information related with {cls.__name__} in the current node `{node}`" - "make sure {cls.__name__} has run and marked it" - ) - - info : Dict[Any, Any] = node.meta[cls.meta_key()] - if field not in info: - raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") - - return info[field] - - @classmethod - def already_executed_per_node(cls, node : Node) -> None: - return cls.meta_key() in node.meta - - def place_marker_per_node(self, node : Node, info : Dict[Any, Any]) -> None: - if self.already_executed_per_node(node): - raise RuntimeError( - f"Node {node} has already been marked by the current pass, check if " - "the current pass has already been executed in the pipeline" - ) - - node.meta[self.meta_key()] = info - - def clear_marker_per_node(self, node : Node) -> None: - key = self.meta_key() - if key in node.meta: - node.meta.pop(key) - - def clean_all(self, graph_module : GraphModule) -> None: - g : Graph = graph_module.graph - for node in g.nodes: - self.clear_marker_per_node(node) - - -class ShapePropagationPass(AnalyzeBase): - def run(self, graph_module: GraphModule, ctx: ExecutionCtx, **kwargs) -> GraphModule: - example_inputs = ctx.example_inputs - ShapeProp(graph_module).propagate(*example_inputs) - return graph_module - - -class PostDominatorSolverPass(AnalyzeBase): - def __init__(self, node_filter : Callable[[Node], bool] = lambda x : True) -> None: - super().__init__() - self.node_filter = node_filter - - def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: - g : Graph = graph_module.graph - stable_topological_sort(g) - - for node in reversed(g.nodes): - doms = {node} - candidates = [] - for user in node.users: - dom = self.get_stored_field_info(user, field='post_doms', must_have=True) - candidates.append(dom) - if len(candidates): - doms = doms.union(reduce(lambda x, y: x.intersection(y), candidates)) - self.place_marker_per_node(node, {'post_doms' : doms}) - - for node in g.nodes: - if not self.node_filter(node): - self.clear_marker_per_node(node) - - return graph_module - - -class DependencySetSolverPass(AnalyzeBase): - def __init__(self, node_filter : Callable[[Node], bool] = lambda x : True) -> None: - super().__init__() - self.node_filter = node_filter - def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: - g : Graph = graph_module.graph - stable_topological_sort(g) - - for node in g.nodes: - deps = {node} - candidates = [] - for pred in node.all_input_nodes: - dep = self.get_stored_field_info(pred, field='dependency_nodes', must_have=True) - candidates.append(dep) - deps = reduce(lambda x, y: x.union(y), candidates, deps) - self.place_marker_per_node(node, {'dependency_nodes' : deps}) - - for node in g.nodes: - if not self.node_filter(node): - self.clear_marker_per_node(node) - - return graph_module - - -class ParallelLinearAnnotatePass(AnalyzeBase): - def mark_attention_related_linears(self, graph : Graph, linears : List[Node]) -> None: - deps, post_doms = [], [] - for linear in linears: - dep = DependencySetSolverPass.get_stored_field_info(linear, field='dependency_nodes', must_have=True) - deps.append(dep) - - post_dom = PostDominatorSolverPass.get_stored_field_info(linear, field='post_doms', must_have=True) - post_doms.append(post_dom) - - # Check 1: no dependencies between parallel linears - if {linears[0], linears[1]}.intersection(deps[2]) or \ - {linears[1], linears[2]}.intersection(deps[0]) or \ - {linears[0], linears[2]}.intersection(deps[1]): - return - - # Check 2: there is a Linear after these three Linears and it post-dominates these three linears - # Need topo-order here - node, last_node = linears[0].next, next(iter(reversed(graph.nodes))) - sdpas, matmul_2, matmul_3 = 0, 0, 0 - while node is not last_node and (node in linears or not is_linear(node)): - if is_matmul(node): - doms = sum([int(node in post_dom) for post_dom in post_doms]) - if doms == 2: - # we find a matmul dominating the two linears(Q,K) out of all three linears - matmul_2 += 1 - elif doms == 3 and matmul_2 == 1: - # we find a matmul dominating the previous matmul and all three linears - matmul_3 += 1 - elif is_sdpa(node) and all([node in post_dom for post_dom in post_doms]): - sdpas += 1 - node = node.next - - if node is last_node or any([node not in post_dom for post_dom in post_doms]): - return - - # Check 3: there is two dominating matmuls or there is one dominating sdpa - if not ((sdpas == 1) ^ (matmul_2 == 1 and matmul_3 == 1)): - return - - # we can almost certainly say we have captured an self-attention pattern here, - # we will be fine as long as we are right under 99% of situations - for linear in linears: - self.place_marker_per_node(linear, {'replace_by' : 'column'}) - - self.place_marker_per_node(node, {'replace_by' : 'row'}) - - - def mark_mlp_related_linears(self, graph : Graph, linears : List[Node]) -> None: - if any([self.already_executed_per_node(node) for node in linears]): - return - - deps, post_doms = [], [] - for linear in linears: - dep = DependencySetSolverPass.get_stored_field_info(linear, field='dependency_nodes', must_have=True) - deps.append(dep) - - post_dom = PostDominatorSolverPass.get_stored_field_info(linear, field='post_doms', must_have=True) - post_doms.append(post_dom) - - if len(linears) == 2 and (linears[0] in deps[1] or linears[1] in deps[0]): - return - - node, last_node = linears[0], next(iter(reversed(graph.nodes))) - - activations = 0 - while node is not last_node and (node in linears or not is_linear(node)): - if is_activation(node) and sum([int(node in post_dom) for post_dom in post_doms]): - activations += 1 - node = node.next - - if node is last_node or self.already_executed_per_node(node) or any([node not in post_dom for post_dom in post_doms]): - return - - # should have at least one activation node in between - if activations == 0: - return - - for linear in linears: - self.place_marker_per_node(linear, {'replace_by' : 'column'}) - - self.place_marker_per_node(node, {'replace_by' : 'row'}) - - - def run(self, graph_module: GraphModule, **kwargs) -> GraphModule: - g : Graph = graph_module.graph - stable_topological_sort(g) - - linear_groups : Dict[Node, List[Node]] = defaultdict(list) - for node in g.nodes: - if is_linear(node): - linear_groups[node.args[0]].append(node) - - # first process attention-related linears, q_proj, k_proj, v_proj, o_proj - for _, downstream_linears in linear_groups.items(): - if len(downstream_linears) == 3: - self.mark_attention_related_linears(g, downstream_linears) - - # then llama-style mlp - for _, downstream_linears in linear_groups.items(): - if len(downstream_linears) == 2: - self.mark_mlp_related_linears(g, downstream_linears) - - # finally classic-style mlp - for _, downstream_linears in linear_groups.items(): - if len(downstream_linears) == 1: - self.mark_mlp_related_linears(g, downstream_linears) - - return graph_module - - -class AttentionHeadIndexPropagationPass(AnalyzeBase): - def propagate_transpose(self, node: Node, head_idx: int) -> bool: - if 'dim0' in node.kwargs and 'dim1' in node.kwargs: - dim0, dim1, dims = node.kwargs['dim0'], node.kwargs['dim1'], len(node.meta['tensor_meta'].shape) - dim0 = (dim0 + dims) % dims - dim1 = (dim1 + dims) % dims - if dim0 == head_idx: - self.place_marker_per_node(node, {'head_idx' : dim1}) - return True - elif dim1 == head_idx: - self.place_marker_per_node(node, {'head_idx' : dim0}) - return True - return False - - if len(node.args) == 3: - dims = len(node.meta['tensor_meta'].shape) - if head_idx not in node.args and head_idx - dims not in node.args: - return False - for arg in node.args: - if isinstance(arg, int) and (arg + dims) % dims != head_idx: - self.place_marker_per_node(node, {'head_idx' : (arg + dims) % dims}) - return True - - return False - - def propagate_permute(self, node: Node, head_idx: int) -> bool: - if 'dims' in node.kwargs: - dims = node.kwargs['dims'] - else: - dims = list(node.args[1]) if isinstance(node.args[1], tuple) else [arg for arg in node.args if isinstance(arg,int)] - - dim_len = len(node.meta['tensor_meta'].shape) - dims = [dim + dim_len if dim < 0 else dim for dim in dims] - - for i,dim in enumerate(dims): - if dim == head_idx: - self.place_marker_per_node(node, {'head_idx' : i}) - return True - return False - - def propagate_getitem(self, node: Node, head_idx: int) -> bool: - slices = node.args[1] - dims = len(node.meta['tensor_meta'].shape) - assert head_idx < dims - inc, i, j = 0, 0, 0 - - while i < head_idx and j < len(slices): - if isinstance(slices[j], int): - inc -= 1 - i += 1 - elif slices[j] is None: - inc += 1 - elif slices[j] is Ellipsis: - i = dims - k = j - while k < len(slices): - if isinstance(slices[k], (slice, int)): - i -= 1 - k += 1 - else: - i += 1 - j += 1 - - if inc != 0: - assert head_idx + inc < dims and head_idx + inc >= 0 - self.place_marker_per_node(node, {'head_idx' : head_idx + inc}) - return True - return False - - def run(self, graph_module: GraphModule, ctx: ExecutionCtx, **kwargs) -> GraphModule: - g: Graph = graph_module.graph - stable_topological_sort(g) - - for node in g.nodes: - if ParallelLinearAnnotatePass.already_executed_per_node(node): - # start propagating at ColumnLinear - replace_by = ParallelLinearAnnotatePass.get_stored_field_info(node, field='replace_by', must_have=True) - if replace_by == 'column': - self.place_marker_per_node(node, {'head_idx' : 2}) - # stop propagating at RowLinear, concluding the life cycle of attention heads - else: - continue - else: - already_marked_args, head_idx = [], None - for arg in node.all_input_nodes: - if not self.already_executed_per_node(arg): - continue - if head_idx is None: - head_idx = self.get_stored_field_info(arg, field='head_idx', must_have=True) - else: - assert head_idx == self.get_stored_field_info(arg, field='head_idx', must_have=True), \ - "`head_idx` should be equal for all arguments in any related ops" - already_marked_args.append(arg) - - if not already_marked_args: - continue - - marked = False - if is_transpose(node): - marked = self.propagate_transpose(node, head_idx) - elif is_permute(node): - marked = self.propagate_permute(node, head_idx) - elif is_getitem(node): - marked = self.propagate_getitem(node, head_idx) - - # fall back - if not marked: - self.place_marker_per_node(node, {'head_idx' : head_idx}) - return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py new file mode 100644 index 0000000000..c24876bc1f --- /dev/null +++ b/optimum/fx/parallelization/core.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass, field +from typing import List, Any, List, Dict, Callable +import torch +import torch.nn as nn +import torch.distributed as dist +from functools import partial + +class HashableSlice: + def __init__(self, start : int, stop : int, step : int) -> None: + self.start = start + self.stop = stop + self.step = step + + def __hash__(self) -> int: + return hash(f'{self.start},{self.stop},{self.step}') + + def __eq__(self, value: object) -> bool: + return isinstance(value, HashableSlice) and self.start == value.start and \ + self.stop == value.stop and self.step == value.step + + def to_slice(self) -> None: + return slice(self.start, self.stop, self.step) + + +@dataclass +class ParameterMeta: + # parameter name + source : str = None + # which axis to index + dim : int = None + # index to slice the tensor + index : slice = None + + +@dataclass +class ParameterMapping: + id : int = None + meta : ParameterMeta = None + + +@dataclass +class ParallelParameterMapping(ParameterMapping): + # the axis being parallelized + parallel_dim : int = None + # for multi-source parameter mapping + mapping : Dict[HashableSlice, ParameterMeta] = field(default_factory=dict) + + +@dataclass +class ParallelExecutionCtx: + """ + Parallel execution context which contains runtime information. + + - example_inputs + A list of tensors which are used as example inputs for graphs captured by dynamo. + + - parallel_layer_cache + Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. + Note that we will build the cache in the first compilation process, and for recompilations + later on, we will directly replace the modules with their parallel counterparts in the cache, + because we have to make sure we don't initiate new parameters and replace original ones when + recompilation happens in training process. + + - parameter_mapping + Mapping between parameter ids and their correponding names in the original module. Note + that it changes as we create new parameters to replace original ones in the first compilation + process. It's useful because dynamo flattens the graph(which invalidates the parameter name + hierarchy) but the original parameters are kept. + + - weight_map + Mapping between parameter names and their locations on disk, useful when loading weights + from disk. + + - tp_group + Tensor parallel process group the current process belongs to. + + - compile_times + Number of compilation times happened during the whole process. + + - current_device + Device correpsonding to the current process. + """ + example_inputs : List[Any] = field(default_factory=list) + parallel_layer_cache : Dict[int, nn.Module] = field(default_factory=dict) + parameter_mapping : Dict[int, ParameterMapping] = field(default_factory=dict) + weight_map : Dict[str, str] = field(default_factory=dict) + tp_group : dist.ProcessGroup = None + compile_times : int = 0 + current_device : torch.device = None + + +@dataclass +class Config: + """ + Static config which contains instructions which do not change in runtime. + + - lint_and_recompile + Whether to run graph linting and module recompilation after every pass. + + - clean_markers_after_all_passes + Whether to clean markers of analytical passes after all passes have run. + + - weight_init_fn + Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, + if not provided weights loading path. + """ + lint_and_recompile : bool = True + clean_markers_after_all_passes : bool = True + weight_init_fn : Callable = partial(nn.init.normal_, std=0.02) diff --git a/optimum/fx/parallelization/core/__init__.py b/optimum/fx/parallelization/core/__init__.py deleted file mode 100644 index 6d84129776..0000000000 --- a/optimum/fx/parallelization/core/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .context import ExecutionCtx -from .config import PassPipelineConfig \ No newline at end of file diff --git a/optimum/fx/parallelization/core/config.py b/optimum/fx/parallelization/core/config.py deleted file mode 100644 index 958a19eff6..0000000000 --- a/optimum/fx/parallelization/core/config.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from dataclasses import dataclass -from typing import Callable -from torch.fx import Node - - -PARALLEL_INTERESTED_NODES = ( - ('call_module', nn.Linear), - ('call_module', nn.GELU), - ('call_module', nn.SiLU), - ('call_function', torch.matmul), - ('call_function', F.scaled_dot_product_attention), - ('call_function', F.gelu), - ('call_function', F.silu), -) - -@dataclass -class PassConfig: - is_active : bool = False - - -@dataclass -class ShapePropagationConfig(PassConfig): - pass - -@dataclass -class PostDominatorSolverConfig(PassConfig): - # only information of nodes satisfying `node_filter` will be kept - # for later uses in consideration of memory consumption - node_filter : Callable[[Node], bool] = lambda x : True - -@dataclass -class DependencySetSolverConfig(PassConfig): - # only information of nodes satisfying `node_filter` will be kept - # for later uses in consideration of memory consumption - node_filter : Callable[[Node], bool] = lambda x : True - -@dataclass -class ParallelLinearAnnotateConfig(PassConfig): - pass - -@dataclass -class AttentionHeadIndexPropagationConfig(PassConfig): - pass - -@dataclass -class PassPipelineConfig: - shape_propagation_config : ShapePropagationConfig = ShapePropagationConfig() - post_dominator_solver_config : PostDominatorSolverConfig = PostDominatorSolverConfig() - dependency_set_solver_config : DependencySetSolverConfig = DependencySetSolverConfig() - parellel_linear_annotate_config : ParallelLinearAnnotateConfig = ParallelLinearAnnotateConfig() - attention_head_index_propagation_config : AttentionHeadIndexPropagationConfig = AttentionHeadIndexPropagationConfig() \ No newline at end of file diff --git a/optimum/fx/parallelization/core/context.py b/optimum/fx/parallelization/core/context.py deleted file mode 100644 index 0f28bb9b18..0000000000 --- a/optimum/fx/parallelization/core/context.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass, field -import torch.distributed as dist -from typing import List, Any, List - - -@dataclass -class ExecutionCtx: - example_inputs : List[Any] = field(default_factory=list) - tp_group : dist.ProcessGroup = None \ No newline at end of file diff --git a/optimum/fx/parallelization/distributed/__init__.py b/optimum/fx/parallelization/distributed/__init__.py index f4efcae471..45b9d2837a 100644 --- a/optimum/fx/parallelization/distributed/__init__.py +++ b/optimum/fx/parallelization/distributed/__init__.py @@ -3,4 +3,5 @@ differentiable_identity, differentiable_all_reduce_sum, differentiable_scatter, + scatter, ) \ No newline at end of file diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index 94eacb7bd9..cb4c93569b 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -2,43 +2,56 @@ import torch.distributed as dist def all_reduce(group: dist.ProcessGroup, tensor : torch.Tensor) -> torch.Tensor: - word_size = dist.get_world_size(group) - if word_size == 1: + world_size = dist.get_world_size(group) + if world_size == 1: return tensor dist.all_reduce(tensor, group=group) return tensor - -def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim = -1) -> torch.Tensor: - word_size = dist.get_world_size(group) - if word_size == 1: +def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: return tensor rank = dist.get_rank(group = group) tensor = tensor.contiguous() - tensors = [torch.empty_like(tensor) for _ in range(word_size)] + tensors = [torch.empty_like(tensor) for _ in range(world_size)] tensors[rank] = tensor dist.all_gather(tensors, tensor, group=group) return torch.cat(tensors, dim=gather_dim) - -def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim = -1) -> torch.Tensor: - word_size = dist.get_world_size(group) - if word_size == 1: +def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: return tensor rank = dist.get_rank(group) - - assert tensor.size()[split_dim] % word_size == 0 - - tensors = torch.split(tensor, word_size, dim = split_dim) - + size = tensor.size() + assert size[split_dim] % world_size == 0 + tensors = torch.split(tensor, size[split_dim] // world_size, dim = split_dim) tensor = tensors[rank].contiguous() return tensor +def scatter(group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: + return tensor + + rank = dist.get_rank(group) + if rank == 0: + size = tensor.size() + assert size[scatter_dim] % world_size == 0 + tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim) + scatter_list = [tensor.contiguous() for tensor in tensors] + output_tensor = scatter_list[rank] + else: + scatter_list = None + dist.scatter(tensor=output_tensor, scatter_list=scatter_list, src=0, group=group) + return output_tensor + class DifferentiableIdentity(torch.autograd.Function): @staticmethod @@ -65,26 +78,26 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Any: class DifferentiableScatter(torch.autograd.Function): @staticmethod - def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim = -1) -> torch.Tensor: + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: ctx.group = group ctx.dim = dim - return split(group=group, tensor=tensor, split_dim = dim) + return split(group=group, tensor=tensor, split_dim=dim) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - return DifferentiableAllGather.apply(grad_output, group = ctx.group, dim = ctx.dim), None, None + return DifferentiableAllGather.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None class DifferentiableAllGather(torch.autograd.Function): @staticmethod - def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: ctx.group = group ctx.dim = dim - return all_gather(group = group, tensor = tensor, gather_dim = dim) + return all_gather(group=group, tensor=tensor, gather_dim=dim) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - return DifferentiableScatter.apply(grad_output, group = ctx.group, dim = ctx.dim), None, None + return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup): diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 6799dcd79b..2ab5c11849 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -1,32 +1,103 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.distributed as dist +from functools import partial +from typing import Callable +from ..core import ( + ParallelExecutionCtx, + ParallelParameterMapping, + ParameterMeta, +) from ..distributed import ( + differentiable_identity, differentiable_all_gather, differentiable_scatter, differentiable_all_reduce_sum, + scatter, ) class ColumnParallelLinear(nn.Linear): + """ + Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + ctx: parallel execution context which contains runtime information. + linear: the original linear module being replaced. + gather_output: whether gathering output in the end of forward. + init_fn: weight initialization function. + """ def __init__( self, - process_group: dist.ProcessGroup, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, + ctx: ParallelExecutionCtx, + linear: nn.Linear, gather_output: bool = True, + init_fn: Callable = partial(nn.init.normal_, mean=0, std=0.02), ) -> None: - self.process_group = process_group - self.word_size = process_group.size() - assert out_features % self.word_size == 0 + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + assert linear.out_features % world_size == 0 - super().__init__(in_features, out_features // self.word_size, bias, device, dtype) + in_features = linear.in_features + out_features = linear.out_features // world_size + bias = linear.bias is not None + device = ctx.current_device + dtype = linear.weight.dtype + + super().__init__(in_features, out_features, bias, device, dtype) self.gather_output = gather_output + tp_rank = dist.get_rank(self.process_group) + + parameter_mapping, key = ctx.parameter_mapping, id(linear.weight) + assert key in parameter_mapping, "should have run `initialize_paramter_mapping` after moving model to current device" + original_linear_weight_meta = parameter_mapping[key].meta + + # initialize the weight if not in weight_map + need_intialize = original_linear_weight_meta.source not in ctx.weight_map + if need_intialize: + # initialize on cpu + master_weight = torch.empty_like(linear.weight, device='cpu') + init_fn(master_weight) + with torch.no_grad(): + self.weight.copy_(master_weight[tp_rank * out_features : (tp_rank + 1) * out_features, :]) + + # update parameter mapping corresponding to original linear weight and bias + linear_weight_mapping = ParallelParameterMapping( + id=id(self.weight), + meta=ParameterMeta( + source=original_linear_weight_meta.source, + dim=0, + index=slice(tp_rank * out_features, (tp_rank + 1) * out_features) + ), + parallel_dim=0 + ) + parameter_mapping.pop(key) + parameter_mapping[linear_weight_mapping.id] = linear_weight_mapping + + if bias: + key = id(linear.bias) + assert key in parameter_mapping + original_linear_bias_meta = parameter_mapping[key].meta + linear_bias_mapping = ParallelParameterMapping( + id=id(self.bias), + meta=ParameterMeta( + source=original_linear_bias_meta.source, + dim=0, + index=slice(tp_rank * out_features, (tp_rank + 1) * out_features) + ), + parallel_dim=0 + ) + + parameter_mapping.pop(key) + parameter_mapping[linear_bias_mapping.id] = linear_bias_mapping + self.bias.zero_() def forward(self, input: torch.Tensor) -> torch.Tensor: + input = differentiable_identity(input, self.process_group) output = super().forward(input) if self.gather_output: output = differentiable_all_gather(output, self.process_group) @@ -34,27 +105,87 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class RowParallelLinear(nn.Linear): + """ + Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + ctx: parallel execution context which contains runtime information. + linear: the original lineat module being replaced. + input_is_parallel: whether the input tensor has already been parallelized. + init_fn: weight initialization function. + """ def __init__( self, - process_group: dist.ProcessGroup, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, + ctx: ParallelExecutionCtx, + linear: nn.Linear, input_is_parallel: bool = False, + init_fn: Callable = partial(nn.init.normal_, mean=0, std=0.02), ) -> None: - self.process_group = process_group - self.word_size = process_group.size() - assert in_features % self.word_size == 0 + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + assert linear.in_features % world_size == 0 + + in_features = linear.in_features // world_size + out_features = linear.out_features + bias = linear.bias is not None + device = ctx.current_device + dtype = linear.weight.dtype - super().__init__(in_features // self.word_size, out_features, bias, device, dtype) + super().__init__(in_features, out_features, bias, device, dtype) self.input_is_parallel = input_is_parallel + tp_rank = dist.get_rank(self.process_group) + + parameter_mapping, key = ctx.parameter_mapping, id(linear.weight) + assert key in parameter_mapping, "should have run `initialize_paramter_mapping` after moving model to current device" + original_linear_weight_meta = parameter_mapping[key].meta + + need_intialize = original_linear_weight_meta.source not in ctx.weight_map + if need_intialize: + # initialize on cpu + master_weight = torch.empty_like(linear.weight, device='cpu') + init_fn(master_weight) + with torch.no_grad(): + self.weight.copy_(master_weight[:, tp_rank * in_features : (tp_rank + 1) * in_features]) + + # update parameter mapping corresponding to original linear weight and bias + linear_weight_mapping = ParallelParameterMapping( + id=id(self.weight), + meta=ParameterMeta( + source=original_linear_weight_meta.source, + dim=1, + index=slice(tp_rank * in_features, (tp_rank + 1) * in_features) + ), + parallel_dim=1 + ) + parameter_mapping.pop(key) + parameter_mapping[linear_weight_mapping.id] = linear_weight_mapping + + if bias: + key = id(linear.bias) + assert key in parameter_mapping + linear_bias_mapping = parameter_mapping[key] + parameter_mapping.pop(key) + linear_bias_mapping.id = id(self.bias) + parameter_mapping[linear_bias_mapping.id] = linear_bias_mapping + self.bias.zero_() + def forward(self, input: torch.Tensor) -> torch.Tensor: if not self.input_is_parallel: input = differentiable_scatter(input, self.process_group) - output = super().forward(input) + output = F.linear(input, self.weight) output = differentiable_all_reduce_sum(output, self.process_group) + + if self.bias is not None: + output = output + self.bias return output diff --git a/optimum/fx/parallelization/pass_base.py b/optimum/fx/parallelization/pass_base.py deleted file mode 100644 index b79cf1ac39..0000000000 --- a/optimum/fx/parallelization/pass_base.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations -from typing import List, Any -from abc import ABC, abstractmethod -from torch.fx import GraphModule -from .core import ExecutionCtx, PassPipelineConfig - - -class PassBase(ABC): - @classmethod - def signature(cls) -> str: - return cls.__name__ - - @abstractmethod - def run(self, graph_module : GraphModule, **kwargs: Any) -> GraphModule: - raise NotImplementedError("Implement this first.") - - def __call__( - self, - graph_module: GraphModule, - ctx: ExecutionCtx = ExecutionCtx(), - lint_and_recompile: bool = True, - **kwargs - ) -> GraphModule: - graph_module = self.run(graph_module, ctx=ctx, **kwargs) - if lint_and_recompile: - graph_module.graph.lint() - graph_module.recompile() - return graph_module - - -def build_passes_from_config(config : PassPipelineConfig) -> List[PassBase]: - # we traverse the all pass configs in dependency-aware order and collect them if they are active - - from .analyze import ( - ShapePropagationPass, - PostDominatorSolverPass, - DependencySetSolverPass, - ParallelLinearAnnotatePass, - AttentionHeadIndexPropagationPass, - ) - passes = [] - if config.shape_propagation_config.is_active: - passes.append(ShapePropagationPass()) - if config.post_dominator_solver_config.is_active: - passes.append(PostDominatorSolverPass(node_filter=config.post_dominator_solver_config.node_filter)) - if config.dependency_set_solver_config.is_active: - passes.append(DependencySetSolverPass(node_filter=config.dependency_set_solver_config.node_filter)) - if config.parellel_linear_annotate_config.is_active: - passes.append(ParallelLinearAnnotatePass()) - if config.attention_head_index_propagation_config.is_active: - passes.append(AttentionHeadIndexPropagationPass()) - - return passes - - -class PassPipeline: - def __init__( - self, - passes : List[PassBase] = [], - config : PassPipelineConfig = None, - ) -> None: - if len(passes) and config is not None: - raise RuntimeError( - "You can't initiate both `passes` and `config` arguments because there might be" - " conflicts, and `PassPipeline` won't try detecting and correcting it." - ) - if config is not None: - passes = build_passes_from_config(config) - - self._passes = passes - - @classmethod - def from_config(cls, config : PassPipelineConfig): - return cls(config=config) - - def __iter__(self,): - return self._passes.__iter__() - - def __call__( - self, - graph_module: GraphModule, - ctx: ExecutionCtx = ExecutionCtx(), - lint_and_recompile : bool = True, - clean_markers_after_all_passes : bool = True, - **kwargs: Any - ) -> GraphModule: - for PASS in self._passes: - graph_module = PASS( - graph_module=graph_module, - ctx=ctx, - lint_and_recompile=lint_and_recompile - ) - - from .analyze import AnalyzeBase - - if clean_markers_after_all_passes: - for PASS in self._passes: - if isinstance(PASS, AnalyzeBase): - PASS.clean_all(graph_module) - return graph_module \ No newline at end of file diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py new file mode 100644 index 0000000000..2d42f53441 --- /dev/null +++ b/optimum/fx/parallelization/passes.py @@ -0,0 +1,449 @@ +from __future__ import annotations +from typing import List, Any, Dict +from abc import ABC, abstractmethod +from torch.fx import GraphModule, Graph, Node +import torch.nn as nn +from .utils import ( + stable_topological_sort, + is_transpose, + is_permute, + is_linear, + is_shape_consumer, + is_shape_generator, +) + +from .core import ParallelExecutionCtx, Config +from .parallel_layers import ColumnParallelLinear, RowParallelLinear + + +class PassBase(ABC): + """ + Base class for parallelization targeted passes + """ + @classmethod + def signature(cls) -> str: + return cls.__name__ + + @abstractmethod + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + """ + Args: + graph_module (`GraphModule`): + graph module before processing. + ctx (`ParallelExecutionCtx`): + dynamic execution context which gathers and preserves information along processing. + config (`Config`): + static config to include instructions which persists the whole process. + + Returns: + GraphModule: graph module after processed by the current pass. + """ + raise NotImplementedError + + def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + graph_module = self.run(graph_module, ctx=ctx, config=config) + if config.lint_and_recompile: + graph_module.graph.lint() + graph_module.recompile() + return graph_module + + +class AnalyzeBase(PassBase): + """ + Base class for passes which only run for analytical purposes and preserve graph structure + during processing. Analytical passes are often prerequisite passes which provide information + for passes later on to actually change the graph. + + Passes inheriting from `AnalyBase` places the class signature as a meta key in `node.meta`, + which is a dict storing meta information related with a fx Node, such as the shape and dtype of + output. Look-up APIs are exposed as classmethod so that passes using them won't need to create + concrete instances. + """ + @classmethod + def meta_key(cls) -> str: + # place class-wise unique meta_key in `meta` to prevent duplicate fields + return cls.signature() + + @classmethod + def get_stored_field_info(cls, node: Node, field: Any, must_have: bool = False) -> Any: + if not cls.already_executed_per_node(node): + if not must_have: + return None + else: + raise RuntimeError( + f"Can't find information related with {cls.__name__} in the current node `{node}` " + f"make sure {cls.__name__} has run and marked it" + ) + + info : Dict[Any, Any] = node.meta[cls.meta_key()] + if field not in info: + raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") + + return info[field] + + @classmethod + def already_executed_per_node(cls, node: Node) -> None: + return cls.meta_key() in node.meta + + def place_marker_per_node(self, node: Node, info: Dict[Any, Any]) -> None: + if self.already_executed_per_node(node): + raise RuntimeError( + f"Node {node} has already been marked by the current pass, check if " + "the current pass has already been executed in the pipeline" + ) + + node.meta[self.meta_key()] = info + + def clear_marker_per_node(self, node: Node) -> None: + key = self.meta_key() + if key in node.meta: + node.meta.pop(key) + + def clean_all(self, graph_module: GraphModule) -> None: + g : Graph = graph_module.graph + for node in g.nodes: + self.clear_marker_per_node(node) + + +class ParallelLinearAnnotatePass(AnalyzeBase): + """ + A pass which tries to automatically identify parallel linears in the graph by grouping linears as + `upstream` nodes and `downstream` nodes, and `upstream` nodes are marked as `ColumnLinear`, `downstream` + nodes are marked as `RowLinear`. + + Typical examples in transformer models: + + Attention Bert-style MLP Llama-style MLP + __________________________________________________________________________ + Linear Linear Linear Linear + \\ / | \\ --> upstream + Matmul Linear Activation Activation Linear + __________________________________________________________________________ + \\ / | \\ / + \\ / ___________ \\ / + Matmul / Linear \ Mul + | / \ | + _______________________________/ \___________________________ + Linear Linear --> downstream + + Note that there are some patterns that can not be clearly marked, like this one: + + Linear + | \\ + | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` + | / + Add + | + Linear + + For patterns like this we will be preservative and raise errors directly because we don't know how to parallelize + it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution + even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution + should work fairly well. + """ + def try_form_parallel_linear_groups(self, linear: Node) -> None: + """ + We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down + recusively to find all the potential `downstream` linears, note that once we have reached a linear, the recursion stops. + And the newly found `downstream` linears are used as new seeds to traverse upwards to find all the potential `upstream` + linears, the process goes on until number of linears on both sides converges. + Args: + linear (Node): the first linear node used as `upstream` node seed to form closure. + + Raises: + RuntimeError: + raises runtime error when the pattern itself is not clear, there are no clear boundaries that can be drawn. + """ + upstream_nodes, downstream_nodes = {linear}, set() + + seeds, next_seeds = [(linear, 'down')], [] + + def traverse(start: Node, cur: Node, direction = 'down'): + if is_linear(cur) and cur is not start: + if direction == 'up' and cur not in upstream_nodes: + upstream_nodes.add(cur) + next_seeds.append((cur, 'down')) + elif direction == 'down' and cur not in downstream_nodes: + downstream_nodes.add(cur) + next_seeds.append((cur, 'up')) + return + + + next_nodes = cur.all_input_nodes if direction == 'up' else cur.users + for node in next_nodes: + # we should ignore shape-related dependencies + if is_shape_generator(node): + continue + traverse(start, node, direction) + + while seeds: + next_seeds = [] + for node, direction in seeds: + traverse(start=node, cur=node, direction=direction) + seeds = next_seeds + + if any([self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)]) or \ + (upstream_nodes & downstream_nodes): + raise RuntimeError( + "Failed to automatically group and parallelize ops in graph in greedy way: " + "no clear boudaries between `upstream` and `downstream` ops." + ) + + for node in upstream_nodes: + self.place_marker_per_node(node, {'axis' : 'column', 'gather_output' : False if downstream_nodes else True}) + + for node in downstream_nodes: + self.place_marker_per_node(node, {'axis' : 'row', 'input_is_parallel' : True}) + + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + graph: Graph = graph_module.graph + stable_topological_sort(graph) + for node in graph.nodes: + if is_linear(node) and not self.already_executed_per_node(node): + self.try_form_parallel_linear_groups(node) + + return graph_module + + +class ParallelAxisPropagationPass(AnalyzeBase): + """ + A pass tries to track which axis is being parallelized in the dataflow. For transformer models, the + axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for + Q and K matrice which need to swap the sequence length axis and head axis to do the attention computation, + so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel + axis after these operations. + """ + def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: + dims = node.meta['example_value'].dim() + if 'dim0' in node.kwargs and 'dim1' in node.kwargs: + dim0, dim1, dims = node.kwargs['dim0'], node.kwargs['dim1'] + dim0 = (dim0 + dims) % dims + dim1 = (dim1 + dims) % dims + if dim0 == parallel_axis: + self.place_marker_per_node(node, {'parallel_axis' : dim1}) + return True + elif dim1 == parallel_axis: + self.place_marker_per_node(node, {'parallel_axis' : dim0}) + return True + return False + + if len(node.args) == 3: + if parallel_axis not in node.args and parallel_axis - dims not in node.args: + return False + for arg in node.args: + if isinstance(arg, int) and (arg + dims) % dims != parallel_axis: + self.place_marker_per_node(node, {'parallel_axis' : (arg + dims) % dims}) + return True + + return False + + def propagate_permute(self, node: Node, parallel_axis: int) -> bool: + if 'dims' in node.kwargs: + dims = node.kwargs['dims'] + else: + dims = list(node.args[1]) if isinstance(node.args[1], tuple) else [arg for arg in node.args if isinstance(arg,int)] + + dim_len = node.meta['example_value'].dim() + dims = [dim + dim_len if dim < 0 else dim for dim in dims] + + for i,dim in enumerate(dims): + if dim == parallel_axis: + self.place_marker_per_node(node, {'parallel_axis' : i}) + return True + return False + + def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: + slices = node.args[1] + dims = node.meta['example_value'].dim() + assert parallel_axis < dims + inc, i, j = 0, 0, 0 + + while i < parallel_axis and j < len(slices): + if isinstance(slices[j], int): + inc -= 1 + i += 1 + elif slices[j] is None: + inc += 1 + elif slices[j] is Ellipsis: + i = dims + k = j + while k < len(slices): + if slices[k] is not Ellipsis: + i -= 1 + k += 1 + else: + i += 1 + j += 1 + + if inc != 0: + assert parallel_axis + inc < dims and parallel_axis + inc >= 0 + self.place_marker_per_node(node, {'parallel_axis' : parallel_axis + inc}) + return True + return False + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + g: Graph = graph_module.graph + stable_topological_sort(g) + + for node in g.nodes: + if ParallelLinearAnnotatePass.already_executed_per_node(node): + # start propagating at ColumnLinear, marking the beginning of parallelized region + axis = ParallelLinearAnnotatePass.get_stored_field_info(node, field='axis', must_have=True) + if axis == 'column': + self.place_marker_per_node(node, {'parallel_axis' : 2}) + # stop propagating at RowLinear, concluding the ending of parallelized region + else: + continue + else: + already_marked_args, parallel_axis = [], None + for arg in node.all_input_nodes: + if not self.already_executed_per_node(arg): + continue + if parallel_axis is None: + parallel_axis = self.get_stored_field_info(arg, field='parallel_axis', must_have=True) + else: + assert parallel_axis == self.get_stored_field_info(arg, field='parallel_axis', must_have=True), \ + "`parallel_axis` should be equal for all arguments in any related ops" + already_marked_args.append(arg) + + if not already_marked_args: + continue + + marked = False + if is_transpose(node): + marked = self.propagate_transpose(node, parallel_axis) + elif is_permute(node): + marked = self.propagate_permute(node, parallel_axis) + + # fall back + if not marked: + self.place_marker_per_node(node, {'parallel_axis' : parallel_axis}) + return graph_module + + +class ParallelLinearReplacePass(PassBase): + """ + A pass which modifies graph according to information provided by previous analytical passes, + in general it does two things for now: + 1. replace linears with their parallel counterparts. + 2. modify hard-coded arguments like the number of attenton heads in the graph by dividing it by parallelism level. + """ + @staticmethod + def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + graph_module = node.graph.owning_module + axis = ParallelLinearAnnotatePass.get_stored_field_info(node, field='axis') + if axis is None: + return + + assert axis in {'column', 'row'} + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod : nn.Linear = graph_module.get_submodule(node.target) + key, layer_cache = id(mod), ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + if axis == 'column': + gather_output = ParallelLinearAnnotatePass.get_stored_field_info(node, field='gather_output', must_have=True) + new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn) + else: + input_is_parallel = ParallelLinearAnnotatePass.get_stored_field_info(node, field='input_is_parallel', must_have=True) + new_mod = RowParallelLinear(ctx, mod, input_is_parallel, config.weight_init_fn) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) + + + @staticmethod + def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: + + def extract_shape_from_node(node: Node) -> List[Any]: + if 'size' in node.kwargs: + return list(node.kwargs['size']) + elif 'shape' in node.kwargs: + return list(node.kwargs['shape']) + elif isinstance(node.args[1], tuple): + return [idx for idx in node.args[1]] + else: + return [idx for idx in node.args[1:]] + + def update(node: Node, new_shape: List[Any], parallel_axis: int): + if 'size' in node.kwargs: + node.update_kwarg('size', tuple(new_shape)) + elif 'shape' in node.kwargs: + node.update_kwarg('shape', tuple(new_shape)) + elif isinstance(node.args[1], tuple): + node.update_arg(1, tuple(new_shape)) + else: + node.update_arg(parallel_axis + 1, shape[parallel_axis]) + + parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field='parallel_axis') + if parallel_axis is None: + return + + shape = extract_shape_from_node(node) + assert parallel_axis < len(shape) + if not isinstance(shape[parallel_axis], int) or shape[parallel_axis] == -1: + return + world_size = ctx.tp_group.size() + assert shape[parallel_axis] % world_size == 0 + shape[parallel_axis] = shape[parallel_axis] // world_size + update(node, shape, parallel_axis) + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + for node in graph_module.graph.nodes: + if is_linear(node): + self.handle_linear(node, ctx, config) + # correct the attention head num in parallel setting + elif is_shape_consumer(node): + self.handle_hard_coded_axis_param(node, ctx) + return graph_module + + +def build_parallel_pass_pipeline() -> PassPipeline: + """ + Ensemble a pass pipeline which contains the following passes: + + 1. `ParallelLinearAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` + 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes + + Returns: + PassPipeline: the pipeline used for automatic parallelism. + """ + return PassPipeline([ + ParallelLinearAnnotatePass(), + ParallelAxisPropagationPass(), + ParallelLinearReplacePass() + ]) + + +class PassPipeline: + """ + `PassPipeline` ensembles a list of passes and execute them one by one as provided in the list, + it can be iterated and appended after initialization for flexibility. + """ + def __init__(self, passes : List[PassBase] = []) -> None: + self._passes = passes + + def __iter__(self,): + return self._passes.__iter__() + + def append(self, PASS: PassBase): + self._passes.append(PASS) + + def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + for PASS in self._passes: + graph_module = PASS(graph_module=graph_module, ctx=ctx, config=config) + + if config.clean_markers_after_all_passes: + for PASS in self._passes: + if isinstance(PASS, AnalyzeBase): + PASS.clean_all(graph_module) + return graph_module diff --git a/optimum/fx/parallelization/transform.py b/optimum/fx/parallelization/transform.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index d22de264cb..5c64568e0b 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -1,8 +1,19 @@ +import operator +import importlib + import torch import torch.nn as nn import torch.nn.functional as F -from torch.fx import Node -import operator +from typing import Dict, Callable, List, Union +from torch.fx import Node, Graph +from functools import wraps +from collections import defaultdict +from itertools import chain +from .core import ( + ParallelExecutionCtx, + ParameterMapping, + ParameterMeta, +) def is_linear(node: Node) -> bool: if node.op != 'call_module': @@ -10,23 +21,11 @@ def is_linear(node: Node) -> bool: mod = node.graph.owning_module return isinstance(mod.get_submodule(node.target), nn.Linear) -def is_matmul(node: Node) -> bool: - if node.op != 'call_function': - return False - return node.target is torch.matmul - -def is_sdpa(node: Node) -> bool: - if node.op != 'call_function': - return False - return node.target is torch._C._nn.scaled_dot_product_attention - -def is_activation(node: Node) -> bool: - if node.op == 'call_function': - return node.target in {F.gelu, F.silu, F.sigmoid, F.relu, } - elif node.op == 'call_module': - mod = node.graph.owning_module - return isinstance(mod.get_submodule(node.target), (nn.GELU, nn.SiLU, nn.Sigmoid, nn.ReLU)) - return False +def is_shape_consumer(node: Node) -> bool: + if node.op == 'call_method': + return node.target in {'view', 'reshape', 'expand', 'resize', 'resize_'} + elif node.op == 'call_function': + return node.target in {torch.reshape} def is_transpose(node: Node) -> bool: if node.op == 'call_method': @@ -43,4 +42,170 @@ def is_permute(node: Node) -> bool: return False def is_getitem(node: Node) -> bool: - return node.op == 'call_function' and node.target is operator.getitem \ No newline at end of file + return node.op == 'call_function' and node.target is operator.getitem + +def is_output(node: Node) -> bool: + return node.op == 'output' + +def is_shape_generator(node: Node) -> bool: + return node.op == 'call_method' and node.target == 'size' + +def stable_topological_sort(graph: Graph): + + def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = list() + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + +def meta_init(init_fn): + @wraps(init_fn) + def wrapper(*args, **kwargs): + kwargs["device"] = kwargs.pop("device", torch.device("meta")) + return init_fn(*args, **kwargs) + + return wrapper + +@wraps(nn.Linear.forward) +def meta_aware_linear_forward(*args, **kwargs): + self = args[0] + input = args[1] + + if self.weight.device != torch.device('meta'): + return F.linear(input, self.weight, self.bias) + + orig_device = input.device + input = input.to("meta") + meta_output = F.linear(input, self.weight, self.bias) + return torch.empty_like(meta_output, device=orig_device) + + +class MetaAwareMethodsPatcher: + """ + A patcher class which patches `__init__` and `forward` methods on modules which will be put on meta + devices for memory efficiency purposes during initialization. + + Note that for `__init__` method, it can be unpatched once we have finished the initialization of the + model, however, for `forward`, we need it to constantly being patched during the whole process in case + recompile happens and torch dynamo needs meta-aware `forward` to be able to re-capture the graph. + """ + methods_to_patch : Dict[str, Callable] = [ + ("torch.nn.Linear.__init__", meta_init(torch.nn.Linear.__init__)), + ("torch.nn.Linear.forward", meta_aware_linear_forward), + ] + + def __init__(self) -> None: + self.patching_specs = [] + for orig, patch_fn in self.methods_to_patch: + module_qualified_name, attribute_name = orig.rsplit(".", maxsplit=1) + try: + module = importlib.import_module(module_qualified_name) + except ModuleNotFoundError as e: + module_qualified_name, module_attribute_name = module_qualified_name.rsplit( + ".", maxsplit=1 + ) + module = importlib.import_module(module_qualified_name) + try: + module = getattr(module, module_attribute_name) + except AttributeError: + raise e + orig_fn = getattr(module, attribute_name) + + # Module, Attribute, Patchee, Patcher, Status + self.patching_specs.append([module, attribute_name, orig_fn, patch_fn, False]) + + def _patch(self, identifier: str): + for spec in self.patching_specs: + # already patched + if spec[-1]: + continue + if identifier in spec[1]: + setattr(spec[0], spec[1], spec[3]) + spec[-1] = True + + def _unpatch(self, identifier: str): + for spec in self.patching_specs: + # already patched + if not spec[-1]: + continue + if identifier in spec[1]: + setattr(spec[0], spec[1], spec[2]) + spec[-1] = False + + def patch_meta_init(self,): + self._patch("init") + + def patch_meta_forward(self,): + self._patch("forward") + + def unpatch_meta_init(self,): + self._unpatch("init") + + def unpatch_meta_forward(self,): + self._unpatch("forward") + + def __enter__(self,): + self.patch_meta_init() + self.patch_meta_forward() + + def __exit__(self, exc_type, exc_value, traceback): + self.unpatch_meta_init() + + +def initialize_parameter_mapping(model: nn.Module, ctx: ParallelExecutionCtx) -> None: + mapping = ctx.parameter_mapping + + for name, tensor in chain(model.named_parameters(), model.named_buffers()): + mapping[id(tensor)] = ParameterMapping(id = id(tensor), meta = ParameterMeta(source=name)) + +def move_model_to_device(model: nn.Module, device: Union[torch.device, str]): + # move everything except tensors on meta devices on current device + # this function should be called before `intialize_parameter_mapping` + for name, tensor in chain(model.named_parameters(), model.named_buffers()): + if tensor.device == torch.device("meta"): + continue + splits = name.rsplit(".", maxsplit=1) + if len(splits) == 1: + parent_mod = model + attr_name = splits[0] + else: + qualified_name = splits[0] + parent_mod = model.get_submodule(qualified_name) + attr_name = splits[1] + new_tensor = tensor.to(device) + if isinstance(tensor, nn.Parameter): + new_tensor = nn.Parameter(new_tensor) + setattr(parent_mod, attr_name, new_tensor) diff --git a/tests/fx/parallelization/dist_utils.py b/tests/fx/parallelization/dist_utils.py new file mode 100644 index 0000000000..a9abe4dd34 --- /dev/null +++ b/tests/fx/parallelization/dist_utils.py @@ -0,0 +1,55 @@ +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from typing import Callable, List, Optional +from transformers import set_seed + +SEED = 42 +NUM_AVAILABLE_DEVICES = torch.cuda.device_count() + + +def dist_init( + rank: int, + world_size: int, + backend: str = 'nccl', + master_addr: str = '127.0.0.1', + master_port: str = '29500', +): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + dist.init_process_group( + backend=backend, + init_method="env://", + world_size=world_size, + rank=rank, + ) + + torch.cuda.set_device(rank) + +def runner(rank: int, fn:Callable, deterministic: bool, *args, **kwargs): + if deterministic: + set_seed(SEED) + fn(rank, *args, **kwargs) + +def spawn(world_size: int, fn: Callable, *args, deterministic: bool = False): + mp.spawn(fn=runner, args=(fn, deterministic, world_size, *args), nprocs=world_size, join=True) + +def tearDown(group: Optional[dist.ProcessGroup] = None): + dist.destroy_process_group(group) + +def gather_at_main_process(tensor: torch.Tensor, group: dist.ProcessGroup, rank: int, world_size: int) -> List[torch.Tensor]: + if world_size == 1: + return [tensor] + + tensor = tensor.contiguous() + if rank == 0: + tensors = [torch.empty_like(tensor) for _ in range(world_size)] + tensors[rank] = tensor + else: + tensors = None + dist.gather(tensor=tensor, gather_list=tensors, dst=0, group=group) + return tensors diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py new file mode 100644 index 0000000000..f865f3d35b --- /dev/null +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -0,0 +1,190 @@ +import unittest +import torch +import torch.distributed as dist +from typing import Type +from functools import partial +from transformers import ( + PretrainedConfig, + PreTrainedModel, + LlamaConfig, + MistralConfig, + LlamaForCausalLM, + MistralForCausalLM, + set_seed, +) +from parameterized import parameterized +from optimum.fx.parallelization import parallelize_backend, ParallelExecutionCtx, Config +from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher, move_model_to_device, initialize_parameter_mapping +from dist_utils import ( + dist_init, + tearDown, + spawn, + gather_at_main_process, + NUM_AVAILABLE_DEVICES, + SEED +) + + +DUMMY_MODELS_TO_TEST = ( + (LlamaForCausalLM, LlamaConfig(), ), + (MistralForCausalLM, MistralConfig(), ), +) + + +def dummify(config: PretrainedConfig): + config.num_hidden_layers = 2 + config.use_cache = False + config.output_attentions = False + config.output_hidden_states = False + +def run_test_all_rank_results_match(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + dummify(model_config) + + # initialize default group + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type='cuda', index=torch.cuda.current_device()) + ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + + inputs = { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), + "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), + } + + # this will initialize all linears on meta device + with MetaAwareMethodsPatcher(): + model = model_cls(model_config) + model.eval() + # move model to current device, with linears still on meta, and intialize parameter mapping + move_model_to_device(model, device=device) + initialize_parameter_mapping(model, ctx=ctx) + + model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + logits = model(**inputs)[0] + tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) + + # check results at main worker process + if rank == 0: + assert len(tensors) == world_size + for i in range(1, world_size): + torch.testing.assert_close(tensors[i - 1].cpu(), tensors[i].cpu(), rtol=1e-4, atol=1e-4) + + dist.barrier(tp_group) + tearDown(tp_group) + +def run_test_parameters_persist_bewteen_recompile(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + dummify(model_config) + + # initialize default group + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type='cuda', index=torch.cuda.current_device()) + ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + + inputs = { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), + "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), + } + + # different shape to trigger recompile + another_inputs = { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 11), device=device), + "attention_mask": torch.ones((1, 11), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, 11, device=device).unsqueeze(0), + } + + # this will initialize all linears on meta device + with MetaAwareMethodsPatcher(): + model = model_cls(model_config) + model.eval() + # move model to current device, with linears still on meta + move_model_to_device(model, device=device) + initialize_parameter_mapping(model, ctx=ctx) + + model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + model(**inputs) + + parameter_ids = set([id(param) for _, param in model.named_parameters()]) + model(**another_inputs) + + parameter_ids_after_recompile = set([id(param) for _, param in model.named_parameters()]) + assert parameter_ids == parameter_ids_after_recompile + + dist.barrier(tp_group) + tearDown(tp_group) + +def run_test_parallel_results_matches_non_parallel(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + dummify(model_config) + + dist_init(rank, world_size) + tp_group = dist.new_group(ranks=[rank]) + + # prepare config and context + device = torch.device(type='cuda', index=torch.cuda.current_device()) + ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + + inputs = { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), + "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), + } + + set_seed(SEED) + # non-parallel local forward + with MetaAwareMethodsPatcher(): + model = model_cls(model_config) + model.eval() + + # move model to current device, with linears still on meta + move_model_to_device(model, device=device) + initialize_parameter_mapping(model, ctx=ctx) + + model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + logits = model(**inputs)[0] + + del model + + tp_group = dist.new_group() + set_seed(SEED) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + with MetaAwareMethodsPatcher(): + model = model_cls(model_config) + model.eval() + + # move model to current device, with linears still on meta + move_model_to_device(model, device=device) + initialize_parameter_mapping(model, ctx=ctx) + + model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + parallel_logits = model(**inputs)[0] + + torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) + + dist.barrier(tp_group) + tearDown() + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") +def test_all_rank_results_match(model_cls, config, ): + for world_size in [1, 2, 4, 8]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_all_rank_results_match, model_cls, config, deterministic=True) + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") +def test_parameters_persist_bewteen_recompile(model_cls, config, ): + for world_size in [1, 2, 4, 8]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_parameters_persist_bewteen_recompile, model_cls, config, deterministic=False) + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") +def test_parallel_results_matches_non_parallel(model_cls, config, ): + # world_size == 2 is enough + spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, config, deterministic=True) \ No newline at end of file From 0876f5d6246f00f4dde55b92c4729e8bbb7bb442 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 9 Jul 2024 00:08:59 +0200 Subject: [PATCH 05/25] add embedding & weight tie --- optimum/fx/parallelization/__init__.py | 24 +- optimum/fx/parallelization/core.py | 177 +++++--- .../parallelization/distributed/__init__.py | 18 +- .../parallelization/distributed/dist_ops.py | 51 ++- .../parallel_layers/__init__.py | 17 +- .../parallel_layers/embedding.py | 95 +++++ .../parallelization/parallel_layers/linear.py | 200 ++++----- optimum/fx/parallelization/passes.py | 389 ++++++++++++------ optimum/fx/parallelization/utils.py | 172 ++++++-- optimum/onnxruntime/runs/__init__.py | 6 +- tests/fx/parallelization/dist_utils.py | 34 +- .../parallelization/test_tensor_parallel.py | 148 ++++--- 12 files changed, 931 insertions(+), 400 deletions(-) create mode 100644 optimum/fx/parallelization/parallel_layers/embedding.py diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py index ee32f3915d..7f3d0e737b 100644 --- a/optimum/fx/parallelization/__init__.py +++ b/optimum/fx/parallelization/__init__.py @@ -1,11 +1,29 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + import torch from torch.fx import GraphModule -from typing import List -from .core import ParallelExecutionCtx, Config + +from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline -def parallelize_backend(graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config): +def parallelize_backend( + graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config +) -> GraphModule: ctx.example_inputs = example_inputs pass_pipeline = build_parallel_pass_pipeline() graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index c24876bc1f..a040123bfe 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -1,49 +1,97 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass, field -from typing import List, Any, List, Dict, Callable +from functools import partial +from typing import Any, Callable, Dict, List, Optional + import torch -import torch.nn as nn import torch.distributed as dist -from functools import partial +import torch.nn as nn + class HashableSlice: - def __init__(self, start : int, stop : int, step : int) -> None: + def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None: self.start = start self.stop = stop self.step = step def __hash__(self) -> int: - return hash(f'{self.start},{self.stop},{self.step}') + return hash(f"{self.start},{self.stop},{self.step}") def __eq__(self, value: object) -> bool: - return isinstance(value, HashableSlice) and self.start == value.start and \ - self.stop == value.stop and self.step == value.step - - def to_slice(self) -> None: + return ( + isinstance(value, HashableSlice) + and self.start == value.start + and self.stop == value.stop + and self.step == value.step + ) + + def to_slice(self) -> slice: return slice(self.start, self.stop, self.step) @dataclass -class ParameterMeta: - # parameter name - source : str = None - # which axis to index - dim : int = None - # index to slice the tensor - index : slice = None +class ParameterSlice: + """ + A slice of parameter which corresponds to a tensor in weight dict. Only support slicing + along a specific axis (the potential parallel axis) right now. + Attributes: + - source (`Optional[str]`): + Original parameter name which can be found in the weight dict. -@dataclass -class ParameterMapping: - id : int = None - meta : ParameterMeta = None + - index (`Optional[slice]`): + Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same + layout as their correspondings in memory. + """ + + source: Optional[str] = None + index: Optional[slice] = None @dataclass -class ParallelParameterMapping(ParameterMapping): - # the axis being parallelized - parallel_dim : int = None - # for multi-source parameter mapping - mapping : Dict[HashableSlice, ParameterMeta] = field(default_factory=dict) +class ParameterMeta: + """ + Parameter meta information. + + Attributes: + - is_tied (`bool`, defaults to `False`): + Whether the parameter is shared accross multiple modules. + + - is_modified_meta (`bool`, defaults to `False`): + Whether the meta has already been modified since initialization. + + - need_initialize (`bool`, defaults to `False`): + Whether need to manually initialize weights if not provided in weight map. + + - init_fn (`Optional[Callable]`): + Initialization function, can override `weight_init_fn` in `Config` if not None. + + - dim (`int`, defaults to `0`): + Axis on which `mapping` is based. + + - mapping (`Dict[HashableSlice, ParameterSlice]`): + Mapping between the current parameter and weight tensor stored in weight map. + """ + + is_tied: bool = False + is_modified_meta: bool = False + need_initialize: bool = False + init_fn: Optional[Callable] = None + dim: int = 0 + mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict) @dataclass @@ -51,42 +99,37 @@ class ParallelExecutionCtx: """ Parallel execution context which contains runtime information. - - example_inputs - A list of tensors which are used as example inputs for graphs captured by dynamo. - - - parallel_layer_cache - Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. - Note that we will build the cache in the first compilation process, and for recompilations - later on, we will directly replace the modules with their parallel counterparts in the cache, - because we have to make sure we don't initiate new parameters and replace original ones when - recompilation happens in training process. + Attributes: + - tp_group (`dist.ProcessGroup`): + Tensor parallel process group the current process belongs to. - - parameter_mapping - Mapping between parameter ids and their correponding names in the original module. Note - that it changes as we create new parameters to replace original ones in the first compilation - process. It's useful because dynamo flattens the graph(which invalidates the parameter name - hierarchy) but the original parameters are kept. + - current_device (`torch.device`): + Device correpsonding to the current process. - - weight_map - Mapping between parameter names and their locations on disk, useful when loading weights - from disk. + - example_inputs (`List[Any]`): + A list of tensors which are used as example inputs for graphs captured by dynamo. - - tp_group - Tensor parallel process group the current process belongs to. + - parallel_layer_cache (`Dict[int, nn.Module]`): + Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. + Note that we will build the cache in the first compilation process, and for recompilations + later on, we will directly replace the modules with their parallel counterparts in the cache, + because we have to make sure we don't initiate new parameters and replace original ones when + recompilation happens in training process. - - compile_times - Number of compilation times happened during the whole process. + - weight_map (`Dict[str, str]`): + Mapping between parameter names and their locations on disk, useful when loading weights + from disk. - - current_device - Device correpsonding to the current process. + - compile_times (`int`, defaults to `0`): + Number of compilation times happened during the whole process. """ - example_inputs : List[Any] = field(default_factory=list) - parallel_layer_cache : Dict[int, nn.Module] = field(default_factory=dict) - parameter_mapping : Dict[int, ParameterMapping] = field(default_factory=dict) - weight_map : Dict[str, str] = field(default_factory=dict) - tp_group : dist.ProcessGroup = None - compile_times : int = 0 - current_device : torch.device = None + + tp_group: dist.ProcessGroup + current_device: torch.device + example_inputs: List[Any] = field(default_factory=list) + parallel_layer_cache: Dict[int, nn.Module] = field(default_factory=dict) + weight_map: Dict[str, str] = field(default_factory=dict) + compile_times: int = 0 @dataclass @@ -94,16 +137,18 @@ class Config: """ Static config which contains instructions which do not change in runtime. - - lint_and_recompile - Whether to run graph linting and module recompilation after every pass. + Attributes: + - lint_and_recompile (`bool`, defaults to `True`): + Whether to run graph linting and module recompilation after every pass. + + - clean_markers_after_all_passes (`bool`, defaults to `True`): + Whether to clean markers of analytical passes after all passes have run. - - clean_markers_after_all_passes - Whether to clean markers of analytical passes after all passes have run. - - - weight_init_fn - Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, - if not provided weights loading path. + - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) + Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, + if not provided weights loading path. """ - lint_and_recompile : bool = True - clean_markers_after_all_passes : bool = True - weight_init_fn : Callable = partial(nn.init.normal_, std=0.02) + + lint_and_recompile: bool = True + clean_markers_after_all_passes: bool = True + weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) diff --git a/optimum/fx/parallelization/distributed/__init__.py b/optimum/fx/parallelization/distributed/__init__.py index 45b9d2837a..3734013669 100644 --- a/optimum/fx/parallelization/distributed/__init__.py +++ b/optimum/fx/parallelization/distributed/__init__.py @@ -1,7 +1,21 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .dist_ops import ( differentiable_all_gather, - differentiable_identity, differentiable_all_reduce_sum, + differentiable_identity, differentiable_scatter, scatter, -) \ No newline at end of file +) diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index cb4c93569b..69abe68bca 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -1,27 +1,44 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch import torch.distributed as dist -def all_reduce(group: dist.ProcessGroup, tensor : torch.Tensor) -> torch.Tensor: + +def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor: world_size = dist.get_world_size(group) if world_size == 1: return tensor - + dist.all_reduce(tensor, group=group) return tensor + def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor: world_size = dist.get_world_size(group) if world_size == 1: return tensor - rank = dist.get_rank(group = group) + rank = dist.get_rank(group=group) tensor = tensor.contiguous() tensors = [torch.empty_like(tensor) for _ in range(world_size)] tensors[rank] = tensor - + dist.all_gather(tensors, tensor, group=group) return torch.cat(tensors, dim=gather_dim) + def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor: world_size = dist.get_world_size(group) if world_size == 1: @@ -30,12 +47,15 @@ def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) - rank = dist.get_rank(group) size = tensor.size() assert size[split_dim] % world_size == 0 - tensors = torch.split(tensor, size[split_dim] // world_size, dim = split_dim) + tensors = torch.split(tensor, size[split_dim] // world_size, dim=split_dim) tensor = tensors[rank].contiguous() return tensor -def scatter(group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0) -> torch.Tensor: + +def scatter( + group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0 +) -> torch.Tensor: world_size = dist.get_world_size(group) if world_size == 1: return tensor @@ -46,7 +66,7 @@ def scatter(group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch assert size[scatter_dim] % world_size == 0 tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim) scatter_list = [tensor.contiguous() for tensor in tensors] - output_tensor = scatter_list[rank] + output_tensor.copy_(scatter_list[rank]) else: scatter_list = None dist.scatter(tensor=output_tensor, scatter_list=scatter_list, src=0, group=group) @@ -70,7 +90,7 @@ class DifferentiableAllReduceSum(torch.autograd.Function): def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: ctx.group = group return all_reduce(group=group, tensor=tensor) - + @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Any: return grad_output, None @@ -94,20 +114,23 @@ def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) ctx.group = group ctx.dim = dim return all_gather(group=group, tensor=tensor, gather_dim=dim) - + @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None -def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup): +def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup): + +def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: return DifferentiableIdentity.apply(tensor, group) -def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1): + +def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: return DifferentiableAllGather.apply(tensor, group, dim) -def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1): - return DifferentiableScatter.apply(tensor, group, dim) \ No newline at end of file + +def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + return DifferentiableScatter.apply(tensor, group, dim) diff --git a/optimum/fx/parallelization/parallel_layers/__init__.py b/optimum/fx/parallelization/parallel_layers/__init__.py index 2b5b54c39b..9bfb13afdf 100644 --- a/optimum/fx/parallelization/parallel_layers/__init__.py +++ b/optimum/fx/parallelization/parallel_layers/__init__.py @@ -1 +1,16 @@ -from .linear import RowParallelLinear, ColumnParallelLinear \ No newline at end of file +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .embedding import VocabParallelEmbedding +from .linear import ColumnParallelLinear, RowParallelLinear diff --git a/optimum/fx/parallelization/parallel_layers/embedding.py b/optimum/fx/parallelization/parallel_layers/embedding.py new file mode 100644 index 0000000000..4cd21f9ebc --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/embedding.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Callable + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from ..core import ParallelExecutionCtx, ParameterMeta +from ..distributed import differentiable_all_reduce_sum +from ..utils import ensure_divisibility + + +class VocabParallelEmbedding(nn.Module): + """ + Embedding layer parallelized in vocabulary dimension. + + Arguments: + ctx: parallel execution context which contains runtime information. + embedding: the original embedding module being replaced. + init_fn: weight initialization function. + """ + + def __init__( + self, + ctx: ParallelExecutionCtx, + embedding: nn.Embedding, + init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), + ): + super(VocabParallelEmbedding, self).__init__() + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(embedding.num_embeddings, world_size) + + num_embeddings = embedding.num_embeddings // world_size + + self.padding_idx = embedding.padding_idx + self.max_norm = embedding.max_norm + self.norm_type = embedding.norm_type + self.scale_grad_by_freq = embedding.scale_grad_by_freq + self.sparse = embedding.sparse + self.vocab_start_idx = tp_rank * num_embeddings + self.vocab_end_idx = (tp_rank + 1) * num_embeddings + + # modify meta information + weight_meta = getattr(embedding.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.dim = 0 + weight_meta.init_fn = init_fn + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(self.vocab_start_idx, self.vocab_end_idx) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = embedding.weight + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx) + masked_input = input.clone() - self.vocab_start_idx + masked_input[input_mask] = 0 + + output = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + output[input_mask, :] = 0.0 + output = differentiable_all_reduce_sum(output, self.process_group) + return output diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 2ab5c11849..71c7d9d1b5 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -1,24 +1,39 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Callable + import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist -from functools import partial -from typing import Callable + from ..core import ( ParallelExecutionCtx, - ParallelParameterMapping, ParameterMeta, ) from ..distributed import ( - differentiable_identity, differentiable_all_gather, - differentiable_scatter, differentiable_all_reduce_sum, - scatter, + differentiable_identity, + differentiable_scatter, ) +from ..utils import ensure_divisibility -class ColumnParallelLinear(nn.Linear): +class ColumnParallelLinear(nn.Module): """ Linear layer with column parallelism. @@ -31,80 +46,71 @@ class ColumnParallelLinear(nn.Linear): gather_output: whether gathering output in the end of forward. init_fn: weight initialization function. """ + def __init__( self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True, - init_fn: Callable = partial(nn.init.normal_, mean=0, std=0.02), + init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), ) -> None: + super(ColumnParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) - assert linear.out_features % world_size == 0 + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(linear.out_features, world_size) - in_features = linear.in_features out_features = linear.out_features // world_size bias = linear.bias is not None - device = ctx.current_device - dtype = linear.weight.dtype - super().__init__(in_features, out_features, bias, device, dtype) + # modify meta information + weight_meta = getattr(linear.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.dim = 0 + weight_meta.init_fn = init_fn + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = linear.weight self.gather_output = gather_output - tp_rank = dist.get_rank(self.process_group) - - parameter_mapping, key = ctx.parameter_mapping, id(linear.weight) - assert key in parameter_mapping, "should have run `initialize_paramter_mapping` after moving model to current device" - original_linear_weight_meta = parameter_mapping[key].meta - - # initialize the weight if not in weight_map - need_intialize = original_linear_weight_meta.source not in ctx.weight_map - if need_intialize: - # initialize on cpu - master_weight = torch.empty_like(linear.weight, device='cpu') - init_fn(master_weight) - with torch.no_grad(): - self.weight.copy_(master_weight[tp_rank * out_features : (tp_rank + 1) * out_features, :]) - - # update parameter mapping corresponding to original linear weight and bias - linear_weight_mapping = ParallelParameterMapping( - id=id(self.weight), - meta=ParameterMeta( - source=original_linear_weight_meta.source, - dim=0, - index=slice(tp_rank * out_features, (tp_rank + 1) * out_features) - ), - parallel_dim=0 - ) - parameter_mapping.pop(key) - parameter_mapping[linear_weight_mapping.id] = linear_weight_mapping if bias: - key = id(linear.bias) - assert key in parameter_mapping - original_linear_bias_meta = parameter_mapping[key].meta - linear_bias_mapping = ParallelParameterMapping( - id=id(self.bias), - meta=ParameterMeta( - source=original_linear_bias_meta.source, - dim=0, - index=slice(tp_rank * out_features, (tp_rank + 1) * out_features) - ), - parallel_dim=0 - ) - - parameter_mapping.pop(key) - parameter_mapping[linear_bias_mapping.id] = linear_bias_mapping - self.bias.zero_() + bias_meta = getattr(linear.bias, "meta", None) + assert isinstance( + bias_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if bias_meta.is_modified_meta: + assert bias_meta.is_tied, "only tied parameters could already have modified meta" + else: + bias_meta.need_initialize = True + bias_meta.init_fn = torch.zero_ + bias_meta.dim = 0 + for _, Slice in bias_meta.mapping.items(): + Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) + bias_meta.is_modified_meta = True + self.bias = linear.bias + else: + self.register_parameter("bias", None) def forward(self, input: torch.Tensor) -> torch.Tensor: input = differentiable_identity(input, self.process_group) - output = super().forward(input) + output = F.linear(input, self.weight, self.bias) if self.gather_output: output = differentiable_all_gather(output, self.process_group) return output -class RowParallelLinear(nn.Linear): +class RowParallelLinear(nn.Module): """ Linear layer with row parallelism. @@ -123,61 +129,57 @@ class RowParallelLinear(nn.Linear): input_is_parallel: whether the input tensor has already been parallelized. init_fn: weight initialization function. """ + def __init__( self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False, - init_fn: Callable = partial(nn.init.normal_, mean=0, std=0.02), + init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), ) -> None: + super(RowParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) - assert linear.in_features % world_size == 0 + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(linear.in_features, world_size) in_features = linear.in_features // world_size - out_features = linear.out_features bias = linear.bias is not None - device = ctx.current_device - dtype = linear.weight.dtype - super().__init__(in_features, out_features, bias, device, dtype) + # modify meta information + weight_meta = getattr(linear.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.dim = 1 + weight_meta.init_fn = init_fn + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(tp_rank * in_features, (tp_rank + 1) * in_features) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = linear.weight self.input_is_parallel = input_is_parallel - tp_rank = dist.get_rank(self.process_group) - - parameter_mapping, key = ctx.parameter_mapping, id(linear.weight) - assert key in parameter_mapping, "should have run `initialize_paramter_mapping` after moving model to current device" - original_linear_weight_meta = parameter_mapping[key].meta - - need_intialize = original_linear_weight_meta.source not in ctx.weight_map - if need_intialize: - # initialize on cpu - master_weight = torch.empty_like(linear.weight, device='cpu') - init_fn(master_weight) - with torch.no_grad(): - self.weight.copy_(master_weight[:, tp_rank * in_features : (tp_rank + 1) * in_features]) - - # update parameter mapping corresponding to original linear weight and bias - linear_weight_mapping = ParallelParameterMapping( - id=id(self.weight), - meta=ParameterMeta( - source=original_linear_weight_meta.source, - dim=1, - index=slice(tp_rank * in_features, (tp_rank + 1) * in_features) - ), - parallel_dim=1 - ) - parameter_mapping.pop(key) - parameter_mapping[linear_weight_mapping.id] = linear_weight_mapping if bias: - key = id(linear.bias) - assert key in parameter_mapping - linear_bias_mapping = parameter_mapping[key] - parameter_mapping.pop(key) - linear_bias_mapping.id = id(self.bias) - parameter_mapping[linear_bias_mapping.id] = linear_bias_mapping - self.bias.zero_() - + bias_meta = getattr(linear.bias, "meta", None) + assert isinstance( + bias_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + if bias_meta.is_modified_meta: + assert bias_meta.is_tied, "only tied parameters could already have modified meta" + else: + bias_meta.need_initialize = True + bias_meta.init_fn = torch.zero_ + bias_meta.is_modified_meta = True + self.bias = linear.bias + else: + self.register_parameter("bias", None) def forward(self, input: torch.Tensor) -> torch.Tensor: if not self.input_is_parallel: diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 2d42f53441..7c394cc4b7 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -1,25 +1,48 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations -from typing import List, Any, Dict + from abc import ABC, abstractmethod -from torch.fx import GraphModule, Graph, Node +from typing import Any, Dict, List + +import torch +import torch.distributed as dist import torch.nn as nn +from torch.fx import Graph, GraphModule, Node + +from .core import Config, ParallelExecutionCtx, ParameterMeta +from .distributed import scatter +from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .utils import ( - stable_topological_sort, - is_transpose, - is_permute, + is_embedding, is_linear, + is_permute, is_shape_consumer, is_shape_generator, + is_transpose, + stable_topological_sort, ) -from .core import ParallelExecutionCtx, Config -from .parallel_layers import ColumnParallelLinear, RowParallelLinear - class PassBase(ABC): """ - Base class for parallelization targeted passes + Base class for parallelization targeted passes. """ + + need_rerun_when_recompile: bool = True + @classmethod def signature(cls) -> str: return cls.__name__ @@ -41,6 +64,10 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf raise NotImplementedError def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + # skip running when recompilation happens + if not self.need_rerun_when_recompile and ctx.compile_times > 0: + return graph_module + graph_module = self.run(graph_module, ctx=ctx, config=config) if config.lint_and_recompile: graph_module.graph.lint() @@ -50,20 +77,21 @@ def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: class AnalyzeBase(PassBase): """ - Base class for passes which only run for analytical purposes and preserve graph structure + Base class for passes which only runs for analytical purposes and preserve graph structure during processing. Analytical passes are often prerequisite passes which provide information for passes later on to actually change the graph. - Passes inheriting from `AnalyBase` places the class signature as a meta key in `node.meta`, + Passes inheriting from `AnalyzeBase` places the class signature as a meta key in `node.meta`, which is a dict storing meta information related with a fx Node, such as the shape and dtype of output. Look-up APIs are exposed as classmethod so that passes using them won't need to create concrete instances. """ + @classmethod def meta_key(cls) -> str: # place class-wise unique meta_key in `meta` to prevent duplicate fields return cls.signature() - + @classmethod def get_stored_field_info(cls, node: Node, field: Any, must_have: bool = False) -> Any: if not cls.already_executed_per_node(node): @@ -74,17 +102,19 @@ def get_stored_field_info(cls, node: Node, field: Any, must_have: bool = False) f"Can't find information related with {cls.__name__} in the current node `{node}` " f"make sure {cls.__name__} has run and marked it" ) - - info : Dict[Any, Any] = node.meta[cls.meta_key()] + + info: Dict[Any, Any] = node.meta[cls.meta_key()] if field not in info: - raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") + if must_have: + raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") + return None return info[field] - + @classmethod - def already_executed_per_node(cls, node: Node) -> None: + def already_executed_per_node(cls, node: Node) -> bool: return cls.meta_key() in node.meta - + def place_marker_per_node(self, node: Node, info: Dict[Any, Any]) -> None: if self.already_executed_per_node(node): raise RuntimeError( @@ -100,17 +130,24 @@ def clear_marker_per_node(self, node: Node) -> None: node.meta.pop(key) def clean_all(self, graph_module: GraphModule) -> None: - g : Graph = graph_module.graph + g: Graph = graph_module.graph for node in g.nodes: self.clear_marker_per_node(node) -class ParallelLinearAnnotatePass(AnalyzeBase): +class ParallelLayerAnnotatePass(AnalyzeBase): """ - A pass which tries to automatically identify parallel linears in the graph by grouping linears as - `upstream` nodes and `downstream` nodes, and `upstream` nodes are marked as `ColumnLinear`, `downstream` - nodes are marked as `RowLinear`. - + A pass which tries to automatically identify parallel layers in the graph. Note that for simplicity + we only consider classical ways of parallelizing layers in transformers architecture for now, we are not + solving an optimization problem which tries to give a best solution of parallelizing any model under + memory/hardware constraints. + + For `nn.Embedding` layers, we parallelize them on the vocabulary dim by default, because they are often tied + to the `lm_head` of the model, which is usually a `ColumnLinear`(parallelized on vocab dim). + + For `nn.Linear` layers, we parallelize them by grouping them as `upstream` nodes and `downstream` nodes, and + `upstream` nodes are marked as `ColumnLinear`, `downstream` nodes are marked as `RowLinear`. + Typical examples in transformer models: Attention Bert-style MLP Llama-style MLP @@ -120,7 +157,7 @@ class ParallelLinearAnnotatePass(AnalyzeBase): Matmul Linear Activation Activation Linear __________________________________________________________________________ \\ / | \\ / - \\ / ___________ \\ / + \\ / ___________ \\ / Matmul / Linear \ Mul | / \ | _______________________________/ \___________________________ @@ -129,18 +166,19 @@ class ParallelLinearAnnotatePass(AnalyzeBase): Note that there are some patterns that can not be clearly marked, like this one: Linear - | \\ + | \\ | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` | / Add | Linear - - For patterns like this we will be preservative and raise errors directly because we don't know how to parallelize + + For patterns like this we will be conservative and raise errors directly because we don't know how to parallelize it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution should work fairly well. """ + def try_form_parallel_linear_groups(self, linear: Node) -> None: """ We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down @@ -156,45 +194,44 @@ def try_form_parallel_linear_groups(self, linear: Node) -> None: """ upstream_nodes, downstream_nodes = {linear}, set() - seeds, next_seeds = [(linear, 'down')], [] + seeds, next_seeds = [(linear, "down")], [] - def traverse(start: Node, cur: Node, direction = 'down'): + def traverse(start: Node, cur: Node, direction: str = "down"): if is_linear(cur) and cur is not start: - if direction == 'up' and cur not in upstream_nodes: + if direction == "up" and cur not in upstream_nodes: upstream_nodes.add(cur) - next_seeds.append((cur, 'down')) - elif direction == 'down' and cur not in downstream_nodes: + next_seeds.append((cur, "down")) + elif direction == "down" and cur not in downstream_nodes: downstream_nodes.add(cur) - next_seeds.append((cur, 'up')) + next_seeds.append((cur, "up")) return - - next_nodes = cur.all_input_nodes if direction == 'up' else cur.users + next_nodes = cur.all_input_nodes if direction == "up" else cur.users for node in next_nodes: # we should ignore shape-related dependencies if is_shape_generator(node): continue traverse(start, node, direction) - + while seeds: next_seeds = [] for node, direction in seeds: traverse(start=node, cur=node, direction=direction) seeds = next_seeds - if any([self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)]) or \ - (upstream_nodes & downstream_nodes): + if any(self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)) or ( + upstream_nodes & downstream_nodes + ): raise RuntimeError( "Failed to automatically group and parallelize ops in graph in greedy way: " - "no clear boudaries between `upstream` and `downstream` ops." + "no clear boudaries between `upstream` and `downstream` ops." ) - + for node in upstream_nodes: - self.place_marker_per_node(node, {'axis' : 'column', 'gather_output' : False if downstream_nodes else True}) + self.place_marker_per_node(node, {"axis": "column", "gather_output": False if downstream_nodes else True}) for node in downstream_nodes: - self.place_marker_per_node(node, {'axis' : 'row', 'input_is_parallel' : True}) - + self.place_marker_per_node(node, {"axis": "row", "input_is_parallel": True}) def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: graph: Graph = graph_module.graph @@ -202,60 +239,62 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf for node in graph.nodes: if is_linear(node) and not self.already_executed_per_node(node): self.try_form_parallel_linear_groups(node) - + elif is_embedding(node): + # directly mark `nn.Embedding` layers + self.place_marker_per_node(node, {"axis": "vocab"}) + return graph_module class ParallelAxisPropagationPass(AnalyzeBase): """ - A pass tries to track which axis is being parallelized in the dataflow. For transformer models, the + A pass which tries to track which axis is being parallelized in the dataflow. For transformer models, the axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for - Q and K matrice which need to swap the sequence length axis and head axis to do the attention computation, + Q and K matrices which need to swap the sequence length axis and head axis to do the attention computation, so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel axis after these operations. """ + def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: - dims = node.meta['example_value'].dim() - if 'dim0' in node.kwargs and 'dim1' in node.kwargs: - dim0, dim1, dims = node.kwargs['dim0'], node.kwargs['dim1'] - dim0 = (dim0 + dims) % dims - dim1 = (dim1 + dims) % dims - if dim0 == parallel_axis: - self.place_marker_per_node(node, {'parallel_axis' : dim1}) - return True - elif dim1 == parallel_axis: - self.place_marker_per_node(node, {'parallel_axis' : dim0}) - return True - return False - - if len(node.args) == 3: - if parallel_axis not in node.args and parallel_axis - dims not in node.args: - return False - for arg in node.args: - if isinstance(arg, int) and (arg + dims) % dims != parallel_axis: - self.place_marker_per_node(node, {'parallel_axis' : (arg + dims) % dims}) - return True - + dims = node.meta["example_value"].dim() + if "dim0" in node.kwargs and "dim1" in node.kwargs: + dim0, dim1 = node.kwargs["dim0"], node.kwargs["dim1"] + elif len(node.args) == 3: + dim0, dim1 = node.args[1:] + + dim0 = (dim0 + dims) % dims + dim1 = (dim1 + dims) % dims + + if dim0 == parallel_axis: + self.place_marker_per_node(node, {"parallel_axis": dim1}) + return True + elif dim1 == parallel_axis: + self.place_marker_per_node(node, {"parallel_axis": dim0}) + return True return False def propagate_permute(self, node: Node, parallel_axis: int) -> bool: - if 'dims' in node.kwargs: - dims = node.kwargs['dims'] + if "dims" in node.kwargs: + dims = node.kwargs["dims"] else: - dims = list(node.args[1]) if isinstance(node.args[1], tuple) else [arg for arg in node.args if isinstance(arg,int)] - - dim_len = node.meta['example_value'].dim() + dims = ( + list(node.args[1]) + if isinstance(node.args[1], tuple) + else [arg for arg in node.args if isinstance(arg, int)] + ) + + dim_len = node.meta["example_value"].dim() dims = [dim + dim_len if dim < 0 else dim for dim in dims] - for i,dim in enumerate(dims): + for i, dim in enumerate(dims): if dim == parallel_axis: - self.place_marker_per_node(node, {'parallel_axis' : i}) + self.place_marker_per_node(node, {"parallel_axis": i}) return True return False def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: slices = node.args[1] - dims = node.meta['example_value'].dim() + dims = node.meta["example_value"].dim() assert parallel_axis < dims inc, i, j = 0, 0, 0 @@ -278,7 +317,7 @@ def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: if inc != 0: assert parallel_axis + inc < dims and parallel_axis + inc >= 0 - self.place_marker_per_node(node, {'parallel_axis' : parallel_axis + inc}) + self.place_marker_per_node(node, {"parallel_axis": parallel_axis + inc}) return True return False @@ -287,11 +326,12 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf stable_topological_sort(g) for node in g.nodes: - if ParallelLinearAnnotatePass.already_executed_per_node(node): + if ParallelLayerAnnotatePass.already_executed_per_node(node): # start propagating at ColumnLinear, marking the beginning of parallelized region - axis = ParallelLinearAnnotatePass.get_stored_field_info(node, field='axis', must_have=True) - if axis == 'column': - self.place_marker_per_node(node, {'parallel_axis' : 2}) + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis", must_have=True) + gather_output = ParallelLayerAnnotatePass.get_stored_field_info(node, field="gather_output") + if axis == "column" and not gather_output: + self.place_marker_per_node(node, {"parallel_axis": 2}) # stop propagating at RowLinear, concluding the ending of parallelized region else: continue @@ -301,12 +341,13 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf if not self.already_executed_per_node(arg): continue if parallel_axis is None: - parallel_axis = self.get_stored_field_info(arg, field='parallel_axis', must_have=True) + parallel_axis = self.get_stored_field_info(arg, field="parallel_axis", must_have=True) else: - assert parallel_axis == self.get_stored_field_info(arg, field='parallel_axis', must_have=True), \ - "`parallel_axis` should be equal for all arguments in any related ops" + assert parallel_axis == self.get_stored_field_info( + arg, field="parallel_axis", must_have=True + ), "`parallel_axis` should be equal for all arguments in any related ops" already_marked_args.append(arg) - + if not already_marked_args: continue @@ -315,28 +356,29 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf marked = self.propagate_transpose(node, parallel_axis) elif is_permute(node): marked = self.propagate_permute(node, parallel_axis) - + # fall back if not marked: - self.place_marker_per_node(node, {'parallel_axis' : parallel_axis}) + self.place_marker_per_node(node, {"parallel_axis": parallel_axis}) return graph_module -class ParallelLinearReplacePass(PassBase): +class ParallelLayerReplacePass(PassBase): """ A pass which modifies graph according to information provided by previous analytical passes, in general it does two things for now: - 1. replace linears with their parallel counterparts. - 2. modify hard-coded arguments like the number of attenton heads in the graph by dividing it by parallelism level. + 1. replaces linears and embedding layers with their parallel counterparts. + 2. modifies hard-coded arguments like the number of attention heads in the graph by dividing it by parallelism level. """ + @staticmethod def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: graph_module = node.graph.owning_module - axis = ParallelLinearAnnotatePass.get_stored_field_info(node, field='axis') + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: return - - assert axis in {'column', 'row'} + + assert axis in {"column", "row"} prefix_and_field = node.target.rsplit(".", maxsplit=1) if len(prefix_and_field) == 2: parent_mod = graph_module.get_submodule(prefix_and_field[0]) @@ -345,45 +387,73 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None parent_mod = graph_module field = node.target - mod : nn.Linear = graph_module.get_submodule(node.target) + mod: nn.Linear = graph_module.get_submodule(node.target) key, layer_cache = id(mod), ctx.parallel_layer_cache if key in layer_cache: new_mod = layer_cache[key] else: - if axis == 'column': - gather_output = ParallelLinearAnnotatePass.get_stored_field_info(node, field='gather_output', must_have=True) - new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn) + if axis == "column": + gather_output = ParallelLayerAnnotatePass.get_stored_field_info( + node, field="gather_output", must_have=True + ) + new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn) else: - input_is_parallel = ParallelLinearAnnotatePass.get_stored_field_info(node, field='input_is_parallel', must_have=True) + input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info( + node, field="input_is_parallel", must_have=True + ) new_mod = RowParallelLinear(ctx, mod, input_is_parallel, config.weight_init_fn) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) + @staticmethod + def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + graph_module = node.graph.owning_module + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") + if axis is None: + return + + assert axis in {"vocab"}, "Only support parallelization on vocab dim for now." + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod: nn.Embedding = graph_module.get_submodule(node.target) + key, layer_cache = id(mod), ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + new_mod = VocabParallelEmbedding(ctx, mod, config.weight_init_fn) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) @staticmethod def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: def extract_shape_from_node(node: Node) -> List[Any]: - if 'size' in node.kwargs: - return list(node.kwargs['size']) - elif 'shape' in node.kwargs: - return list(node.kwargs['shape']) + if "size" in node.kwargs: + return list(node.kwargs["size"]) + elif "shape" in node.kwargs: + return list(node.kwargs["shape"]) elif isinstance(node.args[1], tuple): - return [idx for idx in node.args[1]] + return list(node.args[1]) else: - return [idx for idx in node.args[1:]] + return list(node.args[1:]) def update(node: Node, new_shape: List[Any], parallel_axis: int): - if 'size' in node.kwargs: - node.update_kwarg('size', tuple(new_shape)) - elif 'shape' in node.kwargs: - node.update_kwarg('shape', tuple(new_shape)) + if "size" in node.kwargs: + node.update_kwarg("size", tuple(new_shape)) + elif "shape" in node.kwargs: + node.update_kwarg("shape", tuple(new_shape)) elif isinstance(node.args[1], tuple): node.update_arg(1, tuple(new_shape)) else: node.update_arg(parallel_axis + 1, shape[parallel_axis]) - parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field='parallel_axis') + parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field="parallel_axis") if parallel_axis is None: return @@ -400,28 +470,104 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf for node in graph_module.graph.nodes: if is_linear(node): self.handle_linear(node, ctx, config) + elif is_embedding(node): + self.handle_embedding(node, ctx, config) # correct the attention head num in parallel setting elif is_shape_consumer(node): self.handle_hard_coded_axis_param(node, ctx) return graph_module +class InitializeOrLoadWeightsPass(PassBase): + """ + Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This + pass will only run once in the very first compilation step. + """ + + need_rerun_when_recompile = False + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + world_size = dist.get_world_size(ctx.tp_group) + tp_rank = dist.get_rank(ctx.tp_group) + + new_parameters, tied_parameters = [], {} + for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): + param_meta: ParameterMeta = getattr(param, "meta") + # skip already initialized parameters + if not param_meta.need_initialize: + continue + if param_meta.is_tied and id(param) in tied_parameters: + new_parameters.append((name, tied_parameters[id(param)])) + continue + + shape = [ + param.size(dim) // world_size if dim == param_meta.dim else param.size(dim) + for dim in range(param.ndim) + ] + new_param = nn.Parameter( + torch.randn(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad + ) + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + # TODO: add weights loading logic + continue + if tp_rank == 0: + # initialize weight on master rank + start = source.start if source.start else 0 + stop = source.stop if source.stop else start + param.size(param_meta.dim) // world_size + shape = [ + (stop - start) * world_size if dim == param_meta.dim else param.size(dim) + for dim in range(param.ndim) + ] + weight = torch.empty(*shape, dtype=param.dtype, device="cpu") + init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn + init_fn(weight) + weight = weight.to(ctx.current_device) + else: + weight = None + with torch.no_grad(): + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + + new_parameters.append((name, new_param)) + if param_meta.is_tied: + tied_parameters[id(param)] = new_param + + for name, new_param in new_parameters: + prefix_and_field = name.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = name + setattr(parent_mod, field, new_param) + + return graph_module + + def build_parallel_pass_pipeline() -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: - - 1. `ParallelLinearAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` + 1. `ParallelLayerAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters Returns: PassPipeline: the pipeline used for automatic parallelism. """ - return PassPipeline([ - ParallelLinearAnnotatePass(), - ParallelAxisPropagationPass(), - ParallelLinearReplacePass() - ]) + return PassPipeline( + [ + ParallelLayerAnnotatePass(), + ParallelAxisPropagationPass(), + ParallelLayerReplacePass(), + InitializeOrLoadWeightsPass(), + ] + ) class PassPipeline: @@ -429,19 +575,22 @@ class PassPipeline: `PassPipeline` ensembles a list of passes and execute them one by one as provided in the list, it can be iterated and appended after initialization for flexibility. """ - def __init__(self, passes : List[PassBase] = []) -> None: + + def __init__(self, passes: List[PassBase] = []) -> None: self._passes = passes - def __iter__(self,): + def __iter__( + self, + ): return self._passes.__iter__() def append(self, PASS: PassBase): self._passes.append(PASS) - + def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: for PASS in self._passes: graph_module = PASS(graph_module=graph_module, ctx=ctx, config=config) - + if config.clean_markers_after_all_passes: for PASS in self._passes: if isinstance(PASS, AnalyzeBase): diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 5c64568e0b..00e44868c0 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -1,59 +1,98 @@ -import operator +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib +import operator +from collections import defaultdict +from functools import wraps +from itertools import chain +from typing import Callable, Dict, List, Union import torch import torch.nn as nn import torch.nn.functional as F -from typing import Dict, Callable, List, Union -from torch.fx import Node, Graph -from functools import wraps -from collections import defaultdict -from itertools import chain +from torch.fx import Graph, Node + from .core import ( - ParallelExecutionCtx, - ParameterMapping, + HashableSlice, ParameterMeta, + ParameterSlice, ) + +def ensure_divisibility(numerator: int, denominator: int) -> None: + if numerator % denominator != 0: + raise RuntimeError( + f"{numerator} is not divisible by {denominator}, check if the parallel dimension of weight parameters is divisible " + "by parallelism level(world size of tensor parallel group)" + ) + + def is_linear(node: Node) -> bool: - if node.op != 'call_module': + if node.op != "call_module": return False mod = node.graph.owning_module return isinstance(mod.get_submodule(node.target), nn.Linear) + +def is_embedding(node: Node) -> bool: + if node.op != "call_module": + return False + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.Embedding) + + def is_shape_consumer(node: Node) -> bool: - if node.op == 'call_method': - return node.target in {'view', 'reshape', 'expand', 'resize', 'resize_'} - elif node.op == 'call_function': + if node.op == "call_method": + return node.target in {"view", "reshape", "expand", "resize", "resize_"} + elif node.op == "call_function": return node.target in {torch.reshape} + return False + def is_transpose(node: Node) -> bool: - if node.op == 'call_method': - return node.target in {'transpose', 'transpose_'} - elif node.op == 'call_function': + if node.op == "call_method": + return node.target in {"transpose", "transpose_"} + elif node.op == "call_function": return node.target is torch.transpose return False + def is_permute(node: Node) -> bool: - if node.op == 'call_method': - return node.target in {'permute'} - elif node.op == 'call_function': + if node.op == "call_method": + return node.target in {"permute"} + elif node.op == "call_function": return node.target is torch.permute return False + def is_getitem(node: Node) -> bool: - return node.op == 'call_function' and node.target is operator.getitem + return node.op == "call_function" and node.target is operator.getitem + def is_output(node: Node) -> bool: - return node.op == 'output' + return node.op == "output" + def is_shape_generator(node: Node) -> bool: - return node.op == 'call_method' and node.target == 'size' + return node.op == "call_method" and node.target == "size" + def stable_topological_sort(graph: Graph): def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: - args: List[torch.fx.node.Argument] = list() + args: List[torch.fx.node.Argument] = [] torch.fx.map_arg((n.args, n.kwargs), args.append) return args @@ -91,6 +130,7 @@ def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: assert not waiting and len(ready) == len(graph.nodes) + def meta_init(init_fn): @wraps(init_fn) def wrapper(*args, **kwargs): @@ -99,32 +139,66 @@ def wrapper(*args, **kwargs): return wrapper + @wraps(nn.Linear.forward) def meta_aware_linear_forward(*args, **kwargs): self = args[0] input = args[1] - if self.weight.device != torch.device('meta'): + if self.weight.device != torch.device("meta"): return F.linear(input, self.weight, self.bias) - + orig_device = input.device input = input.to("meta") meta_output = F.linear(input, self.weight, self.bias) return torch.empty_like(meta_output, device=orig_device) +@wraps(nn.Embedding.forward) +def meta_aware_embedding_forward(*args, **kwargs): + self = args[0] + input = args[1] + + if self.weight.device != torch.device("meta"): + return F.embedding( + input=input, + weight=self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + orig_device = input.device + input = input.to("meta") + meta_output = F.embedding( + input=input, + weight=self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + return torch.empty_like(meta_output, device=orig_device) + + class MetaAwareMethodsPatcher: """ A patcher class which patches `__init__` and `forward` methods on modules which will be put on meta devices for memory efficiency purposes during initialization. - + Note that for `__init__` method, it can be unpatched once we have finished the initialization of the model, however, for `forward`, we need it to constantly being patched during the whole process in case recompile happens and torch dynamo needs meta-aware `forward` to be able to re-capture the graph. """ - methods_to_patch : Dict[str, Callable] = [ - ("torch.nn.Linear.__init__", meta_init(torch.nn.Linear.__init__)), + + methods_to_patch: Dict[str, Callable] = [ + ("torch.nn.Linear.__init__", meta_init(nn.Linear.__init__)), + ("torch.nn.Embedding.__init__", meta_init(nn.Embedding.__init__)), ("torch.nn.Linear.forward", meta_aware_linear_forward), + ("torch.nn.Embedding.forward", meta_aware_embedding_forward), ] def __init__(self) -> None: @@ -134,9 +208,7 @@ def __init__(self) -> None: try: module = importlib.import_module(module_qualified_name) except ModuleNotFoundError as e: - module_qualified_name, module_attribute_name = module_qualified_name.rsplit( - ".", maxsplit=1 - ) + module_qualified_name, module_attribute_name = module_qualified_name.rsplit(".", maxsplit=1) module = importlib.import_module(module_qualified_name) try: module = getattr(module, module_attribute_name) @@ -165,19 +237,29 @@ def _unpatch(self, identifier: str): setattr(spec[0], spec[1], spec[2]) spec[-1] = False - def patch_meta_init(self,): + def patch_meta_init( + self, + ): self._patch("init") - def patch_meta_forward(self,): + def patch_meta_forward( + self, + ): self._patch("forward") - def unpatch_meta_init(self,): + def unpatch_meta_init( + self, + ): self._unpatch("init") - def unpatch_meta_forward(self,): + def unpatch_meta_forward( + self, + ): self._unpatch("forward") - def __enter__(self,): + def __enter__( + self, + ): self.patch_meta_init() self.patch_meta_forward() @@ -185,18 +267,28 @@ def __exit__(self, exc_type, exc_value, traceback): self.unpatch_meta_init() -def initialize_parameter_mapping(model: nn.Module, ctx: ParallelExecutionCtx) -> None: - mapping = ctx.parameter_mapping +def initialize_parameter_meta(model: nn.Module) -> None: + parameter_ids = set() + for name, tensor in model.named_parameters(remove_duplicate=False): + key = id(tensor) + if key not in parameter_ids: + setattr( + tensor, + "meta", + ParameterMeta(dim=0, mapping={HashableSlice(None, None, None): ParameterSlice(source=name)}), + ) + parameter_ids.add(key) + else: + tensor.meta.is_tied = True - for name, tensor in chain(model.named_parameters(), model.named_buffers()): - mapping[id(tensor)] = ParameterMapping(id = id(tensor), meta = ParameterMeta(source=name)) +@torch.no_grad def move_model_to_device(model: nn.Module, device: Union[torch.device, str]): # move everything except tensors on meta devices on current device # this function should be called before `intialize_parameter_mapping` for name, tensor in chain(model.named_parameters(), model.named_buffers()): if tensor.device == torch.device("meta"): - continue + continue splits = name.rsplit(".", maxsplit=1) if len(splits) == 1: parent_mod = model diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d98294934..d21db2a4ac 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/tests/fx/parallelization/dist_utils.py b/tests/fx/parallelization/dist_utils.py index a9abe4dd34..ef35fb33d0 100644 --- a/tests/fx/parallelization/dist_utils.py +++ b/tests/fx/parallelization/dist_utils.py @@ -1,10 +1,26 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os +from typing import Callable, List, Optional + import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import Callable, List, Optional from transformers import set_seed + SEED = 42 NUM_AVAILABLE_DEVICES = torch.cuda.device_count() @@ -12,9 +28,9 @@ def dist_init( rank: int, world_size: int, - backend: str = 'nccl', - master_addr: str = '127.0.0.1', - master_port: str = '29500', + backend: str = "nccl", + master_addr: str = "127.0.0.1", + master_port: str = "29501", ): os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) @@ -30,18 +46,24 @@ def dist_init( torch.cuda.set_device(rank) -def runner(rank: int, fn:Callable, deterministic: bool, *args, **kwargs): + +def runner(rank: int, fn: Callable, deterministic: bool, *args, **kwargs): if deterministic: set_seed(SEED) fn(rank, *args, **kwargs) + def spawn(world_size: int, fn: Callable, *args, deterministic: bool = False): mp.spawn(fn=runner, args=(fn, deterministic, world_size, *args), nprocs=world_size, join=True) + def tearDown(group: Optional[dist.ProcessGroup] = None): dist.destroy_process_group(group) -def gather_at_main_process(tensor: torch.Tensor, group: dist.ProcessGroup, rank: int, world_size: int) -> List[torch.Tensor]: + +def gather_at_main_process( + tensor: torch.Tensor, group: dist.ProcessGroup, rank: int, world_size: int +) -> List[torch.Tensor]: if world_size == 1: return [tensor] diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index f865f3d35b..a1a3ac09cb 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -1,51 +1,79 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest +from functools import partial +from typing import Type + import torch import torch.distributed as dist -from typing import Type -from functools import partial +from dist_utils import NUM_AVAILABLE_DEVICES, SEED, dist_init, gather_at_main_process, spawn, tearDown +from packaging import version +from parameterized import parameterized from transformers import ( - PretrainedConfig, - PreTrainedModel, LlamaConfig, - MistralConfig, LlamaForCausalLM, + MistralConfig, MistralForCausalLM, + PretrainedConfig, + PreTrainedModel, set_seed, ) -from parameterized import parameterized -from optimum.fx.parallelization import parallelize_backend, ParallelExecutionCtx, Config -from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher, move_model_to_device, initialize_parameter_mapping -from dist_utils import ( - dist_init, - tearDown, - spawn, - gather_at_main_process, - NUM_AVAILABLE_DEVICES, - SEED -) + +from optimum.fx.parallelization import Config, ParallelExecutionCtx, parallelize_backend +from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher, initialize_parameter_meta, move_model_to_device DUMMY_MODELS_TO_TEST = ( - (LlamaForCausalLM, LlamaConfig(), ), - (MistralForCausalLM, MistralConfig(), ), + ( + LlamaForCausalLM, + LlamaConfig, + ), + ( + MistralForCausalLM, + MistralConfig, + ), ) +def is_gpu_available(): + return torch.cuda.is_available() + + +def is_torch_compile_available(): + return version.parse(torch.__version__) >= version.parse("2.3.0") + + def dummify(config: PretrainedConfig): config.num_hidden_layers = 2 config.use_cache = False config.output_attentions = False config.output_hidden_states = False -def run_test_all_rank_results_match(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + +def run_test_all_rank_results_match( + rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] +): + model_config = config_cls() dummify(model_config) # initialize default group dist_init(rank, world_size) tp_group = dist.new_group() - + # prepare config and context - device = torch.device(type='cuda', index=torch.cuda.current_device()) + device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() inputs = { @@ -60,7 +88,7 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_cls: Type[ model.eval() # move model to current device, with linears still on meta, and intialize parameter mapping move_model_to_device(model, device=device) - initialize_parameter_mapping(model, ctx=ctx) + initialize_parameter_meta(model) model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) logits = model(**inputs)[0] @@ -71,19 +99,23 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_cls: Type[ assert len(tensors) == world_size for i in range(1, world_size): torch.testing.assert_close(tensors[i - 1].cpu(), tensors[i].cpu(), rtol=1e-4, atol=1e-4) - + dist.barrier(tp_group) tearDown(tp_group) -def run_test_parameters_persist_bewteen_recompile(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + +def run_test_parameters_persist_bewteen_recompile( + rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] +): + model_config = config_cls() dummify(model_config) # initialize default group dist_init(rank, world_size) tp_group = dist.new_group() - + # prepare config and context - device = torch.device(type='cuda', index=torch.cuda.current_device()) + device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() inputs = { @@ -91,7 +123,7 @@ def run_test_parameters_persist_bewteen_recompile(rank: int, world_size: int, mo "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), } - + # different shape to trigger recompile another_inputs = { "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 11), device=device), @@ -105,30 +137,34 @@ def run_test_parameters_persist_bewteen_recompile(rank: int, world_size: int, mo model.eval() # move model to current device, with linears still on meta move_model_to_device(model, device=device) - initialize_parameter_mapping(model, ctx=ctx) + initialize_parameter_meta(model) model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) model(**inputs) - parameter_ids = set([id(param) for _, param in model.named_parameters()]) + parameter_ids = {id(param) for _, param in model.named_parameters()} model(**another_inputs) - parameter_ids_after_recompile = set([id(param) for _, param in model.named_parameters()]) + parameter_ids_after_recompile = {id(param) for _, param in model.named_parameters()} assert parameter_ids == parameter_ids_after_recompile dist.barrier(tp_group) tearDown(tp_group) -def run_test_parallel_results_matches_non_parallel(rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig): + +def run_test_parallel_results_matches_non_parallel( + rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] +): + model_config = config_cls() dummify(model_config) dist_init(rank, world_size) tp_group = dist.new_group(ranks=[rank]) - + # prepare config and context - device = torch.device(type='cuda', index=torch.cuda.current_device()) + device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - + inputs = { "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), @@ -143,7 +179,7 @@ def run_test_parallel_results_matches_non_parallel(rank: int, world_size: int, m # move model to current device, with linears still on meta move_model_to_device(model, device=device) - initialize_parameter_mapping(model, ctx=ctx) + initialize_parameter_meta(model) model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) logits = model(**inputs)[0] @@ -159,7 +195,7 @@ def run_test_parallel_results_matches_non_parallel(rank: int, world_size: int, m # move model to current device, with linears still on meta move_model_to_device(model, device=device) - initialize_parameter_mapping(model, ctx=ctx) + initialize_parameter_meta(model) model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) parallel_logits = model(**inputs)[0] @@ -169,22 +205,42 @@ def run_test_parallel_results_matches_non_parallel(rank: int, world_size: int, m dist.barrier(tp_group) tearDown() + @parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") -def test_all_rank_results_match(model_cls, config, ): +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_all_rank_results_match( + model_cls, + config_cls, +): for world_size in [1, 2, 4, 8]: if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_all_rank_results_match, model_cls, config, deterministic=True) + spawn(world_size, run_test_all_rank_results_match, model_cls, config_cls, deterministic=True) + @parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") -def test_parameters_persist_bewteen_recompile(model_cls, config, ): - for world_size in [1, 2, 4, 8]: +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_parameters_persist_bewteen_recompile( + model_cls, + config_cls, +): + for world_size in [1, 2]: if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_parameters_persist_bewteen_recompile, model_cls, config, deterministic=False) + spawn( + world_size, run_test_parameters_persist_bewteen_recompile, model_cls, config_cls, deterministic=False + ) + @parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf(not torch.cuda.is_available(), "requires gpu to run") -def test_parallel_results_matches_non_parallel(model_cls, config, ): +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_parallel_results_matches_non_parallel( + model_cls, + config_cls, +): # world_size == 2 is enough - spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, config, deterministic=True) \ No newline at end of file + spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, config_cls, deterministic=True) From ae6d9d27164e8818738443fa34bc306b6950d373 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 9 Jul 2024 00:49:36 +0200 Subject: [PATCH 06/25] address comments --- .../parallelization/distributed/dist_ops.py | 24 +++++++++++++------ optimum/fx/parallelization/passes.py | 1 - optimum/fx/parallelization/utils.py | 7 +++--- optimum/onnxruntime/runs/__init__.py | 6 ++--- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index 69abe68bca..e55c1ac707 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -14,6 +14,7 @@ # limitations under the License. import torch import torch.distributed as dist +from ..utils import ensure_divisibility def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor: @@ -32,11 +33,20 @@ def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = rank = dist.get_rank(group=group) tensor = tensor.contiguous() - tensors = [torch.empty_like(tensor) for _ in range(world_size)] - tensors[rank] = tensor - - dist.all_gather(tensors, tensor, group=group) - return torch.cat(tensors, dim=gather_dim) + gather_dim = (gather_dim + tensor.ndim) % tensor.ndim + shape = tuple( + tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim) + ) + index = list( + slice(rank * tensor.size(dim), (rank + 1) * tensor.size(dim), None) + if dim == gather_dim + else slice(None, None, None) + for dim in range(tensor.ndim) + ) + tensors = torch.empty(*shape, dtype=tensor.dtype, device=tensor.device) + tensors[index] = tensor + dist.all_gather_into_tensor(tensors, tensor, group=group) + return tensors def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor: @@ -46,7 +56,7 @@ def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) - rank = dist.get_rank(group) size = tensor.size() - assert size[split_dim] % world_size == 0 + ensure_divisibility(size[split_dim], world_size) tensors = torch.split(tensor, size[split_dim] // world_size, dim=split_dim) tensor = tensors[rank].contiguous() @@ -63,7 +73,7 @@ def scatter( rank = dist.get_rank(group) if rank == 0: size = tensor.size() - assert size[scatter_dim] % world_size == 0 + ensure_divisibility(size[scatter_dim], world_size) tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim) scatter_list = [tensor.contiguous() for tensor in tensors] output_tensor.copy_(scatter_list[rank]) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 7c394cc4b7..eda1d02b69 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -432,7 +432,6 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> N @staticmethod def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: - def extract_shape_from_node(node: Node) -> List[Any]: if "size" in node.kwargs: return list(node.kwargs["size"]) diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 00e44868c0..59c968ec5b 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -90,7 +90,6 @@ def is_shape_generator(node: Node) -> bool: def stable_topological_sort(graph: Graph): - def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: args: List[torch.fx.node.Argument] = [] torch.fx.map_arg((n.args, n.kwargs), args.append) @@ -284,8 +283,10 @@ def initialize_parameter_meta(model: nn.Module) -> None: @torch.no_grad def move_model_to_device(model: nn.Module, device: Union[torch.device, str]): - # move everything except tensors on meta devices on current device - # this function should be called before `intialize_parameter_mapping` + """ + Move everything except tensors on meta devices on current device + this function should be called before `intialize_parameter_meta` + """ for name, tensor in chain(model.named_parameters(), model.named_buffers()): if tensor.device == torch.device("meta"): continue diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index d21db2a4ac..1d98294934 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body["model_type"] = ( - self.torch_model.config.model_type - ) # return_body is initialized in parent class + self.return_body[ + "model_type" + ] = self.torch_model.config.model_type # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) From 455c0c7f04b43e55ff8b04aef701bd294dd9534c Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 9 Jul 2024 00:50:45 +0200 Subject: [PATCH 07/25] lint --- optimum/fx/parallelization/distributed/dist_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index e55c1ac707..9e0b654210 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -14,6 +14,7 @@ # limitations under the License. import torch import torch.distributed as dist + from ..utils import ensure_divisibility @@ -37,12 +38,12 @@ def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = shape = tuple( tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim) ) - index = list( + index = [ slice(rank * tensor.size(dim), (rank + 1) * tensor.size(dim), None) if dim == gather_dim else slice(None, None, None) for dim in range(tensor.ndim) - ) + ] tensors = torch.empty(*shape, dtype=tensor.dtype, device=tensor.device) tensors[index] = tensor dist.all_gather_into_tensor(tensors, tensor, group=group) From 27a9bb822f789fdd20daeb9b161fe130583b68e9 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 13 Jul 2024 00:54:47 +0200 Subject: [PATCH 08/25] fix --- .../workflows/test_fx_automatic_parallel.yml | 63 +++++++++++++++++++ optimum/fx/parallelization/__init__.py | 17 +---- optimum/fx/parallelization/api.py | 31 +++++++++ optimum/fx/parallelization/core.py | 20 ++++-- .../parallelization/distributed/dist_ops.py | 24 ++++--- .../parallel_layers/embedding.py | 17 ++--- .../parallelization/parallel_layers/linear.py | 38 ++++------- optimum/fx/parallelization/passes.py | 33 +++++----- optimum/fx/parallelization/utils.py | 5 +- .../parallelization/test_tensor_parallel.py | 6 +- 10 files changed, 161 insertions(+), 93 deletions(-) create mode 100644 .github/workflows/test_fx_automatic_parallel.yml create mode 100644 optimum/fx/parallelization/api.py diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml new file mode 100644 index 0000000000..d745b8c724 --- /dev/null +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -0,0 +1,63 @@ +name: Automatic Model Parallelism Test on GPUs + +on: + pull_request: + branches: + - main + paths: + - 'optimum/fx/parallelization/**.py' + push: + branches: + - main + paths: + - 'optimum/fx/parallelization/**.py' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + run_gpu_tests: + strategy: + fail-fast: false + matrix: + config: + - name: GPU-enabled Optimum Test Suite + image: nvidia/cuda:12.4.1-devel-ubuntu22.04 + gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] + + name: ${{ matrix.config.name }} + runs-on: + group: "${{matrix.gpu_target}}" + + container: + image: ${{ matrix.config.image }} + options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Checkout optimum + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run nvidia-smi + run: | + nvidia-smi + + - name: Install dependencies + run: | + python3 -m pip install -U pip + python3 -m pip install torch transformers + python3 -m pip install .[tests] + + - name: Run automatic model parallelism tests + run: | + pytest -s -v -o log_cli=true tests/fx/parallelization diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py index 7f3d0e737b..bb42a0f133 100644 --- a/optimum/fx/parallelization/__init__.py +++ b/optimum/fx/parallelization/__init__.py @@ -12,20 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import torch -from torch.fx import GraphModule - +from .api import parallelize_backend from .core import Config, ParallelExecutionCtx -from .passes import build_parallel_pass_pipeline - - -def parallelize_backend( - graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config -) -> GraphModule: - ctx.example_inputs = example_inputs - pass_pipeline = build_parallel_pass_pipeline() - graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) - ctx.compile_times += 1 - return graph_module diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py new file mode 100644 index 0000000000..7f3d0e737b --- /dev/null +++ b/optimum/fx/parallelization/api.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch +from torch.fx import GraphModule + +from .core import Config, ParallelExecutionCtx +from .passes import build_parallel_pass_pipeline + + +def parallelize_backend( + graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config +) -> GraphModule: + ctx.example_inputs = example_inputs + pass_pipeline = build_parallel_pass_pipeline() + graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) + ctx.compile_times += 1 + return graph_module diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index a040123bfe..bd50d0d059 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -14,7 +14,7 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -49,16 +49,20 @@ class ParameterSlice: along a specific axis (the potential parallel axis) right now. Attributes: - - source (`Optional[str]`): + - source (`Optional[str]`, defaults to `None`): Original parameter name which can be found in the weight dict. - - index (`Optional[slice]`): + - shape (`Optional[Tuple]`, defaults to `None`): + Shape of parameter tensor corresponding to `source`. + + - index (`slice`, defaults to `slice(None, None, None)`): Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same layout as their correspondings in memory. """ source: Optional[str] = None - index: Optional[slice] = None + shape: Optional[Tuple] = None + index: slice = slice(None, None, None) @dataclass @@ -70,23 +74,27 @@ class ParameterMeta: - is_tied (`bool`, defaults to `False`): Whether the parameter is shared accross multiple modules. + - is_parallel (`bool`, defaults to `False`): + Whether the parameter needs to be parallelized. + - is_modified_meta (`bool`, defaults to `False`): Whether the meta has already been modified since initialization. - need_initialize (`bool`, defaults to `False`): Whether need to manually initialize weights if not provided in weight map. - - init_fn (`Optional[Callable]`): + - init_fn (`Optional[Callable]`, defaults to `None`): Initialization function, can override `weight_init_fn` in `Config` if not None. - dim (`int`, defaults to `0`): - Axis on which `mapping` is based. + Axis on which `mapping` is based, also the parallel axis if `is_parallel`. - mapping (`Dict[HashableSlice, ParameterSlice]`): Mapping between the current parameter and weight tensor stored in weight map. """ is_tied: bool = False + is_parallel: bool = False is_modified_meta: bool = False need_initialize: bool = False init_fn: Optional[Callable] = None diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index 9e0b654210..081f84ce17 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -31,22 +31,19 @@ def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = world_size = dist.get_world_size(group) if world_size == 1: return tensor - rank = dist.get_rank(group=group) - - tensor = tensor.contiguous() gather_dim = (gather_dim + tensor.ndim) % tensor.ndim - shape = tuple( - tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim) - ) - index = [ - slice(rank * tensor.size(dim), (rank + 1) * tensor.size(dim), None) - if dim == gather_dim - else slice(None, None, None) - for dim in range(tensor.ndim) - ] + shape = [tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim)] + if gather_dim != 0: + shape[0], shape[gather_dim] = shape[gather_dim], shape[0] tensors = torch.empty(*shape, dtype=tensor.dtype, device=tensor.device) - tensors[index] = tensor + + if gather_dim != 0: + tensor = tensor.transpose(0, gather_dim) + tensor = tensor.contiguous() + dist.all_gather_into_tensor(tensors, tensor, group=group) + if gather_dim != 0: + tensors = tensors.transpose(0, gather_dim).contiguous() return tensors @@ -69,6 +66,7 @@ def scatter( ) -> torch.Tensor: world_size = dist.get_world_size(group) if world_size == 1: + output_tensor.copy_(tensor) return tensor rank = dist.get_rank(group) diff --git a/optimum/fx/parallelization/parallel_layers/embedding.py b/optimum/fx/parallelization/parallel_layers/embedding.py index 4cd21f9ebc..eb8cc9b294 100644 --- a/optimum/fx/parallelization/parallel_layers/embedding.py +++ b/optimum/fx/parallelization/parallel_layers/embedding.py @@ -12,9 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Callable - import torch import torch.distributed as dist import torch.nn as nn @@ -30,17 +27,11 @@ class VocabParallelEmbedding(nn.Module): Embedding layer parallelized in vocabulary dimension. Arguments: - ctx: parallel execution context which contains runtime information. - embedding: the original embedding module being replaced. - init_fn: weight initialization function. + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + embedding(`torch.nn.Embedding`): the original embedding module being replaced. """ - def __init__( - self, - ctx: ParallelExecutionCtx, - embedding: nn.Embedding, - init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), - ): + def __init__(self, ctx: ParallelExecutionCtx, embedding: nn.Embedding): super(VocabParallelEmbedding, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) @@ -66,8 +57,8 @@ def __init__( assert weight_meta.is_tied, "only tied parameters could already have modified meta" else: weight_meta.need_initialize = True + weight_meta.is_parallel = True weight_meta.dim = 0 - weight_meta.init_fn = init_fn for _, Slice in weight_meta.mapping.items(): Slice.index = slice(self.vocab_start_idx, self.vocab_end_idx) weight_meta.is_modified_meta = True diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 71c7d9d1b5..62d5894dac 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -12,9 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Callable - import torch import torch.distributed as dist import torch.nn as nn @@ -41,19 +38,12 @@ class ColumnParallelLinear(nn.Module): its second dimension as A = [A_1, ..., A_p]. Arguments: - ctx: parallel execution context which contains runtime information. - linear: the original linear module being replaced. - gather_output: whether gathering output in the end of forward. - init_fn: weight initialization function. + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + linear(`torch.nn.Linear`): the original linear module being replaced. + gather_output(`bool`, defaults to `True`): whether gathering output in the end of forward. """ - def __init__( - self, - ctx: ParallelExecutionCtx, - linear: nn.Linear, - gather_output: bool = True, - init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), - ) -> None: + def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None: super(ColumnParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) @@ -73,8 +63,8 @@ def __init__( assert weight_meta.is_tied, "only tied parameters could already have modified meta" else: weight_meta.need_initialize = True + weight_meta.is_parallel = True weight_meta.dim = 0 - weight_meta.init_fn = init_fn for _, Slice in weight_meta.mapping.items(): Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) weight_meta.is_modified_meta = True @@ -93,6 +83,7 @@ def __init__( assert bias_meta.is_tied, "only tied parameters could already have modified meta" else: bias_meta.need_initialize = True + bias_meta.is_parallel = True bias_meta.init_fn = torch.zero_ bias_meta.dim = 0 for _, Slice in bias_meta.mapping.items(): @@ -124,19 +115,12 @@ class RowParallelLinear(nn.Module): | A_p | - - Arguments: - ctx: parallel execution context which contains runtime information. - linear: the original lineat module being replaced. - input_is_parallel: whether the input tensor has already been parallelized. - init_fn: weight initialization function. + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + linear(`torch.nn.Linear`): the original linear module being replaced. + input_is_parallel(`bool`, defaults to `True`): whether the input tensor has already been parallelized. """ - def __init__( - self, - ctx: ParallelExecutionCtx, - linear: nn.Linear, - input_is_parallel: bool = False, - init_fn: Callable[[torch.Tensor], torch.Tensor] = partial(nn.init.normal_, mean=0, std=0.02), - ) -> None: + def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False) -> None: super(RowParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) @@ -156,8 +140,8 @@ def __init__( assert weight_meta.is_tied, "only tied parameters could already have modified meta" else: weight_meta.need_initialize = True + weight_meta.is_parallel = True weight_meta.dim = 1 - weight_meta.init_fn = init_fn for _, Slice in weight_meta.mapping.items(): Slice.index = slice(tp_rank * in_features, (tp_rank + 1) * in_features) weight_meta.is_modified_meta = True diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index eda1d02b69..bdfc56cd21 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -495,41 +495,42 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf # skip already initialized parameters if not param_meta.need_initialize: continue + # skip already initialized tied parameters if param_meta.is_tied and id(param) in tied_parameters: new_parameters.append((name, tied_parameters[id(param)])) continue shape = [ - param.size(dim) // world_size if dim == param_meta.dim else param.size(dim) + param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) for dim in range(param.ndim) ] + new_param = nn.Parameter( - torch.randn(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad + torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad ) + for source, target in sorted(param_meta.mapping.items()): if target.source in ctx.weight_map: # TODO: add weights loading logic continue - if tp_rank == 0: + if not param_meta.is_parallel or tp_rank == 0: # initialize weight on master rank - start = source.start if source.start else 0 - stop = source.stop if source.stop else start + param.size(param_meta.dim) // world_size - shape = [ - (stop - start) * world_size if dim == param_meta.dim else param.size(dim) - for dim in range(param.ndim) - ] - weight = torch.empty(*shape, dtype=param.dtype, device="cpu") + weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn init_fn(weight) weight = weight.to(ctx.current_device) else: weight = None + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] with torch.no_grad(): - index = [ - source.to_slice() if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + if param_meta.is_parallel: + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + else: + new_param.data[index].copy_(weight) + setattr(new_param, "meta", param_meta) new_parameters.append((name, new_param)) if param_meta.is_tied: @@ -583,7 +584,7 @@ def __iter__( ): return self._passes.__iter__() - def append(self, PASS: PassBase): + def append(self, PASS: PassBase) -> None: self._passes.append(PASS) def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 59c968ec5b..68b7a804b1 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -274,7 +274,10 @@ def initialize_parameter_meta(model: nn.Module) -> None: setattr( tensor, "meta", - ParameterMeta(dim=0, mapping={HashableSlice(None, None, None): ParameterSlice(source=name)}), + ParameterMeta( + dim=0, + mapping={HashableSlice(None, None, None): ParameterSlice(source=name, shape=tuple(tensor.shape))}, + ), ) parameter_ids.add(key) else: diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index a1a3ac09cb..9875b02a49 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -145,6 +145,9 @@ def run_test_parameters_persist_bewteen_recompile( parameter_ids = {id(param) for _, param in model.named_parameters()} model(**another_inputs) + # check second compilation has been triggered + assert ctx.compile_times > 1 + parameter_ids_after_recompile = {id(param) for _, param in model.named_parameters()} assert parameter_ids == parameter_ids_after_recompile @@ -236,7 +239,8 @@ def test_parameters_persist_bewteen_recompile( @parameterized.expand(DUMMY_MODELS_TO_TEST) @unittest.skipIf( - not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" + not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, + "requires more than one gpu and torch version >= 2.3.0 to run", ) def test_parallel_results_matches_non_parallel( model_cls, From 0512b23983d80933f53c2b3e22fd1d28a8c404cf Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 13 Jul 2024 01:10:02 +0200 Subject: [PATCH 09/25] fix --- optimum/fx/parallelization/passes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index bdfc56cd21..d4d563d5b6 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -372,7 +372,7 @@ class ParallelLayerReplacePass(PassBase): """ @staticmethod - def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -396,17 +396,17 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None gather_output = ParallelLayerAnnotatePass.get_stored_field_info( node, field="gather_output", must_have=True ) - new_mod = ColumnParallelLinear(ctx, mod, gather_output, config.weight_init_fn) + new_mod = ColumnParallelLinear(ctx, mod, gather_output) else: input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info( node, field="input_is_parallel", must_have=True ) - new_mod = RowParallelLinear(ctx, mod, input_is_parallel, config.weight_init_fn) + new_mod = RowParallelLinear(ctx, mod, input_is_parallel) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @staticmethod - def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> None: + def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -426,7 +426,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx, config: Config) -> N if key in layer_cache: new_mod = layer_cache[key] else: - new_mod = VocabParallelEmbedding(ctx, mod, config.weight_init_fn) + new_mod = VocabParallelEmbedding(ctx, mod) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @@ -468,9 +468,9 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: for node in graph_module.graph.nodes: if is_linear(node): - self.handle_linear(node, ctx, config) + self.handle_linear(node, ctx) elif is_embedding(node): - self.handle_embedding(node, ctx, config) + self.handle_embedding(node, ctx) # correct the attention head num in parallel setting elif is_shape_consumer(node): self.handle_hard_coded_axis_param(node, ctx) From 8ec67277695a06026474236610887287920d40a5 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 13 Jul 2024 19:41:57 +0200 Subject: [PATCH 10/25] debug --- .github/workflows/test_fx_automatic_parallel.yml | 3 ++- optimum/fx/parallelization/passes.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml index d745b8c724..4b1cc21952 100644 --- a/.github/workflows/test_fx_automatic_parallel.yml +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -33,7 +33,8 @@ jobs: container: image: ${{ matrix.config.image }} options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ - + env: + NCCL_DEBUG: INFO defaults: run: shell: bash diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index d4d563d5b6..6574f5e883 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -426,6 +426,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: if key in layer_cache: new_mod = layer_cache[key] else: + assert ctx.compile_times == 0, "illegal path for recompilation" new_mod = VocabParallelEmbedding(ctx, mod) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) From 5095f1ed51cc5f757ee2d960b47c183f99537b2c Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Sat, 13 Jul 2024 20:43:16 +0200 Subject: [PATCH 11/25] fix --- optimum/fx/parallelization/core.py | 4 ++-- optimum/fx/parallelization/passes.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index bd50d0d059..1e89f0e6ed 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -117,7 +117,7 @@ class ParallelExecutionCtx: - example_inputs (`List[Any]`): A list of tensors which are used as example inputs for graphs captured by dynamo. - - parallel_layer_cache (`Dict[int, nn.Module]`): + - parallel_layer_cache (`Dict[str, nn.Module]`): Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. Note that we will build the cache in the first compilation process, and for recompilations later on, we will directly replace the modules with their parallel counterparts in the cache, @@ -135,7 +135,7 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device example_inputs: List[Any] = field(default_factory=list) - parallel_layer_cache: Dict[int, nn.Module] = field(default_factory=dict) + parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) weight_map: Dict[str, str] = field(default_factory=dict) compile_times: int = 0 diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 6574f5e883..d14abc6b6a 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -388,7 +388,7 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Linear = graph_module.get_submodule(node.target) - key, layer_cache = id(mod), ctx.parallel_layer_cache + key, layer_cache = node.target, ctx.parallel_layer_cache if key in layer_cache: new_mod = layer_cache[key] else: @@ -422,7 +422,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Embedding = graph_module.get_submodule(node.target) - key, layer_cache = id(mod), ctx.parallel_layer_cache + key, layer_cache = node.target, ctx.parallel_layer_cache if key in layer_cache: new_mod = layer_cache[key] else: From f6ebfc0e2561febdf5671d6879229cebc7b6e08f Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 15 Jul 2024 02:53:19 +0200 Subject: [PATCH 12/25] fix tests --- optimum/fx/parallelization/api.py | 1 + optimum/fx/parallelization/core.py | 5 + .../parallelization/test_tensor_parallel.py | 168 ++++++++++++------ 3 files changed, 121 insertions(+), 53 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 7f3d0e737b..e870e64385 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -28,4 +28,5 @@ def parallelize_backend( pass_pipeline = build_parallel_pass_pipeline() graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) ctx.compile_times += 1 + ctx.last_optimized_graph_module = graph_module return graph_module diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 1e89f0e6ed..cba7d45444 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -19,6 +19,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.fx import GraphModule class HashableSlice: @@ -128,6 +129,9 @@ class ParallelExecutionCtx: Mapping between parameter names and their locations on disk, useful when loading weights from disk. + - last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`): + Optimized graph module corresponding to the latest compilation. + - compile_times (`int`, defaults to `0`): Number of compilation times happened during the whole process. """ @@ -137,6 +141,7 @@ class ParallelExecutionCtx: example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) weight_map: Dict[str, str] = field(default_factory=dict) + last_optimized_graph_module: Optional[GraphModule] = None compile_times: int = 0 diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 9875b02a49..b0211c2a1c 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -14,7 +14,7 @@ # limitations under the License. import unittest from functools import partial -from typing import Type +from typing import Type, Union import torch import torch.distributed as dist @@ -32,17 +32,35 @@ ) from optimum.fx.parallelization import Config, ParallelExecutionCtx, parallelize_backend -from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher, initialize_parameter_meta, move_model_to_device +from optimum.fx.parallelization.parallel_layers import ColumnParallelLinear, VocabParallelEmbedding +from optimum.fx.parallelization.utils import ( + MetaAwareMethodsPatcher, + initialize_parameter_meta, + move_model_to_device, + stable_topological_sort, +) DUMMY_MODELS_TO_TEST = ( ( LlamaForCausalLM, - LlamaConfig, + LlamaConfig( + num_hidden_layers=2, + tie_word_embeddings=True, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + ), ), ( MistralForCausalLM, - MistralConfig, + MistralConfig( + num_hidden_layers=2, + tie_word_embeddings=True, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + ), ), ) @@ -55,19 +73,22 @@ def is_torch_compile_available(): return version.parse(torch.__version__) >= version.parse("2.3.0") -def dummify(config: PretrainedConfig): - config.num_hidden_layers = 2 - config.use_cache = False - config.output_attentions = False - config.output_hidden_states = False +def prepare_dummy_inputs( + model_config: PretrainedConfig, + batch_size: int = 1, + seq_len: int = 10, + device: Union[str, torch.device] = "cuda", +): + return { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device), + "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1), + } def run_test_all_rank_results_match( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] + rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig ): - model_config = config_cls() - dummify(model_config) - # initialize default group dist_init(rank, world_size) tp_group = dist.new_group() @@ -76,12 +97,7 @@ def run_test_all_rank_results_match( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - inputs = { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), - "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), - } - + inputs = prepare_dummy_inputs(model_config) # this will initialize all linears on meta device with MetaAwareMethodsPatcher(): model = model_cls(model_config) @@ -105,11 +121,8 @@ def run_test_all_rank_results_match( def run_test_parameters_persist_bewteen_recompile( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] + rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig ): - model_config = config_cls() - dummify(model_config) - # initialize default group dist_init(rank, world_size) tp_group = dist.new_group() @@ -118,18 +131,11 @@ def run_test_parameters_persist_bewteen_recompile( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - inputs = { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), - "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), - } + inputs = prepare_dummy_inputs(model_config) # different shape to trigger recompile - another_inputs = { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 11), device=device), - "attention_mask": torch.ones((1, 11), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, 11, device=device).unsqueeze(0), - } + another_inputs = prepare_dummy_inputs(model_config, seq_len=11) + yet_another_inputs = prepare_dummy_inputs(model_config, batch_size=2, seq_len=12) # this will initialize all linears on meta device with MetaAwareMethodsPatcher(): @@ -141,26 +147,26 @@ def run_test_parameters_persist_bewteen_recompile( model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) model(**inputs) + parameter_ids = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} - parameter_ids = {id(param) for _, param in model.named_parameters()} model(**another_inputs) - # check second compilation has been triggered - assert ctx.compile_times > 1 - - parameter_ids_after_recompile = {id(param) for _, param in model.named_parameters()} + assert ctx.compile_times == 2 + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} assert parameter_ids == parameter_ids_after_recompile + model(**yet_another_inputs) + assert ctx.compile_times == 3 + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + assert parameter_ids == parameter_ids_after_recompile dist.barrier(tp_group) tearDown(tp_group) def run_test_parallel_results_matches_non_parallel( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], config_cls: Type[PretrainedConfig] + rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig ): - model_config = config_cls() - dummify(model_config) - + # initialize default group dist_init(rank, world_size) tp_group = dist.new_group(ranks=[rank]) @@ -168,11 +174,7 @@ def run_test_parallel_results_matches_non_parallel( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - inputs = { - "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(1, 10), device=device), - "attention_mask": torch.ones((1, 10), dtype=torch.int64, device=device), - "position_ids": torch.arange(0, 10, device=device).unsqueeze(0), - } + inputs = prepare_dummy_inputs(model_config) set_seed(SEED) # non-parallel local forward @@ -209,17 +211,63 @@ def run_test_parallel_results_matches_non_parallel( tearDown() +def run_test_tie_word_embeddings( + rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig +): + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type="cuda", index=torch.cuda.current_device()) + ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + + inputs = prepare_dummy_inputs(model_config) + + with MetaAwareMethodsPatcher(): + model = model_cls(model_config) + model.eval() + + move_model_to_device(model, device=device) + initialize_parameter_meta(model) + + model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + model(**inputs) + + embedding_weight, lm_head_weight = None, None + graph_module = ctx.last_optimized_graph_module + stable_topological_sort(graph_module.graph) + for node in graph_module.graph.nodes: + if node.op == "call_module": + mod = graph_module.get_submodule(node.target) + if isinstance(mod, VocabParallelEmbedding): + embedding_weight = mod.weight + break + for node in reversed(graph_module.graph.nodes): + if node.op == "call_module": + mod = graph_module.get_submodule(node.target) + if isinstance(mod, ColumnParallelLinear): + lm_head_weight = mod.weight + break + assert ( + id(embedding_weight) == id(lm_head_weight) + and hasattr(embedding_weight, "meta") + and embedding_weight.meta.is_tied + ) + dist.barrier(tp_group) + tearDown() + + @parameterized.expand(DUMMY_MODELS_TO_TEST) @unittest.skipIf( not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" ) def test_all_rank_results_match( model_cls, - config_cls, + model_config, ): for world_size in [1, 2, 4, 8]: if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_all_rank_results_match, model_cls, config_cls, deterministic=True) + spawn(world_size, run_test_all_rank_results_match, model_cls, model_config, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -228,12 +276,12 @@ def test_all_rank_results_match( ) def test_parameters_persist_bewteen_recompile( model_cls, - config_cls, + model_config, ): for world_size in [1, 2]: if world_size <= NUM_AVAILABLE_DEVICES: spawn( - world_size, run_test_parameters_persist_bewteen_recompile, model_cls, config_cls, deterministic=False + world_size, run_test_parameters_persist_bewteen_recompile, model_cls, model_config, deterministic=False ) @@ -244,7 +292,21 @@ def test_parameters_persist_bewteen_recompile( ) def test_parallel_results_matches_non_parallel( model_cls, - config_cls, + model_config, ): # world_size == 2 is enough - spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, config_cls, deterministic=True) + spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, model_config, deterministic=True) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), + "requires gpu and torch version >= 2.3.0 to run", +) +def test_tie_word_embeddings( + model_cls, + model_config, +): + for world_size in [1, 2]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_tie_word_embeddings, model_cls, model_config, deterministic=False) From e71e5eada8d446c7f62119728c7bb0c161560afe Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 16 Jul 2024 23:51:04 +0200 Subject: [PATCH 13/25] add experimental API --- optimum/fx/parallelization/api.py | 113 +++++++++++++++++++++++++- optimum/fx/parallelization/utils.py | 120 +++++++++++++++++++++++++++- 2 files changed, 231 insertions(+), 2 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index e870e64385..40bbaa04f0 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -12,13 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +import glob +import importlib +import json +import os +from functools import partial +from typing import List, Optional, Union import torch from torch.fx import GraphModule from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline +from .utils import ( + MetaAwareMethodsPatcher, + convert_bin_to_safetensors, + download_files_from_hf, + initialize_parameter_meta, + move_model_to_device, +) def parallelize_backend( @@ -30,3 +42,102 @@ def parallelize_backend( ctx.compile_times += 1 ctx.last_optimized_graph_module = graph_module return graph_module + + +def parallelize_model( + model: Union[torch.nn.Module, str], + parallel_ctx: ParallelExecutionCtx, + *model_args, + revision: str = "main", + cache_dir: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + API for automatic model parallelism through Pytorch FX. + + Args: + model (Union[torch.nn.Module, str]): + Model to parallelize, could either be a module or a model id in huggingface space. + parallel_ctx (ParallelExecutionCtx): + Parallel execution context containing process groups the current process belongs to. + model_args (additional postional arguments, optional): + Additional postional arguments for intializing the model if a model id is passed. + revision (`str`, defaults to `main`): + Model revision for weights downloading if a model id is passed. + cache_dir (`Optional[str]`, defaults to `None`): + Cache directory to store downloaded weights. Defaults to None. + local_files_only (bool, defaults to `False`): + Whether to use local files only, will avoid downloading from remote if set to `True`. + kwargs (additional keyword arguments, optional): + Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. + """ + from safetensors import safe_open + from transformers import AutoConfig + from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + parallel_config = Config() + for k, v in kwargs.items(): + if k in parallel_config.__dict__: + setattr(parallel_config, k, v) + kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} + + if isinstance(model, str): + is_local = os.path.isdir(model) + use_safetensors = False + allow_patterns = ["*.safetensors", "*.bin"] + if not is_local: + hf_folder = download_files_from_hf( + model_name_or_path=model, + cache_dir=cache_dir, + allow_patterns=allow_patterns, + revision=revision, + local_files_only=local_files_only, + ) + else: + hf_folder = model + for pattern in allow_patterns: + if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: + use_safetensors = pattern == "*.safetensors" + break + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) + config_path = os.path.join(hf_folder, CONFIG_NAME) + if not os.path.isfile(config_path): + raise EnvironmentError(f"Can't find config file {config_path} in {hf_folder}") + + with open(config_path) as f: + config_dict = json.load(f) + model_arch = config_dict["architectures"] + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + + index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) + if os.path.isfile(index_path): + with open(index_path) as f: + index_dict = json.load(f) + parallel_ctx.weight_map = index_dict["weight_map"] + weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) + if not use_safetensors: + weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} + convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) + parallel_ctx.weight_map = weight_map + + # try directly construct weight_map from weight files, should have safetensors file on disk in any case + if not parallel_ctx.weight_map: + weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) + for weight_file in weight_files: + with safe_open(filename=weight_file, framework="pt") as f: + for key in f.keys(): + weight_map[key] = weight_file + parallel_ctx.weight_map = weight_map + + with MetaAwareMethodsPatcher(): + model = model_cls(model_config, *model_args, **kwargs) + + move_model_to_device(model, device=parallel_ctx.current_device) + initialize_parameter_meta(model) + backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config) + model = torch.compile(model, fullgraph=True, backend=backend) + return model diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 68b7a804b1..8881b36432 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -12,13 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import fnmatch +import hashlib import importlib import operator +import os +import re +import tempfile from collections import defaultdict from functools import wraps from itertools import chain -from typing import Callable, Dict, List, Union +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union +import filelock import torch import torch.nn as nn import torch.nn.functional as F @@ -305,3 +312,114 @@ def move_model_to_device(model: nn.Module, device: Union[torch.device, str]): if isinstance(tensor, nn.Parameter): new_tensor = nn.Parameter(new_tensor) setattr(parent_mod, attr_name, new_tensor) + + +temp_dir = tempfile.gettempdir() + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +# adpated from vllm.model_executor.model_loader.weight_utils.py +def download_files_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + local_files_only: bool = False, +) -> str: + """Download model weights, index and config files from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + local_files_only(bool): Should only use local files if True. + + Returns: + str: The path to the downloaded files. + """ + import huggingface_hub.constants + from huggingface_hub import HfFileSystem, snapshot_download + from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + extra_patterns = [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME] + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns + extra_patterns, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE or local_files_only, + ) + return hf_folder + + +# copied from optimum.neuron.utils.misc.py +def _original_filename_to_safetensors_filename(filename: str) -> str: + """Transforms the filename for any kind of checkpoint to a safetensors equivalent.""" + from transformers.utils import SAFE_WEIGHTS_NAME + + _, extension = filename.rsplit(".", maxsplit=1) + pattern = rf"\w+(-[0-9]*-of-[0-9]*)?\.{extension}" + match_ = re.match(pattern, filename) + if not match_: + raise ValueError(f"Could not convert {filename} to a safetensor filename.") + group_1 = match_.group(1) + index_out_of_total_str = group_1 if group_1 is not None else "" + safetensor_filename, safetensor_extension = SAFE_WEIGHTS_NAME.rsplit(".", maxsplit=1) + return f"{safetensor_filename}{index_out_of_total_str}.{safetensor_extension}" + + +def convert_bin_to_safetensors( + model_name_or_path: str, cache_dir: Optional[str], weight_files: List[str], weight_map: Dict[str, str] +): + """Convert to pytorch bin files to their safetensors equivalent.""" + from safetensors.torch import save_file + + with get_lock(model_name_or_path, cache_dir): + for weight_file in weight_files: + weight_file_path = Path(weight_file) + safetensors_filename = _original_filename_to_safetensors_filename(weight_file_path.name) + output_dir = cache_dir if cache_dir else weight_file_path.parent + output_file_path = os.path.join(output_dir, safetensors_filename) + if not os.path.isfile(output_file_path): + checkpoint = torch.load(weight_file, map_location=torch.device("cpu")) + data_pointers = set() + for k, v in checkpoint.items(): + if v.data_ptr() in data_pointers: + v = v.detach().clone() + v = v.contiguous() + checkpoint[k] = v + data_pointers.add(v.data_ptr()) + save_file(checkpoint, output_file_path) + keys = [key for key, value in weight_map if value == weight_file] + for key in keys: + weight_map[key] = output_file_path From 779c77dacf96291d0611b86992e33d04a49d3fc8 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 16 Jul 2024 23:54:21 +0200 Subject: [PATCH 14/25] nit --- optimum/fx/parallelization/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 40bbaa04f0..774a676a82 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -67,7 +67,7 @@ def parallelize_model( Model revision for weights downloading if a model id is passed. cache_dir (`Optional[str]`, defaults to `None`): Cache directory to store downloaded weights. Defaults to None. - local_files_only (bool, defaults to `False`): + local_files_only (`bool`, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. kwargs (additional keyword arguments, optional): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. From e09df2a86dabafd30bcf6f6f7250f273dfc3d6c2 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 17 Jul 2024 22:41:17 +0200 Subject: [PATCH 15/25] fix api --- optimum/fx/parallelization/api.py | 60 ++++++++++++++++------------- optimum/fx/parallelization/utils.py | 22 +++++++---- 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 774a676a82..a834dd5203 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -51,6 +51,7 @@ def parallelize_model( revision: str = "main", cache_dir: Optional[str] = None, local_files_only: bool = False, + skip_load_weights: bool = False, **kwargs, ): """ @@ -69,13 +70,11 @@ def parallelize_model( Cache directory to store downloaded weights. Defaults to None. local_files_only (`bool`, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. + skip_load_weights (`bool`, defaults to `False`): + Whether to skip loading weights from disk to model. kwargs (additional keyword arguments, optional): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. """ - from safetensors import safe_open - from transformers import AutoConfig - from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - parallel_config = Config() for k, v in kwargs.items(): if k in parallel_config.__dict__: @@ -83,8 +82,10 @@ def parallelize_model( kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} if isinstance(model, str): + from transformers import AutoConfig + from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + is_local = os.path.isdir(model) - use_safetensors = False allow_patterns = ["*.safetensors", "*.bin"] if not is_local: hf_folder = download_files_from_hf( @@ -93,13 +94,11 @@ def parallelize_model( allow_patterns=allow_patterns, revision=revision, local_files_only=local_files_only, + skip_download_weights=skip_load_weights, ) else: hf_folder = model - for pattern in allow_patterns: - if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: - use_safetensors = pattern == "*.safetensors" - break + # should be able to load config using only local files model_config, kwargs = AutoConfig.from_pretrained( hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs @@ -113,25 +112,32 @@ def parallelize_model( model_arch = config_dict["architectures"] model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) - index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) - if os.path.isfile(index_path): - with open(index_path) as f: - index_dict = json.load(f) - parallel_ctx.weight_map = index_dict["weight_map"] - weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) - if not use_safetensors: - weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} - convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) - parallel_ctx.weight_map = weight_map + if not skip_load_weights: + use_safetensors = False + for pattern in allow_patterns: + if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: + use_safetensors = pattern == "*.safetensors" + break + index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) + if os.path.isfile(index_path): + with open(index_path) as f: + index_dict = json.load(f) + parallel_ctx.weight_map = index_dict["weight_map"] + weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) + if not use_safetensors: + weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} + convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) + parallel_ctx.weight_map = weight_map - # try directly construct weight_map from weight files, should have safetensors file on disk in any case - if not parallel_ctx.weight_map: - weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) - for weight_file in weight_files: - with safe_open(filename=weight_file, framework="pt") as f: - for key in f.keys(): - weight_map[key] = weight_file - parallel_ctx.weight_map = weight_map + # try directly construct weight_map from weight files, should have safetensors file on disk in any case + if not parallel_ctx.weight_map: + from safetensors import safe_open + weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) + for weight_file in weight_files: + with safe_open(filename=weight_file, framework="pt") as f: + for key in f.keys(): + weight_map[key] = weight_file + parallel_ctx.weight_map = weight_map with MetaAwareMethodsPatcher(): model = model_cls(model_config, *model_args, **kwargs) diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 8881b36432..3ffb3d380c 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -336,18 +336,20 @@ def download_files_from_hf( allow_patterns: List[str], revision: Optional[str] = None, local_files_only: bool = False, + skip_download_weights: bool = False, ) -> str: """Download model weights, index and config files from Hugging Face Hub. Args: - model_name_or_path (str): The model name or path. - cache_dir (Optional[str]): The cache directory to store the model + model_name_or_path (`str`): The model name or path. + cache_dir (`Optional[str]`): The cache directory to store the model weights. If None, will use HF defaults. - allow_patterns (List[str]): The allowed patterns for the + allow_patterns (`List[str]`): The allowed patterns for the weight files. Files matched by any of the patterns will be downloaded. - revision (Optional[str]): The revision of the model. - local_files_only(bool): Should only use local files if True. + revision (`Optional[str]`, defaults to `None`): The revision of the model. + local_files_only(`bool`): Should only use local files if True. + skip_download_weights (`bool`, defaults to `False`): Whether to skip downloading weights to disk. Returns: str: The path to the downloaded files. @@ -356,7 +358,7 @@ def download_files_from_hf( from huggingface_hub import HfFileSystem, snapshot_download from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - if not huggingface_hub.constants.HF_HUB_OFFLINE: + if not skip_download_weights and not huggingface_hub.constants.HF_HUB_OFFLINE: # Before we download we look at that is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -368,13 +370,17 @@ def download_files_from_hf( allow_patterns = [pattern] break - extra_patterns = [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME] + if skip_download_weights: + allow_patterns = [CONFIG_NAME] + else: + allow_patterns = allow_patterns + [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME] + # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): hf_folder = snapshot_download( model_name_or_path, - allow_patterns=allow_patterns + extra_patterns, + allow_patterns=allow_patterns, cache_dir=cache_dir, revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE or local_files_only, From 9fd29d1dba4250e231fa34716db6227506545d02 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 18 Jul 2024 00:16:11 +0000 Subject: [PATCH 16/25] fix api --- optimum/fx/parallelization/__init__.py | 2 +- optimum/fx/parallelization/api.py | 12 +++++------- optimum/fx/parallelization/utils.py | 7 +++++++ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py index bb42a0f133..701badd4d5 100644 --- a/optimum/fx/parallelization/__init__.py +++ b/optimum/fx/parallelization/__init__.py @@ -12,5 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .api import parallelize_backend +from .api import parallelize_backend, parallelize_model from .core import Config, ParallelExecutionCtx diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index a834dd5203..8e1cabf3a1 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -83,7 +83,7 @@ def parallelize_model( if isinstance(model, str): from transformers import AutoConfig - from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME is_local = os.path.isdir(model) allow_patterns = ["*.safetensors", "*.bin"] @@ -103,13 +103,9 @@ def parallelize_model( model_config, kwargs = AutoConfig.from_pretrained( hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs ) - config_path = os.path.join(hf_folder, CONFIG_NAME) - if not os.path.isfile(config_path): - raise EnvironmentError(f"Can't find config file {config_path} in {hf_folder}") - with open(config_path) as f: - config_dict = json.load(f) - model_arch = config_dict["architectures"] + # try getting model class info from config + model_arch = model_config.architectures model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) if not skip_load_weights: @@ -141,6 +137,8 @@ def parallelize_model( with MetaAwareMethodsPatcher(): model = model_cls(model_config, *model_args, **kwargs) + # TODO: remove this once support training-time trace + model.eval() move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 3ffb3d380c..8df2db885f 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -23,6 +23,7 @@ from functools import wraps from itertools import chain from pathlib import Path +from tqdm.auto import tqdm from typing import Callable, Dict, List, Optional, Union import filelock @@ -329,6 +330,11 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): return lock +class DisabledTqdm(tqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + # adpated from vllm.model_executor.model_loader.weight_utils.py def download_files_from_hf( model_name_or_path: str, @@ -384,6 +390,7 @@ def download_files_from_hf( cache_dir=cache_dir, revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE or local_files_only, + tqdm_class=DisabledTqdm, ) return hf_folder From 01cfc256f5861e843d9cd7f734c0acde2b64f084 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 18 Jul 2024 00:19:21 +0000 Subject: [PATCH 17/25] format --- optimum/fx/parallelization/api.py | 1 + optimum/fx/parallelization/utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 8e1cabf3a1..01bb9259e4 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -128,6 +128,7 @@ def parallelize_model( # try directly construct weight_map from weight files, should have safetensors file on disk in any case if not parallel_ctx.weight_map: from safetensors import safe_open + weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) for weight_file in weight_files: with safe_open(filename=weight_file, framework="pt") as f: diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 8df2db885f..55b1c41347 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -23,7 +23,6 @@ from functools import wraps from itertools import chain from pathlib import Path -from tqdm.auto import tqdm from typing import Callable, Dict, List, Optional, Union import filelock @@ -31,6 +30,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.fx import Graph, Node +from tqdm.auto import tqdm from .core import ( HashableSlice, From 8c162679910f45d5fd5839dd0a158b37a5a3b4eb Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 18 Jul 2024 18:04:29 +0000 Subject: [PATCH 18/25] clean tests --- .../workflows/test_fx_automatic_parallel.yml | 1 + .../parallelization/test_tensor_parallel.py | 157 ++++++------------ 2 files changed, 50 insertions(+), 108 deletions(-) diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml index 4b1cc21952..3c913e3f7e 100644 --- a/.github/workflows/test_fx_automatic_parallel.yml +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -35,6 +35,7 @@ jobs: options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ env: NCCL_DEBUG: INFO + HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} defaults: run: shell: bash diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index b0211c2a1c..d12c8689de 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from functools import partial -from typing import Type, Union +from typing import Any, Dict, Union import torch import torch.distributed as dist @@ -22,45 +21,31 @@ from packaging import version from parameterized import parameterized from transformers import ( - LlamaConfig, - LlamaForCausalLM, - MistralConfig, - MistralForCausalLM, PretrainedConfig, - PreTrainedModel, set_seed, ) -from optimum.fx.parallelization import Config, ParallelExecutionCtx, parallelize_backend +from optimum.fx.parallelization import ParallelExecutionCtx, parallelize_model from optimum.fx.parallelization.parallel_layers import ColumnParallelLinear, VocabParallelEmbedding -from optimum.fx.parallelization.utils import ( - MetaAwareMethodsPatcher, - initialize_parameter_meta, - move_model_to_device, - stable_topological_sort, -) +from optimum.fx.parallelization.utils import stable_topological_sort + +DUMMY_MODEL_KWARGS = { + "num_hidden_layers": 2, + "use_cache": False, + "output_attentions": False, + "output_hidden_states": False, + "tie_word_embeddings": True, +} DUMMY_MODELS_TO_TEST = ( ( - LlamaForCausalLM, - LlamaConfig( - num_hidden_layers=2, - tie_word_embeddings=True, - use_cache=False, - output_attentions=False, - output_hidden_states=False, - ), + "meta-llama/Llama-2-7b-hf", + DUMMY_MODEL_KWARGS, ), ( - MistralForCausalLM, - MistralConfig( - num_hidden_layers=2, - tie_word_embeddings=True, - use_cache=False, - output_attentions=False, - output_hidden_states=False, - ), + "mistralai/Mistral-7B-v0.1", + DUMMY_MODEL_KWARGS, ), ) @@ -86,27 +71,17 @@ def prepare_dummy_inputs( } -def run_test_all_rank_results_match( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig -): +def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): # initialize default group dist_init(rank, world_size) tp_group = dist.new_group() # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) - ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - - inputs = prepare_dummy_inputs(model_config) - # this will initialize all linears on meta device - with MetaAwareMethodsPatcher(): - model = model_cls(model_config) - model.eval() - # move model to current device, with linears still on meta, and intialize parameter mapping - move_model_to_device(model, device=device) - initialize_parameter_meta(model) - - model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) logits = model(**inputs)[0] tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) @@ -121,7 +96,7 @@ def run_test_all_rank_results_match( def run_test_parameters_persist_bewteen_recompile( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig + rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any] ): # initialize default group dist_init(rank, world_size) @@ -129,23 +104,15 @@ def run_test_parameters_persist_bewteen_recompile( # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) - ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - inputs = prepare_dummy_inputs(model_config) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) # different shape to trigger recompile - another_inputs = prepare_dummy_inputs(model_config, seq_len=11) - yet_another_inputs = prepare_dummy_inputs(model_config, batch_size=2, seq_len=12) - - # this will initialize all linears on meta device - with MetaAwareMethodsPatcher(): - model = model_cls(model_config) - model.eval() - # move model to current device, with linears still on meta - move_model_to_device(model, device=device) - initialize_parameter_meta(model) - - model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + another_inputs = prepare_dummy_inputs(model.config, seq_len=11) + yet_another_inputs = prepare_dummy_inputs(model.config, batch_size=2, seq_len=12) + model(**inputs) parameter_ids = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} @@ -164,7 +131,7 @@ def run_test_parameters_persist_bewteen_recompile( def run_test_parallel_results_matches_non_parallel( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig + rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any] ): # initialize default group dist_init(rank, world_size) @@ -172,37 +139,21 @@ def run_test_parallel_results_matches_non_parallel( # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) - ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - inputs = prepare_dummy_inputs(model_config) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) set_seed(SEED) - # non-parallel local forward - with MetaAwareMethodsPatcher(): - model = model_cls(model_config) - model.eval() - - # move model to current device, with linears still on meta - move_model_to_device(model, device=device) - initialize_parameter_meta(model) - - model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) logits = model(**inputs)[0] + torch._dynamo.reset() del model tp_group = dist.new_group() set_seed(SEED) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - with MetaAwareMethodsPatcher(): - model = model_cls(model_config) - model.eval() - - # move model to current device, with linears still on meta - move_model_to_device(model, device=device) - initialize_parameter_meta(model) - - model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) parallel_logits = model(**inputs)[0] torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) @@ -211,26 +162,16 @@ def run_test_parallel_results_matches_non_parallel( tearDown() -def run_test_tie_word_embeddings( - rank: int, world_size: int, model_cls: Type[PreTrainedModel], model_config: PretrainedConfig -): +def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): dist_init(rank, world_size) tp_group = dist.new_group() # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) - ctx, cfg = ParallelExecutionCtx(tp_group=tp_group, current_device=device), Config() - - inputs = prepare_dummy_inputs(model_config) - - with MetaAwareMethodsPatcher(): - model = model_cls(model_config) - model.eval() - - move_model_to_device(model, device=device) - initialize_parameter_meta(model) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) - model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg)) + inputs = prepare_dummy_inputs(model.config) model(**inputs) embedding_weight, lm_head_weight = None, None @@ -262,12 +203,12 @@ def run_test_tie_word_embeddings( not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" ) def test_all_rank_results_match( - model_cls, - model_config, + model_id, + model_kwargs, ): for world_size in [1, 2, 4, 8]: if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_all_rank_results_match, model_cls, model_config, deterministic=True) + spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -275,13 +216,13 @@ def test_all_rank_results_match( not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" ) def test_parameters_persist_bewteen_recompile( - model_cls, - model_config, + model_id, + model_kwargs, ): for world_size in [1, 2]: if world_size <= NUM_AVAILABLE_DEVICES: spawn( - world_size, run_test_parameters_persist_bewteen_recompile, model_cls, model_config, deterministic=False + world_size, run_test_parameters_persist_bewteen_recompile, model_id, model_kwargs, deterministic=False ) @@ -291,11 +232,11 @@ def test_parameters_persist_bewteen_recompile( "requires more than one gpu and torch version >= 2.3.0 to run", ) def test_parallel_results_matches_non_parallel( - model_cls, - model_config, + model_id, + model_kwargs, ): # world_size == 2 is enough - spawn(2, run_test_parallel_results_matches_non_parallel, model_cls, model_config, deterministic=True) + spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -304,9 +245,9 @@ def test_parallel_results_matches_non_parallel( "requires gpu and torch version >= 2.3.0 to run", ) def test_tie_word_embeddings( - model_cls, - model_config, + model_id, + model_kwargs, ): for world_size in [1, 2]: if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_tie_word_embeddings, model_cls, model_config, deterministic=False) + spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) From 8ef00e033985819d29ca8470beacd85b1307afec Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 18 Jul 2024 19:44:29 +0000 Subject: [PATCH 19/25] fix weight_map --- optimum/fx/parallelization/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 01bb9259e4..772dfdccd1 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -118,7 +118,7 @@ def parallelize_model( if os.path.isfile(index_path): with open(index_path) as f: index_dict = json.load(f) - parallel_ctx.weight_map = index_dict["weight_map"] + parallel_ctx.weight_map = {k : os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) if not use_safetensors: weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} From 6ef2081e58570fadb28fc6ff2dd262524b8a63fb Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 22 Jul 2024 18:59:21 +0000 Subject: [PATCH 20/25] add weights loading --- optimum/fx/parallelization/api.py | 2 +- optimum/fx/parallelization/passes.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 772dfdccd1..35be5b54d4 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -118,7 +118,7 @@ def parallelize_model( if os.path.isfile(index_path): with open(index_path) as f: index_dict = json.load(f) - parallel_ctx.weight_map = {k : os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} + parallel_ctx.weight_map = {k: os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) if not use_safetensors: weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index d14abc6b6a..97f0fccb82 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -511,9 +511,27 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf ) for source, target in sorted(param_meta.mapping.items()): + # weights loading if target.source in ctx.weight_map: - # TODO: add weights loading logic + from safetensors import safe_open + with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp: + tensor_slice = fp.get_slice(target.source) + source_index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + load_index = [ + target.index if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + + tensor = tensor_slice[load_index].contiguous() + tensor = torch.empty_like(tensor).copy_(tensor) + with torch.no_grad(): + new_param.data[source_index].copy_(tensor) continue + + # initialization if not param_meta.is_parallel or tp_rank == 0: # initialize weight on master rank weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") From 2c561d3f173e0992d9184bab3b89d8c0ea6ec2cb Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 22 Jul 2024 19:00:42 +0000 Subject: [PATCH 21/25] format --- optimum/fx/parallelization/passes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 97f0fccb82..cb4d6cc2e1 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -514,6 +514,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf # weights loading if target.source in ctx.weight_map: from safetensors import safe_open + with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp: tensor_slice = fp.get_slice(target.source) source_index = [ From fc96b6f74f1905fddbdeb75f302d0ae09b93f1d8 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 23 Jul 2024 01:33:01 +0200 Subject: [PATCH 22/25] fix --- optimum/fx/parallelization/api.py | 7 ++++ optimum/fx/parallelization/passes.py | 63 +++++++++++++++------------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 35be5b54d4..1490848a6e 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -136,11 +136,18 @@ def parallelize_model( weight_map[key] = weight_file parallel_ctx.weight_map = weight_map + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + with MetaAwareMethodsPatcher(): model = model_cls(model_config, *model_args, **kwargs) # TODO: remove this once support training-time trace model.eval() + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index cb4d6cc2e1..6546ce622d 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -493,25 +493,26 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf new_parameters, tied_parameters = [], {} for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): param_meta: ParameterMeta = getattr(param, "meta") - # skip already initialized parameters - if not param_meta.need_initialize: - continue - # skip already initialized tied parameters + # skip already initialized/loaded tied parameters if param_meta.is_tied and id(param) in tied_parameters: new_parameters.append((name, tied_parameters[id(param)])) continue - shape = [ + shape = ( param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) for dim in range(param.ndim) - ] - - new_param = nn.Parameter( - torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad ) + if shape == tuple(param.size()) and param.device == ctx.current_device: + new_param = param + else: + new_param = nn.Parameter( + torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), + requires_grad=param.requires_grad, + ) + + # load weights if possible for source, target in sorted(param_meta.mapping.items()): - # weights loading if target.source in ctx.weight_map: from safetensors import safe_open @@ -530,29 +531,33 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf tensor = torch.empty_like(tensor).copy_(tensor) with torch.no_grad(): new_param.data[source_index].copy_(tensor) - continue - # initialization - if not param_meta.is_parallel or tp_rank == 0: - # initialize weight on master rank - weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") - init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn - init_fn(weight) - weight = weight.to(ctx.current_device) - else: - weight = None - index = [ - source.to_slice() if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - with torch.no_grad(): - if param_meta.is_parallel: - scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + # weights initialization + if param_meta.need_initialize: + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + continue + if not param_meta.is_parallel or tp_rank == 0: + # initialize weight on master rank + weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") + init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn + init_fn(weight) + weight = weight.to(ctx.current_device) else: - new_param.data[index].copy_(weight) + weight = None + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + with torch.no_grad(): + if param_meta.is_parallel: + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + else: + new_param.data[index].copy_(weight) setattr(new_param, "meta", param_meta) - new_parameters.append((name, new_param)) + if id(new_param) != id(param): + new_parameters.append((name, new_param)) if param_meta.is_tied: tied_parameters[id(param)] = new_param From 8d2cabbd96e0a908e2242ddd3c7d527beabfa975 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 23 Jul 2024 22:40:22 +0200 Subject: [PATCH 23/25] fix --- optimum/fx/parallelization/passes.py | 6 +- optimum/fx/parallelization/utils.py | 2 +- .../parallelization/test_tensor_parallel.py | 76 +++++++++---------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 6546ce622d..1b25e9e123 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -498,12 +498,12 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf new_parameters.append((name, tied_parameters[id(param)])) continue - shape = ( + shape = [ param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) for dim in range(param.ndim) - ) + ] - if shape == tuple(param.size()) and param.device == ctx.current_device: + if not param_meta.is_parallel and param.device == ctx.current_device: new_param = param else: new_param = nn.Parameter( diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 55b1c41347..1bb3d07645 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -433,6 +433,6 @@ def convert_bin_to_safetensors( checkpoint[k] = v data_pointers.add(v.data_ptr()) save_file(checkpoint, output_file_path) - keys = [key for key, value in weight_map if value == weight_file] + keys = [key for key, value in weight_map.items() if value == weight_file] for key in keys: weight_map[key] = output_file_path diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index d12c8689de..fe09d6fff6 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -40,11 +40,11 @@ DUMMY_MODELS_TO_TEST = ( ( - "meta-llama/Llama-2-7b-hf", + "saibo/llama-1B", DUMMY_MODEL_KWARGS, ), ( - "mistralai/Mistral-7B-v0.1", + "PhoenixJie/dummy-mistral", DUMMY_MODEL_KWARGS, ), ) @@ -198,17 +198,17 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode tearDown() -@parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf( - not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" -) -def test_all_rank_results_match( - model_id, - model_kwargs, -): - for world_size in [1, 2, 4, 8]: - if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) +# @parameterized.expand(DUMMY_MODELS_TO_TEST) +# @unittest.skipIf( +# not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +# ) +# def test_all_rank_results_match( +# model_id, +# model_kwargs, +# ): +# for world_size in [1, 2, 4, 8]: +# if world_size <= NUM_AVAILABLE_DEVICES: +# spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -226,28 +226,28 @@ def test_parameters_persist_bewteen_recompile( ) -@parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf( - not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, - "requires more than one gpu and torch version >= 2.3.0 to run", -) -def test_parallel_results_matches_non_parallel( - model_id, - model_kwargs, -): - # world_size == 2 is enough - spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) - - -@parameterized.expand(DUMMY_MODELS_TO_TEST) -@unittest.skipIf( - not is_gpu_available() or not is_torch_compile_available(), - "requires gpu and torch version >= 2.3.0 to run", -) -def test_tie_word_embeddings( - model_id, - model_kwargs, -): - for world_size in [1, 2]: - if world_size <= NUM_AVAILABLE_DEVICES: - spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) +# @parameterized.expand(DUMMY_MODELS_TO_TEST) +# @unittest.skipIf( +# not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, +# "requires more than one gpu and torch version >= 2.3.0 to run", +# ) +# def test_parallel_results_matches_non_parallel( +# model_id, +# model_kwargs, +# ): +# # world_size == 2 is enough +# spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) + + +# @parameterized.expand(DUMMY_MODELS_TO_TEST) +# @unittest.skipIf( +# not is_gpu_available() or not is_torch_compile_available(), +# "requires gpu and torch version >= 2.3.0 to run", +# ) +# def test_tie_word_embeddings( +# model_id, +# model_kwargs, +# ): +# for world_size in [1, 2]: +# if world_size <= NUM_AVAILABLE_DEVICES: +# spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) From 97e6431ae9962bc6b1821e798cd70fc0f9b0a499 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 23 Jul 2024 22:46:33 +0200 Subject: [PATCH 24/25] enable tests --- .../parallelization/test_tensor_parallel.py | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index fe09d6fff6..9626fccec3 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -198,17 +198,17 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode tearDown() -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" -# ) -# def test_all_rank_results_match( -# model_id, -# model_kwargs, -# ): -# for world_size in [1, 2, 4, 8]: -# if world_size <= NUM_AVAILABLE_DEVICES: -# spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_all_rank_results_match( + model_id, + model_kwargs, +): + for world_size in [1, 2, 4, 8]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -226,28 +226,28 @@ def test_parameters_persist_bewteen_recompile( ) -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, -# "requires more than one gpu and torch version >= 2.3.0 to run", -# ) -# def test_parallel_results_matches_non_parallel( -# model_id, -# model_kwargs, -# ): -# # world_size == 2 is enough -# spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) - - -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available(), -# "requires gpu and torch version >= 2.3.0 to run", -# ) -# def test_tie_word_embeddings( -# model_id, -# model_kwargs, -# ): -# for world_size in [1, 2]: -# if world_size <= NUM_AVAILABLE_DEVICES: -# spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, + "requires more than one gpu and torch version >= 2.3.0 to run", +) +def test_parallel_results_matches_non_parallel( + model_id, + model_kwargs, +): + # world_size == 2 is enough + spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), + "requires gpu and torch version >= 2.3.0 to run", +) +def test_tie_word_embeddings( + model_id, + model_kwargs, +): + for world_size in [1, 2]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) From efd5d28c96724ad55e526e34cddb15b838deb61c Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 24 Jul 2024 19:31:08 +0200 Subject: [PATCH 25/25] address comments --- optimum/fx/parallelization/api.py | 67 ++++++++--------------------- optimum/fx/parallelization/utils.py | 56 +++++++++++++++++++----- 2 files changed, 64 insertions(+), 59 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 1490848a6e..bd307bd93c 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import glob import importlib -import json import os from functools import partial -from typing import List, Optional, Union +from typing import List, Union import torch from torch.fx import GraphModule @@ -26,10 +24,10 @@ from .passes import build_parallel_pass_pipeline from .utils import ( MetaAwareMethodsPatcher, - convert_bin_to_safetensors, - download_files_from_hf, + download_model_from_hf, initialize_parameter_meta, move_model_to_device, + try_collect_weight_map, ) @@ -48,10 +46,6 @@ def parallelize_model( model: Union[torch.nn.Module, str], parallel_ctx: ParallelExecutionCtx, *model_args, - revision: str = "main", - cache_dir: Optional[str] = None, - local_files_only: bool = False, - skip_load_weights: bool = False, **kwargs, ): """ @@ -59,39 +53,41 @@ def parallelize_model( Args: model (Union[torch.nn.Module, str]): - Model to parallelize, could either be a module or a model id in huggingface space. + Model to parallelize, could either be a module or a model id on the Huggingface Hub. parallel_ctx (ParallelExecutionCtx): Parallel execution context containing process groups the current process belongs to. - model_args (additional postional arguments, optional): + *model_args (Any): Additional postional arguments for intializing the model if a model id is passed. - revision (`str`, defaults to `main`): + revision (str, defaults to `main`): Model revision for weights downloading if a model id is passed. - cache_dir (`Optional[str]`, defaults to `None`): + cache_dir (Optional[str], defaults to `None`): Cache directory to store downloaded weights. Defaults to None. - local_files_only (`bool`, defaults to `False`): + local_files_only (bool, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. - skip_load_weights (`bool`, defaults to `False`): + skip_load_weights (bool, defaults to `False`): Whether to skip loading weights from disk to model. - kwargs (additional keyword arguments, optional): + **kwargs (Dict[str, Any]): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. """ + revision = kwargs.pop("revision", "main") + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) + skip_load_weights = kwargs.pop("skip_load_weights", False) + parallel_config = Config() - for k, v in kwargs.items(): + for k, v in dict(kwargs).items(): if k in parallel_config.__dict__: setattr(parallel_config, k, v) - kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__} + kwargs.pop(k) if isinstance(model, str): from transformers import AutoConfig - from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME is_local = os.path.isdir(model) - allow_patterns = ["*.safetensors", "*.bin"] if not is_local: - hf_folder = download_files_from_hf( + hf_folder = download_model_from_hf( model_name_or_path=model, cache_dir=cache_dir, - allow_patterns=allow_patterns, revision=revision, local_files_only=local_files_only, skip_download_weights=skip_load_weights, @@ -109,32 +105,7 @@ def parallelize_model( model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) if not skip_load_weights: - use_safetensors = False - for pattern in allow_patterns: - if len(glob.glob(os.path.join(hf_folder, pattern))) > 0: - use_safetensors = pattern == "*.safetensors" - break - index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) - if os.path.isfile(index_path): - with open(index_path) as f: - index_dict = json.load(f) - parallel_ctx.weight_map = {k: os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()} - weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin")) - if not use_safetensors: - weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {} - convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map) - parallel_ctx.weight_map = weight_map - - # try directly construct weight_map from weight files, should have safetensors file on disk in any case - if not parallel_ctx.weight_map: - from safetensors import safe_open - - weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors")) - for weight_file in weight_files: - with safe_open(filename=weight_file, framework="pt") as f: - for key in f.keys(): - weight_map[key] = weight_file - parallel_ctx.weight_map = weight_map + parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None if torch_dtype is not None: diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 1bb3d07645..f129ffbd40 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import fnmatch +import glob import hashlib import importlib +import json import operator import os import re @@ -32,11 +34,7 @@ from torch.fx import Graph, Node from tqdm.auto import tqdm -from .core import ( - HashableSlice, - ParameterMeta, - ParameterSlice, -) +from .core import HashableSlice, ParameterMeta, ParameterSlice def ensure_divisibility(numerator: int, denominator: int) -> None: @@ -336,10 +334,9 @@ def __init__(self, *args, **kwargs): # adpated from vllm.model_executor.model_loader.weight_utils.py -def download_files_from_hf( +def download_model_from_hf( model_name_or_path: str, cache_dir: Optional[str], - allow_patterns: List[str], revision: Optional[str] = None, local_files_only: bool = False, skip_download_weights: bool = False, @@ -350,9 +347,6 @@ def download_files_from_hf( model_name_or_path (`str`): The model name or path. cache_dir (`Optional[str]`): The cache directory to store the model weights. If None, will use HF defaults. - allow_patterns (`List[str]`): The allowed patterns for the - weight files. Files matched by any of the patterns will be - downloaded. revision (`Optional[str]`, defaults to `None`): The revision of the model. local_files_only(`bool`): Should only use local files if True. skip_download_weights (`bool`, defaults to `False`): Whether to skip downloading weights to disk. @@ -364,6 +358,8 @@ def download_files_from_hf( from huggingface_hub import HfFileSystem, snapshot_download from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + allow_patterns = ["*.safetensors", "*.bin"] + if not skip_download_weights and not huggingface_hub.constants.HF_HUB_OFFLINE: # Before we download we look at that is available: fs = HfFileSystem() @@ -377,9 +373,12 @@ def download_files_from_hf( break if skip_download_weights: + # only need to download config file allow_patterns = [CONFIG_NAME] + elif allow_patterns[0] == "*.safetensors": + allow_patterns = allow_patterns + [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME] else: - allow_patterns = allow_patterns + [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME] + allow_patterns = allow_patterns + [CONFIG_NAME, WEIGHTS_INDEX_NAME] # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -436,3 +435,38 @@ def convert_bin_to_safetensors( keys = [key for key, value in weight_map.items() if value == weight_file] for key in keys: weight_map[key] = output_file_path + + +def try_collect_weight_map(model_name_or_path: str, cache_dir: Optional[str], folder_path: str) -> Dict[str, str]: + """Try collecting weight mapping information from the model folder.""" + from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + weight_map = {} + use_safetensors, weight_patterns = False, ["*safetensors", "*.bin"] + for pattern in weight_patterns: + if len(glob.glob(os.path.join(folder_path, pattern))) > 0: + use_safetensors = pattern == "*.safetensors" + break + index_path = os.path.join(folder_path, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) + weight_files = glob.glob(os.path.join(folder_path, "*.safetensors" if use_safetensors else "*.bin")) + + if os.path.isfile(index_path): + with open(index_path) as f: + index_dict = json.load(f) + weight_map = {k: os.path.join(folder_path, v) for k, v in index_dict["weight_map"].items()} + + # convert bin files to safetensors, modify `weight_map` meanwhile + if not use_safetensors: + convert_bin_to_safetensors(model_name_or_path, cache_dir, weight_files, weight_map) + + # last resort: try directly construct weight_map from weight files + if not weight_map: + from safetensors import safe_open + + # should have safetensors on disk in any case + weight_files = glob.glob(os.path.join(folder_path, "*.safetensors")) + for weight_file in weight_files: + with safe_open(filename=weight_file, framework="pt") as f: + for key in f.keys(): + weight_map[key] = weight_file + return weight_map