Skip to content

Commit 795a84d

Browse files
yaoyu-33yashaswikarnatiparthmannan
authored andcommitted
Add CP support to Neva in NeMo2 (#11850)
* api updates and fixes Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix arg Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update seq packing in mock ds Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * save Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update preprocess_data Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * update seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix sp Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * save Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * add truncation and padding Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Fix issues Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * change LLaVATemplateConfig variables to class variables * change to use field with default attributes * Apply isort and black reformatting Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Initial support for CP * Add seq packing option in energon Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Fix energon conversation Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * add energon option in neva training script Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: parthmannan <parthmannan@users.noreply.github.com> * Improvements * add ci test for packed seq Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Fix for PP+CP * Max seq len fix * fix mock dataset seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix mock dataset seq packing Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix lint and update seq pack func Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix energon module Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * fix comments Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * address lightning issues Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * Update sequence_packing.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * update energon requirements Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Fix for energon update Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * fix for test Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> * Apply isort and black reformatting Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> * revert overlap config change Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> --------- Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> Signed-off-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com> Signed-off-by: parthmannan <parthmannan@users.noreply.github.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 <yaoyu-33@users.noreply.github.com> Co-authored-by: ykarnati <ykarnati@nvidia.com> Co-authored-by: yashaswikarnati <yashaswikarnati@users.noreply.github.com> Co-authored-by: Parth Mannan <pmannan@nvidia.com> Co-authored-by: parthmannan <parthmannan@users.noreply.github.com> Co-authored-by: Parth Mannan <parth.mannan95@gmail.com> Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
1 parent 49eaa69 commit 795a84d

2 files changed

Lines changed: 135 additions & 27 deletions

File tree

nemo/collections/vlm/neva/model/base.py

Lines changed: 133 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
132132
if value is not None:
133133
setattr(packed_seq_params, attr, value.cuda(non_blocking=True))
134134
_batch["packed_seq_params"] = packed_seq_params
135+
if ps.get_context_parallel_world_size() > 1:
136+
num_valid_tokens_in_ub = None
137+
if "loss_mask" in _batch and _batch["loss_mask"] is not None:
138+
num_valid_tokens_in_ub = _batch["loss_mask"].sum()
139+
_batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub
135140

136141
return _batch
137142

@@ -381,6 +386,59 @@ def forward(
381386
return super().forward(x, attention_mask)
382387

383388

389+
class _get_data_on_this_cp_rank(torch.autograd.Function):
390+
"""Performs sharding for Context Parallelism in THD format
391+
392+
In the forward pass, indices are selected for each CP rank and remaining tokens are dropped.
393+
In the backward pass, this class takes care of managing gradients for dropped tokens on each
394+
CP rank.
395+
"""
396+
397+
@staticmethod
398+
# def forward(ctx, decoder_embeddings, labels, loss_mask, packed_seq_params):
399+
def forward(ctx, batch, packed_seq_params):
400+
cp_size = ps.get_context_parallel_world_size()
401+
if cp_size > 1:
402+
try:
403+
import transformer_engine_torch as tex
404+
except ModuleNotFoundError as e:
405+
logging.error(
406+
"Please update Transformer Engine to >= 1.10 to use \
407+
Context Parallel with THD format data"
408+
)
409+
raise e
410+
cp_rank = ps.get_context_parallel_rank()
411+
for key, data in batch.items():
412+
index = tex.thd_get_partitioned_indices(
413+
packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank
414+
)
415+
if key == "combined_embeddings":
416+
ctx.decoder_emb_index = index
417+
ctx.decoder_emb_seqlen = data.size(1)
418+
batch[key] = data.index_select(1, index)
419+
420+
return batch
421+
422+
@staticmethod
423+
def backward(ctx, grad_out, grad_label, grad_loss):
424+
seqlen = ctx.decoder_emb_seqlen
425+
index = ctx.decoder_emb_index
426+
assert grad_out.size(1) == index.size(
427+
0
428+
), f"Shape mismatch in incoming gradient {grad_out.shape} and \
429+
index from THD CP sharding {index.shape}"
430+
grad_in = torch.zeros(
431+
grad_out.size(0),
432+
seqlen,
433+
*grad_out.size()[2:],
434+
dtype=grad_out.dtype,
435+
device=grad_out.device,
436+
)
437+
grad_in[:, ctx.decoder_emb_index, :] = grad_out
438+
439+
return (grad_in, None, None, None)
440+
441+
384442
class MCoreNevaModel(MCoreLLaVAModel):
385443
def __init__(
386444
self,
@@ -604,6 +662,13 @@ def forward(
604662
packed_seq_params,
605663
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]
606664

665+
if self.context_parallel_lm > 1 or self.sequence_parallel_lm:
666+
combined_embeddings, final_labels, final_loss_mask, packed_seq_params = (
667+
self._process_embedding_token_parallel(
668+
combined_embeddings, final_labels, final_loss_mask, packed_seq_params
669+
)
670+
)
671+
607672
output = self.language_model(
608673
input_ids=None,
609674
position_ids=None,
@@ -850,7 +915,9 @@ def _preprocess_data(
850915
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]
851916

852917
if final_embedding is not None:
853-
final_embedding = final_embedding.transpose(1, 0).contiguous()
918+
if self.context_parallel_lm == 1:
919+
# Transpose to [s,b,h] if not using CP or not using packed_sequence/THD format
920+
final_embedding = final_embedding.transpose(1, 0).contiguous()
854921
# Truncate if exceeding the language model's max sequence length.
855922
if final_embedding.shape[0] > self._language_max_sequence_length:
856923
final_embedding = final_embedding[: self._language_max_sequence_length]
@@ -864,34 +931,73 @@ def _preprocess_data(
864931
packed_seq_params.cu_seqlens_q[-1] >= packed_seq_params.cu_seqlens_q[-2]
865932
), "with packed sequence, the truncation can only truncate on the last sequence."
866933

867-
if self.sequence_parallel_lm and not packed_sequence:
868-
# Create an attention mask. This ensures correct computation.
869-
# This is done even when no padding was done as we set mask_type to
870-
# 'padding' or 'padding_causal' when using SP.
871-
if attention_mask is None:
872-
# Create base attention mask with original seq len to indicate valid tokens
873-
attention_mask = (
874-
torch.ones(
875-
(
876-
final_embedding.shape[1],
877-
final_embedding.shape[0] - sp_padding_needed,
878-
),
879-
device=final_embedding.device,
880-
)
881-
.unsqueeze(1)
882-
.unsqueeze(1)
883-
) # [b, 1, 1, final seq len - sp_padding_needed]
884-
if sp_padding_needed > 0:
885-
# Add the padding portion of the mask
886-
attention_mask = torch.nn.functional.pad(attention_mask, (0, sp_padding_needed))
887-
888-
# Attention mask True/False meaning flipped in 1.7.0
889-
attention_mask = attention_mask < 0.5
890-
if self.sequence_parallel_lm:
891-
final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(final_embedding)
892-
893934
return final_embedding, final_labels, final_loss_mask, attention_mask
894935

936+
def _process_embedding_token_parallel(self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params):
937+
"""Processes the input data for model parallelism support."""
938+
939+
# No pre or post processing needed with PP middle chunks.
940+
if not self.pre_process and not self.post_process:
941+
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
942+
943+
if self.pre_process:
944+
if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
945+
shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
946+
seq_dim = 1
947+
elif self.context_parallel_lm > 1:
948+
shard_factor = self.context_parallel_lm * 2
949+
seq_dim = 1
950+
elif self.sequence_parallel_lm:
951+
shard_factor = self.tensor_model_parallel_size_lm
952+
seq_dim = 0
953+
954+
assert (
955+
combined_embeddings.shape[seq_dim] % shard_factor == 0
956+
), f"Sequence length should be divisible by {shard_factor} for \
957+
Sequence/Context parallelism"
958+
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
959+
assert (
960+
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
961+
), f"TP Comm overlap either requires Vision+Text token length \
962+
== language_max_sequence_length"
963+
964+
if self.context_parallel_lm > 1:
965+
batch = dict()
966+
if self.pre_process:
967+
batch.update(
968+
{
969+
"combined_embeddings": combined_embeddings,
970+
}
971+
)
972+
if self.post_process:
973+
batch.update(
974+
{
975+
"new_labels": new_labels,
976+
"new_loss_mask": new_loss_mask,
977+
}
978+
)
979+
# Distribute sequence across CP ranks
980+
if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
981+
from megatron.training.utils import get_batch_on_this_cp_rank
982+
983+
batch = get_batch_on_this_cp_rank(batch)
984+
else:
985+
batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params)
986+
987+
if self.pre_process:
988+
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
989+
combined_embeddings = combined_embeddings.transpose(1, 0).contiguous() # [B,S/CP,H] -> [S/CP,B,H]
990+
if self.post_process:
991+
new_labels = batch["new_labels"]
992+
new_loss_mask = batch["new_loss_mask"]
993+
994+
if self.sequence_parallel_lm and self.pre_process:
995+
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
996+
combined_embeddings
997+
) # [S/(CP*TP),B,H]
998+
999+
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
1000+
8951001

8961002
class NevaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
8971003
def __init__(

scripts/vlm/neva_finetune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def main(args):
153153
tensor_model_parallel_size=args.tp_size,
154154
pipeline_model_parallel_size=args.pp_size,
155155
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
156+
context_parallel_size=args.cp_size,
156157
pipeline_dtype=torch.bfloat16,
157158
sequence_parallel=True,
158159
ddp=DistributedDataParallelConfig(
@@ -271,6 +272,7 @@ def main(args):
271272
parser.add_argument("--max_steps", type=int, required=False, default=5190)
272273
parser.add_argument("--tp_size", type=int, required=False, default=1)
273274
parser.add_argument("--pp_size", type=int, required=False, default=1)
275+
parser.add_argument("--cp_size", type=int, required=False, default=1)
274276
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
275277
parser.add_argument("--projector_type", type=str, required=False, default="mcore_mlp")
276278
parser.add_argument("--name", type=str, required=False, default="neva_pretrain")

0 commit comments

Comments
 (0)