Skip to content

Commit

Permalink
test(tgi): add granite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Dec 23, 2024
1 parent 429aeb5 commit a48cb6d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions text-generation-inference/tests/fixtures/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
"model_id": "Qwen/Qwen2.5-0.5B",
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"},
},
"granite": {
"model_id": "ibm-granite/granite-3.1-2b-instruct",
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"},
},
}


Expand Down
3 changes: 3 additions & 0 deletions text-generation-inference/tests/integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def test_model_single_request(tgi_service):
"llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
"granite": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
}
assert response.generated_text == greedy_expectations[service_name]

Expand All @@ -50,6 +51,7 @@ async def test_model_single_request(tgi_service):
"llama": "Deep Learning",
"mistral": "Deep learning",
"qwen2": "Deep Learning",
"granite": "Deep Learning",
}
assert sample_expectations[service_name] in response

Expand Down Expand Up @@ -84,6 +86,7 @@ async def test_model_multiple_requests(tgi_service, generate_load):
"llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
"granite": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
}
expected = expectations[tgi_service.client.service_name]
for r in responses:
Expand Down
2 changes: 2 additions & 0 deletions text-generation-inference/tests/server/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _test_decode(config_name, generator, do_sample):
"llama": "George Orwell, 1984",
"mistral": "The sky was",
"qwen2": " A young woman with",
"granite": " A young woman with",
}[config_name]
assert expected_text in output.text
else:
Expand All @@ -49,5 +50,6 @@ def _test_decode(config_name, generator, do_sample):
"llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
"granite": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
}[config_name]
assert output.text == expected_text
3 changes: 3 additions & 0 deletions text-generation-inference/tests/server/test_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
"llama": [10058, " George"],
"mistral": [450, " The"],
"qwen2": [362, " A"],
"granite": [362, " A"],
}[config_name]
else:
expectations = {
"gpt2": [198, "\n"],
"llama": [10058, " George"],
"mistral": [13, "\n"],
"qwen2": [358, " I"],
"granite": [358, " I"],
}[config_name]
for g in generations:
tokens = g.tokens
Expand Down Expand Up @@ -80,6 +82,7 @@ def test_prefill_truncate(neuron_model_config):
"llama": [" —", " The", " He", " He"],
"mistral": [" He", "\n", " He", " He"],
"qwen2": [" He", " The", " He", " He"],
"granite": [" He", " The", " He", " He"],
}[config_name]
for i, g in enumerate(generations):
tokens = g.tokens
Expand Down

0 comments on commit a48cb6d

Please sign in to comment.