Skip to content

Commit 4b86dab

Browse files
authored
Reuse cos.device in reset_parameters (Lightning-AI#993)
1 parent 1e045fe commit 4b86dab

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

lit_gpt/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def max_seq_length(self, value: int) -> None:
5959

6060
def reset_parameters(self) -> None:
6161
# Trigger resetting the rope-cache
62-
self.cos, self.sin = self.rope_cache()
62+
self.cos, self.sin = self.rope_cache(device=self.cos.device)
6363

6464
def _init_weights(self, module: nn.Module) -> None:
6565
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""

tests/test_model.py

+12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from conftest import RunIf
1111
from lightning import Fabric
1212
from lightning.fabric.utilities.imports import _IS_WINDOWS
13+
from lightning.fabric.utilities.init import _materialize_meta_tensors
1314

1415
# support running without installing as a package
1516
wd = Path(__file__).parent.parent.resolve()
@@ -832,3 +833,14 @@ def test_rope_init_under_fsdp():
832833
cos, sin = model.rope_cache(device=fabric.device)
833834
torch.testing.assert_close(model.cos, cos)
834835
torch.testing.assert_close(model.sin, sin)
836+
837+
838+
@RunIf(min_cuda_gpus=1)
839+
def test_reset_parameters_device():
840+
from lit_gpt import GPT
841+
842+
with torch.device("meta"):
843+
model = GPT.from_name("pythia-14m", n_layer=1)
844+
_materialize_meta_tensors(model, torch.device("cuda"))
845+
model.reset_parameters()
846+
assert model.cos.device.type == "cuda"

0 commit comments

Comments
 (0)