|
4 | 4 | import pytest
|
5 | 5 | import torch
|
6 | 6 | from pydantic import BaseModel, constr
|
| 7 | +from vllm.outputs import RequestOutput |
7 | 8 | from vllm.sampling_params import SamplingParams
|
8 | 9 |
|
9 | 10 | import outlines.generate as generate
|
@@ -42,15 +43,31 @@ def test_vllm_generation_api(model, generator_type, params):
|
42 | 43 | res = generator("test", stop_at=[".", "ab"])
|
43 | 44 | assert isinstance(res, str)
|
44 | 45 |
|
| 46 | + res = generator("test", original_output=True) |
| 47 | + assert isinstance(res, list) |
| 48 | + assert len(res) == 1 |
| 49 | + assert isinstance(res[0], RequestOutput) |
| 50 | + |
45 | 51 | res1 = generator("test", seed=1)
|
46 | 52 | res2 = generator("test", seed=1)
|
47 | 53 | assert isinstance(res1, str)
|
48 | 54 | assert isinstance(res2, str)
|
49 | 55 | assert res1 == res2
|
50 | 56 |
|
| 57 | + res1 = generator("test", seed=1, original_output=True) |
| 58 | + res2 = generator("test", seed=1) |
| 59 | + assert isinstance(res1[0], RequestOutput) |
| 60 | + assert isinstance(res2, str) |
| 61 | + text1 = [sample.text for sample in res1[0].outputs] |
| 62 | + assert len(text1) == 1 |
| 63 | + assert text1[0] == res2 |
| 64 | + |
51 | 65 | res = generator(["test", "test1"])
|
52 | 66 | assert len(res) == 2
|
53 | 67 |
|
| 68 | + res = generator(["test", "test1"], original_output=True) |
| 69 | + assert len(res) == 2 |
| 70 | + |
54 | 71 |
|
55 | 72 | def test_vllm_sampling_params(model):
|
56 | 73 | generator = generate.text(model)
|
|
0 commit comments