Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
sahilsuneja1 committed Sep 6, 2024
1 parent 8039afd commit 08fd47d
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions fms_extras/models/paged_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def factory(**kwargs):
hidden_grow_factor=3.5,
multiple_of=1024,
max_expected_seq_len=8192,
rope_theta=500000,
)

_70b_llama3_config = PagedLLaMAConfig(
Expand Down
2 changes: 1 addition & 1 deletion scripts/inference_simplified.sh
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,4 @@ MODEL_ARGS_LLAMA3_8B_HF="\
torchrun \
--nproc_per_node=8 \
scripts/paged_speculative_inference.py \
${MODEL_ARGS_LLAMA3_70B_SPECU2_CONVERTED_HF}
${MODEL_ARGS_LLAMA3_8B_HF}
2 changes: 1 addition & 1 deletion scripts/paged_speculative_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
if args.distributed:
#dist.init_process_group()
#torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
tp_size = 16
tp_size = 8
base_model_mesh = dist.device_mesh.init_device_mesh(
"cuda", (world_size // tp_size, tp_size), mesh_dim_names=("dp", "tp")
)
Expand Down

0 comments on commit 08fd47d

Please sign in to comment.