diff --git a/thunder/distributed/tensor_parallel/common.py b/thunder/distributed/tensor_parallel/common.py index 010691059..ec5d0eaa1 100644 --- a/thunder/distributed/tensor_parallel/common.py +++ b/thunder/distributed/tensor_parallel/common.py @@ -107,35 +107,40 @@ def eligible_for_comm_optimization(self) -> bool: return self._has_other_tensor_parallel def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE: + from thunder.core.prims import PrimIDs from thunder.core.transforms import VISIT_TYPE - from thunder.core.trace import get_tracectx from thunder.core.proxies import variableify - for t in bsym.flat_proxy_args: + if bsym.sym.id in { + PrimIDs.UNPACK_TRIVIAL, + PrimIDs.UNPACK_SEQUENCE, + PrimIDs.UNPACK_KEY, + PrimIDs.UNPACK_EMPTY_DICT, + }: + return VISIT_TYPE.NO_OP + + pre_post_process: PrePostProcessInterface | None = self.bsym_to_prepostprocess.get(bsym, None) + new_bsym = bsym.from_bsym_swap_proxies(self.swap_map) + for t in new_bsym.flat_proxy_args: self._maybe_other_tensor_parallel(t) - input_swap_map: dict[VariableInterface, ProxyInterface] = {} - pre_post_process: PrePostProcessInterface | None = None - if bsym in self.bsym_to_prepostprocess: - pre_post_process = self.bsym_to_prepostprocess[bsym] - orig_arg = bsym.flat_proxy_args[0] + if pre_post_process is not None: + orig_arg = new_bsym.flat_proxy_args[0] new_arg, preprocess_artifacts = pre_post_process.preprocess(orig_arg) if new_arg.name != orig_arg.name: - input_swap_map[variableify(orig_arg)] = new_arg - - new_bsym = bsym.from_bsym_swap_proxies(self.swap_map, skip_output=True) - if pre_post_process is not None: - new_bsym = new_bsym.from_bsym_swap_proxies(input_swap_map) + new_bsym = new_bsym.from_bsym_swap_proxies({variableify(orig_arg): new_arg}) new_bsym = pre_post_process.maybe_modify_args_and_kwargs(new_bsym) # note(crcrpar): This header seems to be lost in the extrace. new_bsym.header = f"{pre_post_process.__class__.layer_type}" - trace = get_tracectx() - trace.scopes[-1].append(new_bsym) + new_out = new_bsym.sym(*new_bsym.args, **new_bsym.kwargs) + + var_original_bsym_output = variableify(new_bsym.flat_proxy_outs[0]) if pre_post_process is not None: - y = bsym.flat_proxy_outs[0] - processed_y = pre_post_process.postprocess(y, preprocess_artifacts) - self.swap_map[variableify(y)] = processed_y + processed_y = pre_post_process.postprocess(new_out, preprocess_artifacts) + self.swap_map[var_original_bsym_output] = processed_y + else: + self.swap_map[var_original_bsym_output] = new_out return VISIT_TYPE.REPLACE diff --git a/thunder/tests/distributed/modules.py b/thunder/tests/distributed/modules.py new file mode 100644 index 000000000..144b394c8 --- /dev/null +++ b/thunder/tests/distributed/modules.py @@ -0,0 +1,50 @@ +from __future__ import annotations +from typing import ClassVar, TYPE_CHECKING + +import torch.nn as nn + +from thunder.core import utils + +if TYPE_CHECKING: + import torch + + +__all__ = [ + "ParallelMLP", +] + + +class ParallelMLP(nn.Module): + """Simplified version of Megatron/NeMo's ParallelMLP. + + Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61 + """ + + COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",) + ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",) + + SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh") + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int | None = None, + bias: bool = True, + gelu_approximate: str = "none", + ) -> None: + utils.check( + gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX, + lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}", + ) + if ffn_hidden_size is None: + ffn_hidden_size = 4 * hidden_size + + super().__init__() + self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias) + self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias) + self.gelu = nn.GELU(approximate=gelu_approximate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + four_h = self.gelu(self.dense_h_to_4h(x)) + h = self.dense_4h_to_h(four_h) + return h diff --git a/thunder/tests/distributed/test_tensor_parallel.py b/thunder/tests/distributed/test_tensor_parallel.py index 1a85a13d6..0b509099a 100644 --- a/thunder/tests/distributed/test_tensor_parallel.py +++ b/thunder/tests/distributed/test_tensor_parallel.py @@ -8,9 +8,9 @@ from thunder.distributed import column_parallel, row_parallel import thunder.executors from thunder.tests.distributed.helper import ToyModel, DataParallelTestCase +from thunder.tests.distributed.modules import ParallelMLP from torch.testing._internal import common_utils -from torch.distributed import distributed_c10d as c10d _COL = "column" _ROW = "row" @@ -24,7 +24,7 @@ class TensorParallelTest(DataParallelTestCase): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") @common_utils.parametrize("name,bias", product(tuple(_name_to_transform.keys()), (True, False))) - def test_tensor_parallel_linear(self, name, bias): + def test_linear(self, name, bias): device = torch.device("cuda", self.rank) x = torch.randn(2, 12).to(device).requires_grad_() x_ref = x.clone().detach().requires_grad_() @@ -74,7 +74,7 @@ def test_tensor_parallel_linear(self, name, bias): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") @common_utils.parametrize("name", tuple(_name_to_transform.keys())) - def test_tensor_parallel_embedding(self, name): + def test_embedding(self, name): num_embeddings = 128 embedding_dim = 32 @@ -130,7 +130,7 @@ def forward(self, x): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") @common_utils.parametrize("bias", (True, False)) - def test_tensor_parallel_both_column_and_row(self, bias): + def test_both_column_and_row(self, bias): num_embeddings = 128 embedding_dim = 32 n_hidden = 96 @@ -189,6 +189,55 @@ def forward(self, x): grad = tp_model.get_parameter(param_fqn).grad torch.testing.assert_close(actual=grad, expected=ref_grad, msg=msg) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") + def test_parallel_mlp(self): + from thunder.distributed.prims import PrimIDs + + sequence_length: int = 32 + batch_size: int = 4 + hidden_size: int = 128 + ffn_hidden_size: int = 512 + device = torch.device("cuda", self.rank) + + ref_mlp = ParallelMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size).to(device) + ref_state_dict = ref_mlp.state_dict() + mlp = ParallelMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size).to(device) + mlp.load_state_dict(ref_state_dict) + tp_mlp = thunder.jit(mlp) + tp_mlp = column_parallel(tp_mlp, ParallelMLP.COLUMN_WISE) + tp_mlp = row_parallel(tp_mlp, ParallelMLP.ROW_WISE) + + # See https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L221 for the input shape. + x_ref = torch.randn((sequence_length, batch_size, hidden_size), device=device, requires_grad=True) + x = x_ref.clone().detach().requires_grad_(True) + + expected = ref_mlp(x_ref) + actual = tp_mlp(x) + torch.testing.assert_close(actual=actual, expected=expected) + + grad = torch.rand_like(x_ref) + expected.backward(grad) + actual.backward(grad) + torch.testing.assert_close(actual=x.grad, expected=x_ref.grad) + + tp_syncs = {PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT, PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT} + fwd_traces_with_tensor_parallel_syncs = list( + filter( + lambda trace: any(bsym.sym.id in tp_syncs for bsym in trace.bound_symbols), + thunder.last_traces(tp_mlp), + ) + ) + + last_fwd_trace_with_tp_sync = fwd_traces_with_tensor_parallel_syncs[-1] + bsyms_of_tp_sync = tuple( + filter(lambda bsym: bsym.sym.id in tp_syncs, last_fwd_trace_with_tp_sync.bound_symbols) + ) + msg = f"{bsyms_of_tp_sync=}" + # Two bsyms are supposed to be + # - preprocessing of column-wise parallel linear + # - postprocessing of row-wise parallel linear + self.assertEqual(len(bsyms_of_tp_sync), 2, msg=msg) + common_utils.instantiate_parametrized_tests(TensorParallelTest)