Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add sanity checks to dtensor tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 0f3887b4a5b87ea63adb8d09a2edfabdd4868e0a
Pull Request resolved: #302
  • Loading branch information
drisspg committed Jul 3, 2024
1 parent 36405a7 commit 474da90
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ def test_fp8_mlp_tensor_parallelism_base(
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
assert (
tp_model.ffn.w1.weight.requires_grad
), "Expecting gradients to be enabled for TP model."
assert tp_out.requires_grad, "Expecting gradients to be enabled for TP model."
awaited_out = tp_out.wait()
assert awaited_out.requires_grad, "Expecting awaited out to require gradients"
tp_out.sum().backward()
sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
Expand Down

0 comments on commit 474da90

Please sign in to comment.