Skip to content

Commit a04665b

Browse files
committed
init llama infer
1 parent 9721837 commit a04665b

File tree

1 file changed

+69
-0
lines changed
  • torchprime/experimental/torchax_models

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torchax.interop
2+
from llama import model
3+
import torch
4+
import torchax
5+
import torchax.config
6+
import jax
7+
import time
8+
9+
env = torchax.default_env()
10+
torch.manual_seed(42)
11+
torch.set_default_dtype(torch.bfloat16)
12+
torchax.enable_performance_mode()
13+
14+
max_seq_len = 512 # 8192
15+
vocab_size = 128 # 32000
16+
n_layer = 1
17+
n_heads = 4
18+
dim = 8
19+
block_size = 16 # 2048
20+
batch_size = 1
21+
22+
23+
def fake_dataloader(size, vocab_size, seqlen, batch_size):
24+
for _ in range(size):
25+
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cpu")
26+
yield x
27+
28+
29+
if __name__ == "__main__":
30+
with torch.no_grad():
31+
input = torch.randint(0, vocab_size, (1, max_seq_len))
32+
model_args = model.ModelArgs(
33+
block_size=block_size,
34+
vocab_size=vocab_size,
35+
n_layer=n_layer,
36+
n_heads=n_heads,
37+
dim=dim,
38+
max_seq_len=max_seq_len,
39+
)
40+
freqs_cis = model.precompute_freqs_cis(
41+
model_args.dim // model_args.n_heads,
42+
model_args.max_seq_len,
43+
model_args.rope_theta,
44+
model_args.use_scaled_rope,
45+
).to(torch.bfloat16)
46+
m = model.Transformer(model_args)
47+
m.to(torch.bfloat16)
48+
49+
def forward(input, freqs_cis, mask):
50+
return m(input, 0, freqs_cis=freqs_cis, mask=mask)
51+
52+
jitted_forward = torchax.interop.jax_jit(forward)
53+
54+
data_iter = fake_dataloader(5, vocab_size, max_seq_len, batch_size)
55+
with env:
56+
m.to("jax")
57+
freqs_cis = freqs_cis.to("jax")
58+
for i, input in enumerate(data_iter):
59+
input = input.to("jax")
60+
mask = torch.ones_like(input)
61+
step_start = time.perf_counter()
62+
output = jitted_forward(input, freqs_cis, mask)
63+
jax.block_until_ready(torchax.tensor.t2j(output))
64+
step_end = time.perf_counter()
65+
print(
66+
i,
67+
"step latency: ",
68+
step_end - step_start,
69+
)

0 commit comments

Comments
 (0)