Skip to content

Commit 98db5f8

Browse files
committed
test(tgi): add granite tests
1 parent 429aeb5 commit 98db5f8

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

text-generation-inference/tests/fixtures/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
"model_id": "Qwen/Qwen2.5-0.5B",
4242
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"},
4343
},
44+
"granite": {
45+
"model_id": "ibm-granite/granite-3.1-2b-instruct",
46+
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"},
47+
},
4448
}
4549

4650

text-generation-inference/tests/integration/test_generate.py

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ async def test_model_single_request(tgi_service):
2525
"llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
2626
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
2727
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
28+
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
2829
}
2930
assert response.generated_text == greedy_expectations[service_name]
3031

@@ -50,6 +51,11 @@ async def test_model_single_request(tgi_service):
5051
"llama": "Deep Learning",
5152
"mistral": "Deep learning",
5253
"qwen2": "Deep Learning",
54+
<<<<<<< HEAD
55+
"granite": "Deep Learning",
56+
=======
57+
"granite": "Deep learning",
58+
>>>>>>> 209eb21 (test(tgi): add granite tests)
5359
}
5460
assert sample_expectations[service_name] in response
5561

@@ -84,6 +90,7 @@ async def test_model_multiple_requests(tgi_service, generate_load):
8490
"llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
8591
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
8692
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
93+
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
8794
}
8895
expected = expectations[tgi_service.client.service_name]
8996
for r in responses:

text-generation-inference/tests/server/test_decode.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _test_decode(config_name, generator, do_sample):
4040
"llama": "George Orwell, 1984",
4141
"mistral": "The sky was",
4242
"qwen2": " A young woman with",
43+
"granite": "Aldous Huxley, Brave New World",
4344
}[config_name]
4445
assert expected_text in output.text
4546
else:
@@ -49,5 +50,6 @@ def _test_decode(config_name, generator, do_sample):
4950
"llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
5051
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
5152
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
53+
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
5254
}[config_name]
5355
assert output.text == expected_text

text-generation-inference/tests/server/test_prefill.py

+3
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
3939
"llama": [10058, " George"],
4040
"mistral": [450, " The"],
4141
"qwen2": [362, " A"],
42+
"granite": [429, " -"],
4243
}[config_name]
4344
else:
4445
expectations = {
4546
"gpt2": [198, "\n"],
4647
"llama": [10058, " George"],
4748
"mistral": [13, "\n"],
4849
"qwen2": [358, " I"],
50+
"granite": [203, "\n"],
4951
}[config_name]
5052
for g in generations:
5153
tokens = g.tokens
@@ -80,6 +82,7 @@ def test_prefill_truncate(neuron_model_config):
8082
"llama": [" —", " The", " He", " He"],
8183
"mistral": [" He", "\n", " He", " He"],
8284
"qwen2": [" He", " The", " He", " He"],
85+
"granite": ["\n", "\n", " I", " He"],
8386
}[config_name]
8487
for i, g in enumerate(generations):
8588
tokens = g.tokens

0 commit comments

Comments
 (0)