Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add Tensor Parallel to torch_native_llama #1876

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

kwen2501
Copy link

@kwen2501 kwen2501 commented Nov 2, 2024

Motivation

The torch_native_llama model does not have Tensor Parallel support today. This PR adds it, using torch.distributed APIs.

Modifications

  • Added a .tensor_parallel() utility;
  • Added ColwiseParallel and RowwiseParallel annotations to related sub-modules;

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

cc: @jerryzh168 @merrymercy @wz337

@@ -495,8 +554,43 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)

# Re-arrange fused matrix for TP
tp_size = get_tensor_model_parallel_world_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, is it possible to do:

  1. split qkv to 3 Tensors
  2. apply tp to each of the Tensor
  3. concat 3 Tensors to a single DTensor

This way we can rely on split/concat ops in DTensor itself instead of worrying about the implementation details?

Copy link
Author

@kwen2501 kwen2501 Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? Although, at this location, we haven't applied TP yet, so there is no notion of DTensor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see what you mean. We can use DTensor API instead of TP API (higher level) here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the newer version, I added support for TP'lized weight loading. Then we directly construct DTensor from the local shard. See the ColwiseParallelSharded strategy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to understand, currently step 2 is manual right?

Copy link
Author

@kwen2501 kwen2501 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not manual per se. It is already packaged and can be called with parallelize_module as like other styles. So no evolvement needed from user or model author.

@@ -153,6 +153,13 @@ def __init__(
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
if self.tp_size > 1:
logger.info(f"Tensor parallelism is enabled, {self.tp_size} devices will be used.")
Copy link
Contributor

@merrymercy merrymercy Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this break other models? Can we only do this for torch_native_llama?
For example, check hasattr(self.model, "tensor_parallel")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Added a supports_torch_tp attr in model.

param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return


def shuffle_qkv_proj_weight(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part still seems complicated. It would be good if we can have some high-level APIs to simplify this, as pointed by @jerryzh168

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree. This part is now removed.

Comment on lines +55 to +64
# shard_id: (shard_offset, shard_size)
gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes):
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size
if loaded_shard_id is None:
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in gate_up_offsets.items():
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
0, shard_offset, shard_size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are style changes only.

Comment on lines +124 to +131
# shard_id: (shard_offset, shard_size)
qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size),
"v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size),
}
if loaded_shard_id is None:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in qkv_offsets.items():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are style changes only.

python/sglang/srt/models/torch_native_llama.py Outdated Show resolved Hide resolved
Comment on lines -216 to -221
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used now.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think this is the best we can do for now until we don't use fused qkv and rely on torch.compile for speedup

@kwen2501 kwen2501 marked this pull request as ready for review November 9, 2024 00:38
@@ -176,7 +165,6 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it seems that we are already doing manual sharding here, I do feel this code should move to separate tp related code instead of being embedded in the model if possible

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it seems the "local" n_heads are needed for constructing the RadixAttention later.

self.attn = RadixAttention(
    self.num_heads,
    ...

I am not sure if I can remove it given that it involves a contract change with that module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants