Skip to content

Commit 9ef2c3b

Browse files
committed
add torch memory profiler
1 parent 107b8af commit 9ef2c3b

File tree

6 files changed

+186
-14
lines changed

6 files changed

+186
-14
lines changed
364 KB
Binary file not shown.

uncollects/learn_by_do/torch_lab/llama.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from dataclasses import dataclass
2+
import socket
3+
from datetime import datetime
24
import torch
35
from torch import nn
46
import torch.nn.functional as F
7+
from torch.autograd.profiler import record_function
58
from typing import Optional
69

710

@@ -14,8 +17,9 @@ class ModelArgs:
1417
vocab_size: int = 1024
1518
ffn_hidden_dim: int = 8192
1619
norm_eps: float = 1e-5
17-
rope_theta: float = 500000
20+
rope_theta: float = 10000
1821
max_seq_len: int = 2048
22+
max_batch_size: int = 2
1923

2024
def __init__(self, **kwargs):
2125
if self.n_kv_heads is None:
@@ -32,7 +36,7 @@ def forward(self, max_seq_len):
3236
seq = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
3337
freqs = torch.outer(seq, self.inv_freq)
3438
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
35-
return freqs_cis[:, None, None, :]
39+
return freqs_cis[None, :, None, :]
3640

3741

3842
def apply_rotary_emb(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
@@ -43,6 +47,7 @@ def apply_rotary_emb(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
4347

4448
class FeedForward(nn.Module):
4549
def __init__(self, dim: int, ffn_hidden_dim: int):
50+
super().__init__()
4651
self.w1 = nn.Linear(dim, ffn_hidden_dim, bias=False)
4752
self.w2 = nn.Linear(ffn_hidden_dim, dim, bias=False)
4853
self.w3 = nn.Linear(dim, ffn_hidden_dim, bias=False)
@@ -104,5 +109,86 @@ def __init__(self, params: ModelArgs):
104109
self.n_layers = params.n_layers
105110
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
106111
self.layers = nn.ModuleList(TransformerBlock(params) for _ in range(params.n_layers))
107-
self.norm = nn.RMSNorm(params.dim, eps=params.norm_eps)
108-
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
112+
self.output_norm = nn.RMSNorm(params.dim, eps=params.norm_eps)
113+
self.lm_head = nn.Linear(params.dim, params.vocab_size, bias=False)
114+
self.rotary_embeddings = RotaryEmbedding(params.dim // params.n_heads)
115+
116+
def forward(self, inputs: torch.Tensor):
117+
bsz, seqlen = inputs.shape
118+
h = self.tok_embeddings(inputs)
119+
freqs_cis = self.rotary_embeddings(seqlen)
120+
mask = torch.full((seqlen, seqlen), float("-inf"), device=inputs.device)
121+
mask = torch.triu(mask, diagonal=1)
122+
mask = mask.type_as(h)
123+
for layer in self.layers:
124+
h = layer(h, freqs_cis, mask)
125+
h = self.output_norm(h)
126+
logits = self.lm_head(h)
127+
return logits
128+
129+
130+
if __name__ == "__main__":
131+
# torch.cuda.memory._record_memory_history()
132+
133+
def get_device():
134+
if torch.cuda.is_available():
135+
return torch.device("cuda")
136+
elif torch.backends.mps.is_available():
137+
return torch.device("mps")
138+
return torch.device("cpu")
139+
140+
def get_mock_batch(batch_size, seq_len, vocab_size):
141+
src = torch.randint(0, vocab_size, (batch_size, seq_len), device=get_device())
142+
tgt = torch.cat((src[:, 1:], torch.randint(0, vocab_size, (batch_size, 1), device=get_device())), dim=1)
143+
return src, tgt
144+
145+
def trace_handler(prof: torch.profiler.profile):
146+
# Prefix for file names.
147+
host_name = socket.gethostname()
148+
timestamp = datetime.now().strftime("%b_%d_%H_%M_%S")
149+
file_prefix = f"{host_name}_{timestamp}"
150+
151+
# Construct the trace file.
152+
prof.export_chrome_trace(f"{file_prefix}.json.gz")
153+
154+
# Construct the memory timeline file.
155+
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
156+
157+
with torch.profiler.profile(
158+
activities=[
159+
torch.profiler.ProfilerActivity.CPU,
160+
torch.profiler.ProfilerActivity.CUDA,
161+
],
162+
schedule=torch.profiler.schedule(wait=0, warmup=0, active=3, repeat=1),
163+
record_shapes=True,
164+
profile_memory=True,
165+
with_stack=True,
166+
on_trace_ready=trace_handler,
167+
) as prof:
168+
args = ModelArgs()
169+
num_batches = 3
170+
learning_rate = 1e-5
171+
model = Llama(args)
172+
model.to(device=get_device())
173+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
174+
criterion = nn.CrossEntropyLoss()
175+
176+
model.train()
177+
total_loss = 0
178+
for i in range(num_batches):
179+
prof.step()
180+
src, tgt = get_mock_batch(args.max_batch_size, args.max_seq_len, args.vocab_size)
181+
with record_function("## forward ##"):
182+
logits = model(src)
183+
184+
with record_function("## backward ##"):
185+
loss = criterion(logits.view(-1, args.vocab_size), tgt.view(-1))
186+
loss.backward()
187+
with record_function("## optimizer ##"):
188+
optimizer.step()
189+
optimizer.zero_grad()
190+
total_loss += loss.item()
191+
print(f"Step {i + 1}, Loss: {loss.item():.4f}")
192+
193+
prof.export_memory_timeline("llama.html", device="cuda:0")
194+
# torch.cuda.memory._dump_snapshot("llama.pickle")

uncollects/learn_by_do/torch_lab/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
8+
"numpy>=2.2.5",
89
"torch>=2.7.0",
10+
"torchvision>=0.22.0",
911
]

0 commit comments

Comments
 (0)