Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Parallelization Strategy to Make it More General #1988

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented Aug 14, 2024

As per title, this PR tries a more general approach rather than relying purely on human heuristics, basically it uses the following steps to search a possible parallelization strategy for a transformer model

  • Use dynamo for graph tracing so that we get the graph to operate on
  • Decompose and functionalize the traced graph so that we get a smaller op set to work with
  • Apply parallel axis analysis and do a constrained backtracking search on the whole graph to get a possible solution(not necessarily optimal)
  • Replace ops the original traced graph with their parallelized version(Linear -> ColumnLinear/RowLinear)

And for the API design, we disable the support of passing custom modules and only focus on models in transformers because supporting custom models is not the priority for now.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

optimum/fx/parallelization/api.py Show resolved Hide resolved
"""
API for automatic model parallelism through Pytorch FX.

Args:
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
model (str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model (str):
model (`str`):

Comment on lines 63 to 67
class DecompTracer(GraphAppendingTracer):
def __init__(self, graph: Graph):
super().__init__(graph)
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = _SymNodeDict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a docstring explaining what it does.

optimum/fx/parallelization/decomp.py Show resolved Hide resolved
Comment on lines 74 to 79
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts
in the orignal graph module.

Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for
parallel axis propagation and analysis, the original graph module is still used for real execution.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you group notes as follows:

Notes:
1. Certain primitive layers ....
2. The traced graph is a low-level equivalent...

leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
) -> Callable:
"""
API to decompose and funcitonalize a high-level graph module.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
API to decompose and funcitonalize a high-level graph module.
API to decompose and functionalize a high-level graph module.

Comment on lines 196 to 200
graph_module (GraphModule):
The high-level graph module to be decomposed and functionalized.
decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`):
The lookup table which maps high-level torch op to their equivalent low-level implementation.
leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
graph_module (GraphModule):
The high-level graph module to be decomposed and functionalized.
decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`):
The lookup table which maps high-level torch op to their equivalent low-level implementation.
leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`):
graph_module (`GraphModule`):
The high-level graph module to be decomposed and functionalized.
decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`):
The lookup table which maps high-level torch op to their equivalent low-level implementation.
leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`):

Comment on lines 25 to 41
class Registry:
def __init__(self) -> None:
self.mapping = {}

def register(self, op_types):
def wrapper(cls):
if isinstance(op_types, (list, tuple)):
for op_type in op_types:
self.mapping[op_type] = cls
else:
self.mapping[op_types] = cls
return cls

return wrapper

def is_supported(self, op_type) -> bool:
return op_type in self.mapping
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this registry used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for registration of parallel axis policy of different aten ops

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring to explain that please?

Comment on lines 127 to 130
def propagate(self) -> bool:
arg = self.node.all_input_nodes[0]
axis = self.extract_axis(arg)
return [axis]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit it's not returning a bool.
If I understand properly it returns the axis that is supposed to be parallel?

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, it looks up the axis of all the inputs and try inferencing the axis of the output, and will return empty list if no valid axis on which the output can be parallelized

Comment on lines 25 to 41
class Registry:
def __init__(self) -> None:
self.mapping = {}

def register(self, op_types):
def wrapper(cls):
if isinstance(op_types, (list, tuple)):
for op_type in op_types:
self.mapping[op_type] = cls
else:
self.mapping[op_types] = cls
return cls

return wrapper

def is_supported(self, op_type) -> bool:
return op_type in self.mapping
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring to explain that please?

Comment on lines +199 to +201
def search(idx: int):
if idx == len(nodes):
return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def search(idx: int):
if idx == len(nodes):
return True
def search(idx: int) -> bool:
return idx == len(nodes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the search here actually is a backtracking search function entailing more logic following, so we can only return True if we have reached the very last op

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines 27 to 28
Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new
aten op, you need to register the corresponding handler class by decorating it with `register` function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new
aten op, you need to register the corresponding handler class by decorating it with `register` function.
Registry class handles registration of parallel axis propagation handlers of different aten ops.
To support a new aten op, you need to register the corresponding handler class by decorating
it with the `register` function.

@zhenglongjiepheonix
Copy link
Contributor Author

merge this for irrelevant failures

@zhenglongjiepheonix zhenglongjiepheonix merged commit ad98dc9 into huggingface:main Sep 2, 2024
43 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants