File tree 2 files changed +13
-1
lines changed
2 files changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def max_seq_length(self, value: int) -> None:
59
59
60
60
def reset_parameters (self ) -> None :
61
61
# Trigger resetting the rope-cache
62
- self .cos , self .sin = self .rope_cache ()
62
+ self .cos , self .sin = self .rope_cache (device = self . cos . device )
63
63
64
64
def _init_weights (self , module : nn .Module ) -> None :
65
65
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
Original file line number Diff line number Diff line change 10
10
from conftest import RunIf
11
11
from lightning import Fabric
12
12
from lightning .fabric .utilities .imports import _IS_WINDOWS
13
+ from lightning .fabric .utilities .init import _materialize_meta_tensors
13
14
14
15
# support running without installing as a package
15
16
wd = Path (__file__ ).parent .parent .resolve ()
@@ -832,3 +833,14 @@ def test_rope_init_under_fsdp():
832
833
cos , sin = model .rope_cache (device = fabric .device )
833
834
torch .testing .assert_close (model .cos , cos )
834
835
torch .testing .assert_close (model .sin , sin )
836
+
837
+
838
+ @RunIf (min_cuda_gpus = 1 )
839
+ def test_reset_parameters_device ():
840
+ from lit_gpt import GPT
841
+
842
+ with torch .device ("meta" ):
843
+ model = GPT .from_name ("pythia-14m" , n_layer = 1 )
844
+ _materialize_meta_tensors (model , torch .device ("cuda" ))
845
+ model .reset_parameters ()
846
+ assert model .cos .device .type == "cuda"
You can’t perform that action at this time.
0 commit comments