-
Notifications
You must be signed in to change notification settings - Fork 673
Open
Description
command that was ran was:
CUDA_VISIBLE_DEVICES=0,1,2,3 tune run --nproc_per_node=4 full_finetune_distributed --config configs/qwen3_0.6b.yaml epochs=1
qwen3_0.6b.yaml is tune download Qwen/Qwen3-0.6B
and with the addition of context_parallel_dim: 4
Error Message:
[rank3]: out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank3]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank3]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
Full error log:
Running with torchrun...
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774]
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] *****************************************
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0904 18:56:33.264000 2112914 site-packages/torch/distributed/run.py:774] *****************************************
Running FullFinetuneRecipeDistributed with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
checkpoint_files:
- model.safetensors
model_type: QWEN3
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
_component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.qwen3.qwen3_tokenizer
max_seq_len: 2048
merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json
Running FullFinetuneRecipeDistributed with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
checkpoint_files:
- model.safetensors
model_type: QWEN3
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
_component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.qwen3.qwen3_tokenizer
max_seq_len: 2048
merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json
Running FullFinetuneRecipeDistributed with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
checkpoint_files:
- model.safetensors
model_type: QWEN3
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
_component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.qwen3.qwen3_tokenizer
max_seq_len: 2048
merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json
Running FullFinetuneRecipeDistributed with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /data3/jerrys17/torchtune/models/qwen3-0.6b
checkpoint_files:
- model.safetensors
model_type: QWEN3
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
recipe_checkpoint: null
clip_grad_norm: null
compile: false
context_parallel_dim: 4
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: fp32
enable_activation_checkpointing: false
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs
model:
_component_: torchtune.models.qwen3.qwen3_0_6b_instruct
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /data3/jerrys17/torchtune/output/qwen3_0_6B/full/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.qwen3.qwen3_tokenizer
max_seq_len: 2048
merges_file: /data3/jerrys17/torchtune/models/qwen3-0.6b/merges.txt
path: /data3/jerrys17/torchtune/models/qwen3-0.6b/vocab.json
Writing logs to /data3/jerrys17/torchtune/output/qwen3_0_6B/full/logs/log_1757037400.txt
Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
Instantiating model and loading checkpoint took 3.15 secs
Memory stats after model init:
GPU peak memory active: 0.76 GiB
GPU peak memory alloc: 0.76 GiB
GPU peak memory reserved: 0.90 GiB
Optimizer is initialized.
Loss is initialized.
No learning rate scheduler configured. Using constant learning rate.
Profiling disabled.
Profiler config after instantiation: {'enabled': False}
0|0: 0%| | 0/3235 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank0]: sys.exit(recipe_main())
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]: sys.exit(recipe_main(conf))
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank0]: recipe.train()
[rank0]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank0]: current_loss = self._loss_step(batch) * current_num_tokens
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank0]: outputs = self._model(**batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]: return inner()
[rank0]: ^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank0]: h = layer(
[rank0]: ^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]: return inner()
[rank0]: ^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank0]: attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank0]: output = self._attention_call(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank0]: return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank0]: return nn.functional.scaled_dot_product_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank0]: output = target_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank0]: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank0]: local_results = _scaled_dot_product_ring_efficient_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank0]: return _templated_ring_attention(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank0]: sdpa_merger.step(out, logsumexp, partial)
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank0]: self._merge_one(out, lse, partial)
[rank0]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank0]: out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank0]: RuntimeError: The size of tensor a (32) must match the size of tensor b (22) at non-singleton dimension 2
[rank1]: Traceback (most recent call last):
[rank1]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank1]: sys.exit(recipe_main())
[rank1]: ^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank1]: sys.exit(recipe_main(conf))
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank1]: recipe.train()
[rank1]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank1]: current_loss = self._loss_step(batch) * current_num_tokens
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank1]: outputs = self._model(**batch)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank1]: h = layer(
[rank1]: ^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank1]: attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank1]: output = self._attention_call(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank1]: return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank1]: return nn.functional.scaled_dot_product_attention(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank1]: output = target_fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank1]: return disable_fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank1]: return DTensor._op_dispatcher.dispatch(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank1]: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank1]: local_results = _scaled_dot_product_ring_efficient_attention(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank1]: return _templated_ring_attention(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank1]: sdpa_merger.step(out, logsumexp, partial)
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank1]: self._merge_one(out, lse, partial)
[rank1]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank1]: out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank1]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank1]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
[rank2]: Traceback (most recent call last):
[rank2]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank2]: sys.exit(recipe_main())
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]: sys.exit(recipe_main(conf))
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank2]: recipe.train()
[rank2]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank2]: current_loss = self._loss_step(batch) * current_num_tokens
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank2]: outputs = self._model(**batch)
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]: return inner()
[rank2]: ^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank2]: h = layer(
[rank2]: ^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank2]: return inner()
[rank2]: ^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank2]: attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank2]: output = self._attention_call(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank2]: return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank2]: return nn.functional.scaled_dot_product_attention(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank2]: output = target_fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank2]: return disable_fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank2]: return fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank2]: return DTensor._op_dispatcher.dispatch(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank2]: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank2]: local_results = _scaled_dot_product_ring_efficient_attention(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank2]: return _templated_ring_attention(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank2]: sdpa_merger.step(out, logsumexp, partial)
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank2]: self._merge_one(out, lse, partial)
[rank2]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank2]: out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank2]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank2]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
[rank3]: Traceback (most recent call last):
[rank3]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1184, in <module>
[rank3]: sys.exit(recipe_main())
[rank3]: ^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank3]: sys.exit(recipe_main(conf))
[rank3]: ^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1179, in recipe_main
[rank3]: recipe.train()
[rank3]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 1031, in train
[rank3]: current_loss = self._loss_step(batch) * current_num_tokens
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/recipes/full_finetune_distributed.py", line 871, in _loss_step
[rank3]: outputs = self._model(**batch)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank3]: return inner()
[rank3]: ^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank3]: result = forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 659, in forward
[rank3]: h = layer(
[rank3]: ^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank3]: return inner()
[rank3]: ^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank3]: result = forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/modules/transformer.py", line 132, in forward
[rank3]: attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/models/qwen3/_attention.py", line 261, in forward
[rank3]: output = self._attention_call(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 250, in _attention_call
[rank3]: return _sdpa_call(q, k, v, mask, dropout_p, is_causal)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/torchtune/torchtune/modules/attention_utils.py", line 209, in _sdpa_call
[rank3]: return nn.functional.scaled_dot_product_attention(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 1001, in inner_fn
[rank3]: output = target_fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank3]: return disable_fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank3]: return fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank3]: return DTensor._op_dispatcher.dispatch(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 148, in dispatch
[rank3]: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 590, in _sdpa_handler
[rank3]: local_results = _scaled_dot_product_ring_efficient_attention(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 244, in _scaled_dot_product_ring_efficient_attention
[rank3]: return _templated_ring_attention(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 560, in _templated_ring_attention
[rank3]: sdpa_merger.step(out, logsumexp, partial)
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 187, in step
[rank3]: self._merge_one(out, lse, partial)
[rank3]: File "/data3/jerrys17/miniconda3/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/experimental/_attention.py", line 156, in _merge_one
[rank3]: out = out - F.sigmoid(block_lse - lse) * (out - block_out)
[rank3]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
[rank3]: RuntimeError: The size of tensor a (64) must match the size of tensor b (44) at non-singleton dimension 2
Metadata
Metadata
Assignees
Labels
No labels