diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 6238576c54..8acad6d38c 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -84,8 +84,8 @@ def _pack_sequences_for_megatron( # Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of if pad_packed_seq_to is not None: - pad_packed_seq_to = _round_up_to_multiple( - pad_packed_seq_to, pad_packed_seq_to_multiple_of + assert pad_packed_seq_to % pad_packed_seq_to_multiple_of == 0, ( + f"pad_packed_seq_to ({pad_packed_seq_to}) is not a multiple of pad_packed_seq_to_multiple_of ({pad_packed_seq_to_multiple_of})." ) pad_factor = pad_individual_seqs_to_multiple_of @@ -275,6 +275,12 @@ def _get_pack_sequence_parameters_for_megatron( else: pad_packed_seq_to = None + # make sure the pad_packed_seq_to is a multiple of the pad_packed_seq_to_multiple_of + if pad_packed_seq_to is not None: + pad_packed_seq_to = _round_up_to_multiple( + pad_packed_seq_to, pad_packed_seq_to_multiple_of + ) + return ( pad_individual_seqs_to_multiple_of, pad_packed_seq_to_multiple_of, diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 963bac0d7f..b0136b0d2e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1045,6 +1045,8 @@ def train( self.cfg["megatron_cfg"], seq_dim_size, ) + # if pad_full_seq_to is not None, we need to use it as the sequence length + seq_dim_size = pad_full_seq_to or seq_dim_size else: data_iterator = batch.make_microbatch_iterator(mbs) data_iterator_len = local_gbs // mbs