Skip to content

Commit

Permalink
test: fix neuronx_distributed import
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jan 2, 2025
1 parent 4cd2843 commit a2039d4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
import torch_xla.core.xla_model as xm

if is_neuronx_distributed_available():
from neuronx_distributed.modules.qkv_linear import get_kv_shared_group
from neuronx_distributed.parallel_layers.parallel_state import (
get_kv_shared_group,
get_pipeline_model_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_size,
Expand Down

0 comments on commit a2039d4

Please sign in to comment.