44import copy
55import functools
66import itertools
7- import unittest
87from typing import Any , Optional , Union
98
109import torch
1110import torch .distributed as dist
1211import torch .nn as nn
1312from torch .distributed .fsdp import fully_shard
1413from torch .nn .parallel .scatter_gather import _is_namedtuple
15- from torch .testing ._internal .common_cuda import TEST_CUDA
1614from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
1715from torch .testing ._internal .common_fsdp import (
1816 check_sharded_parity ,
1917 DoubleLinear ,
2018 FSDPTest ,
2119 FSDPTestMultiThread ,
20+ get_devtype ,
2221 MLP ,
2322)
2423from torch .testing ._internal .common_utils import run_tests
2827)
2928
3029
30+ device_type = torch .device (get_devtype ())
31+
32+
3133class TestFullyShardAutograd (FSDPTest ):
3234 @property
3335 def world_size (self ) -> int :
34- return min (4 , torch .cuda .device_count ())
36+ return min (4 , torch .get_device_module ( device_type ) .device_count ())
3537
3638 def _reduce_1d_partial_grads (
3739 self , module : nn .Module , group : Optional [dist .ProcessGroup ] = None
@@ -58,7 +60,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
5860 local_batch_size = 2
5961 global_batch_size , dim = (self .world_size * local_batch_size , 24 )
6062 model = DoubleLinear (dim = dim , use_second_linear = True )
61- ref_model = copy .deepcopy (model ).cuda ( )
63+ ref_model = copy .deepcopy (model ).to ( device_type )
6264 fully_shard (model .lin1 , reshard_after_forward = reshard_after_forward )
6365 fully_shard (model , reshard_after_forward = reshard_after_forward )
6466 ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 )
@@ -68,7 +70,7 @@ def _test_unused_forward_output(self, reshard_after_forward: Union[bool, int]):
6870 for iter_idx in range (10 ):
6971 # Use all forward outputs in the loss/backward for the first half
7072 # of the iterations and only the 1st forward output for the rest
71- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
73+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
7274 local_inp = global_inp [
7375 self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
7476 ].detach ()
@@ -104,7 +106,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
104106 local_batch_size , dim = (2 , 24 )
105107 global_batch_size = self .world_size * local_batch_size
106108 model = DoubleLinear (dim = dim , use_second_linear = False )
107- ref_model = copy .deepcopy (model ).cuda ( )
109+ ref_model = copy .deepcopy (model ).to ( device_type )
108110 fully_shard (model .lin1 , reshard_after_forward = reshard_after_forward )
109111 fully_shard (model .lin2 , reshard_after_forward = reshard_after_forward )
110112 fully_shard (model , reshard_after_forward = reshard_after_forward )
@@ -113,7 +115,7 @@ def _test_unused_forward_module(self, reshard_after_forward: Union[bool, int]):
113115
114116 torch .manual_seed (1 ) # same on all ranks
115117 for iter_idx in range (10 ):
116- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
118+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
117119 local_inp = global_inp [
118120 self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
119121 ].detach ()
@@ -214,7 +216,7 @@ def forward(self, x: torch.Tensor):
214216 Module (dim ),
215217 FromContainerType (container_type ),
216218 )
217- ref_model = copy .deepcopy (model ).cuda ( )
219+ ref_model = copy .deepcopy (model ).to ( device_type )
218220 for module in model :
219221 fully_shard (module )
220222 fully_shard (model )
@@ -223,7 +225,7 @@ def forward(self, x: torch.Tensor):
223225
224226 torch .manual_seed (1 ) # same on all ranks
225227 for iter_idx in range (10 ):
226- global_inp = torch .rand ((global_batch_size , dim ), device = "cuda" )
228+ global_inp = torch .rand ((global_batch_size , dim ), device = device_type )
227229 local_inp = global_inp [
228230 self .rank * local_batch_size : (self .rank + 1 ) * local_batch_size
229231 ].detach ()
@@ -245,7 +247,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
245247 def world_size (self ) -> int :
246248 return 2
247249
248- @unittest . skipIf ( not TEST_CUDA , "no cuda" )
250+ @skip_if_lt_x_gpu ( 1 )
249251 def test_post_acc_grad_hook_runs (self ):
250252 param_name_to_hook_count = collections .defaultdict (int )
251253
@@ -260,7 +262,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
260262 param_hook = functools .partial (hook , param_name )
261263 param .register_post_accumulate_grad_hook (param_hook )
262264
263- inp = torch .randn ((2 , 8 ), device = "cuda" )
265+ inp = torch .randn ((2 , 8 ), device = device_type )
264266 model (inp ).sum ().backward ()
265267 param_names = {param_name for param_name , _ in model .named_parameters ()}
266268 self .assertEqual (param_names , set (param_name_to_hook_count .keys ()))
@@ -271,7 +273,7 @@ def hook(param_name: str, param: torch.Tensor) -> None:
271273class TestFullyShardPostAccGradHookMultiProcess (FSDPTest ):
272274 @property
273275 def world_size (self ) -> int :
274- return min (torch .cuda .device_count (), 2 )
276+ return min (torch .get_device_module ( device_type ) .device_count (), 2 )
275277
276278 @skip_if_lt_x_gpu (2 )
277279 def test_post_acc_grad_hook_optim_parity (self ):
@@ -283,7 +285,7 @@ def test_post_acc_grad_hook_optim_parity(self):
283285 model_args = ModelArgs (dropout_p = 0.0 )
284286 model = Transformer (model_args )
285287
286- ref_model = copy .deepcopy (model ).cuda ( )
288+ ref_model = copy .deepcopy (model ).to ( device_type )
287289 for module in itertools .chain (ref_model .layers , [ref_model ]):
288290 fully_shard (module )
289291 optim_kwargs = {"lr" : 1e-2 , "foreach" : False }
@@ -312,7 +314,7 @@ def optim_hook(param: nn.Parameter) -> None:
312314 param .register_post_accumulate_grad_hook (optim_hook )
313315
314316 torch .manual_seed (42 + self .rank )
315- inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = "cuda" )
317+ inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = device_type )
316318 for _ in range (10 ):
317319 ref_loss = ref_model (inp ).sum ()
318320 ref_loss .backward ()
0 commit comments