Skip to content

Qwen 3 0.6B full_finetune_distributed breaks with context parallelism enabled #2924

@jerryyiransun

Description

@jerryyiransun

command that was ran was:

CUDA_VISIBLE_DEVICES=0,1,2,3 tune run --nproc_per_node=4 full_finetune_distributed --config configs/qwen3_0.6b.yaml epochs=1

qwen3_0.6b.yaml is tune download Qwen/Qwen3-0.6B and with the addition of context_parallel_dim: 4

Error Message:

[rank3]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank3]:                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank3]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2

Full error log:

Running with torchrun...
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774]
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] *****************************************
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] *****************************************
Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
  checkpoint_files:
  - model.safetensors
  model_type: QWEN3
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
  _component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.qwen3.qwen3_tokenizer
  max_seq_len: 2048
  merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
  path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json

Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
  checkpoint_files:
  - model.safetensors
  model_type: QWEN3
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
  _component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.qwen3.qwen3_tokenizer
  max_seq_len: 2048
  merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
  path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json

Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
  checkpoint_files:
  - model.safetensors
  model_type: QWEN3
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
  _component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.qwen3.qwen3_tokenizer
  max_seq_len: 2048
  merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
  path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json

Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
  checkpoint_files:
  - model.safetensors
  model_type: QWEN3
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
  _component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.qwen3.qwen3_tokenizer
  max_seq_len: 2048
  merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
  path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json

Writing logs to /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs/log_1757037400.txt
Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
Instantiating model and loading checkpoint took 3.15 secs
Memory stats after model init:
        GPU peak memory active: 0.76 GiB
        GPU peak memory alloc: 0.76 GiB
        GPU peak memory reserved: 0.90 GiB
Optimizer is initialized.
Loss is initialized.
No learning rate scheduler configured. Using constant learning rate.
 Profiling disabled.
 Profiler config after instantiation: {'enabled': False}
0|0:   0%|                                                                                                                                                | 0/3235 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank0]:     current_loss = self._loss_step(batch) * current_num_tokens
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank0]:     outputs = self._model(**batch)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank0]:     h = layer(
[rank0]:         ^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank0]:     attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank0]:     output = self._attention_call(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank0]:     return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank0]:     return nn.functional.scaled_dot_product_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank0]:     output = target_fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank0]:     return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank0]:     local_results = _scaled_dot_product_ring_efficient_attention(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank0]:     return _templated_ring_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank0]:     sdpa_merger.step(out, logsumexp, partial)
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank0]:     self._merge_one(out, lse, partial)
[rank0]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank0]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank0]:                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank0]: RuntimeError: The size of tensor a (32) must match the size of tensor b (22) at non-singleton dimension 2
[rank1]: Traceback (most recent call last):
[rank1]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:              ^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank1]:     recipe.train()
[rank1]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank1]:     current_loss = self._loss_step(batch) * current_num_tokens
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank1]:     outputs = self._model(**batch)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank1]:     h = layer(
[rank1]:         ^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]:     return inner()
[rank1]:            ^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank1]:     attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank1]:     output = self._attention_call(
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank1]:     return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank1]:     return nn.functional.scaled_dot_product_attention(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank1]:     output = target_fn(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank1]:     return disable_fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank1]:     return DTensor._op_dispatcher.dispatch(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank1]:     return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank1]:     local_results = _scaled_dot_product_ring_efficient_attention(
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank1]:     return _templated_ring_attention(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank1]:     sdpa_merger.step(out, logsumexp, partial)
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank1]:     self._merge_one(out, lse, partial)
[rank1]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank1]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank1]:                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank1]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
[rank2]: Traceback (most recent call last):
[rank2]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank2]:     current_loss = self._loss_step(batch) * current_num_tokens
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank2]:     outputs = self._model(**batch)
[rank2]:               ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank2]:     h = layer(
[rank2]:         ^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank2]:     attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank2]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank2]:     output = self._attention_call(
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank2]:     return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank2]:     return nn.functional.scaled_dot_product_attention(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank2]:     output = target_fn(*args, **kwargs)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank2]:     return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank2]:     local_results = _scaled_dot_product_ring_efficient_attention(
[rank2]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank2]:     return _templated_ring_attention(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank2]:     sdpa_merger.step(out, logsumexp, partial)
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank2]:     self._merge_one(out, lse, partial)
[rank2]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank2]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank2]:                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank2]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
[rank3]: Traceback (most recent call last):
[rank3]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank3]:     sys.exit(recipe_main())
[rank3]:              ^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank3]:     sys.exit(recipe_main(conf))
[rank3]:              ^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank3]:     recipe.train()
[rank3]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank3]:     current_loss = self._loss_step(batch) * current_num_tokens
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank3]:     outputs = self._model(**batch)
[rank3]:               ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank3]:     h = layer(
[rank3]:         ^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank3]:     attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank3]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank3]:     output = self._attention_call(
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank3]:     return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank3]:     return nn.functional.scaled_dot_product_attention(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank3]:     output = target_fn(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank3]:     return DTensor._op_dispatcher.dispatch(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank3]:     return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank3]:     local_results = _scaled_dot_product_ring_efficient_attention(
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank3]:     return _templated_ring_attention(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank3]:     sdpa_merger.step(out, logsumexp, partial)
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank3]:     self._merge_one(out, lse, partial)
[rank3]:   File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank3]:     out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank3]:                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank3]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions