-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
71 lines (54 loc) · 2.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from praxis import PraxisConfig, PraxisForCausalLM, PraxisModel
AutoConfig.register("praxis", PraxisConfig)
AutoModel.register(PraxisConfig, PraxisModel)
AutoModelForCausalLM.register(PraxisConfig, PraxisForCausalLM)
def test_praxis_model():
# Initialize configuration
config = PraxisConfig(
n_dim=768,
n_layer=12,
n_head=12,
)
# Initialize tokenizer (using GPT-2 tokenizer as a placeholder)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# Initialize model
model = AutoModelForCausalLM.from_config(config)
model.eval()
# Generate dummy input
input_text = "Hello, world! This is a test."
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True, return_dict=True)
# Check outputs
print("Model Output Shape:", outputs.last_hidden_state.shape)
if outputs.hidden_states is not None:
print("Number of layers in output:", len(outputs.hidden_states))
else:
print("Hidden states not returned")
def test_praxis_for_causal_lm():
# Initialize configuration
config = PraxisConfig(
n_dim=768,
n_layer=12,
n_head=12,
)
# Initialize tokenizer (using GPT-2 tokenizer as a placeholder)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# Initialize model
model = AutoModelForCausalLM.from_config(config)
model.eval()
# Generate dummy input
input_text = "Hello, world! This is a test."
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Test text generation
generated = model.generate(input_ids, max_new_tokens=16, num_return_sequences=1)
generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
print("Generated Text:", generated_text)
if __name__ == "__main__":
print("Testing PraxisModel...")
test_praxis_model()
print("\nTesting PraxisForCausalLM...")
test_praxis_for_causal_lm()