55from text_generation .types import FinishReason , PrefillToken , Token
66
77
8- def test_generate (bloom_url , hf_headers ):
9- client = Client (bloom_url , hf_headers )
8+ def test_generate (flan_t5_xxl_url , hf_headers ):
9+ client = Client (flan_t5_xxl_url , hf_headers )
1010 response = client .generate ("test" , max_new_tokens = 1 )
1111
12- assert response .generated_text == ". "
12+ assert response .generated_text == ""
1313 assert response .details .finish_reason == FinishReason .Length
1414 assert response .details .generated_tokens == 1
1515 assert response .details .seed is None
1616 assert len (response .details .prefill ) == 1
17- assert response .details .prefill [0 ] == PrefillToken (
18- id = 9234 , text = "test" , logprob = None
19- )
17+ assert response .details .prefill [0 ] == PrefillToken (id = 0 , text = "<pad>" , logprob = None )
2018 assert len (response .details .tokens ) == 1
2119 assert response .details .tokens [0 ] == Token (
22- id = 17 , text = ". " , logprob = - 1.75 , special = False
20+ id = 3 , text = " " , logprob = - 1.984375 , special = False
2321 )
2422
2523
24+ def test_generate_best_of (flan_t5_xxl_url , hf_headers ):
25+ client = Client (flan_t5_xxl_url , hf_headers )
26+ response = client .generate ("test" , max_new_tokens = 1 , best_of = 2 , do_sample = True )
27+
28+ assert response .details .seed is not None
29+ assert response .details .best_of_sequences is not None
30+ assert len (response .details .best_of_sequences ) == 1
31+ assert response .details .best_of_sequences [0 ].seed is not None
32+
33+
2634def test_generate_not_found (fake_url , hf_headers ):
2735 client = Client (fake_url , hf_headers )
2836 with pytest .raises (NotFoundError ):
@@ -35,16 +43,16 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
3543 client .generate ("test" , max_new_tokens = 10_000 )
3644
3745
38- def test_generate_stream (bloom_url , hf_headers ):
39- client = Client (bloom_url , hf_headers )
46+ def test_generate_stream (flan_t5_xxl_url , hf_headers ):
47+ client = Client (flan_t5_xxl_url , hf_headers )
4048 responses = [
4149 response for response in client .generate_stream ("test" , max_new_tokens = 1 )
4250 ]
4351
4452 assert len (responses ) == 1
4553 response = responses [0 ]
4654
47- assert response .generated_text == ". "
55+ assert response .generated_text == ""
4856 assert response .details .finish_reason == FinishReason .Length
4957 assert response .details .generated_tokens == 1
5058 assert response .details .seed is None
@@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
6371
6472
6573@pytest .mark .asyncio
66- async def test_generate_async (bloom_url , hf_headers ):
67- client = AsyncClient (bloom_url , hf_headers )
74+ async def test_generate_async (flan_t5_xxl_url , hf_headers ):
75+ client = AsyncClient (flan_t5_xxl_url , hf_headers )
6876 response = await client .generate ("test" , max_new_tokens = 1 )
6977
70- assert response .generated_text == ". "
78+ assert response .generated_text == ""
7179 assert response .details .finish_reason == FinishReason .Length
7280 assert response .details .generated_tokens == 1
7381 assert response .details .seed is None
7482 assert len (response .details .prefill ) == 1
75- assert response .details .prefill [0 ] == PrefillToken (
76- id = 9234 , text = "test" , logprob = None
77- )
83+ assert response .details .prefill [0 ] == PrefillToken (id = 0 , text = "<pad>" , logprob = None )
7884 assert len (response .details .tokens ) == 1
7985 assert response .details .tokens [0 ] == Token (
80- id = 17 , text = ". " , logprob = - 1.75 , special = False
86+ id = 3 , text = " " , logprob = - 1.984375 , special = False
8187 )
8288
8389
@@ -96,16 +102,16 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
96102
97103
98104@pytest .mark .asyncio
99- async def test_generate_stream_async (bloom_url , hf_headers ):
100- client = AsyncClient (bloom_url , hf_headers )
105+ async def test_generate_stream_async (flan_t5_xxl_url , hf_headers ):
106+ client = AsyncClient (flan_t5_xxl_url , hf_headers )
101107 responses = [
102108 response async for response in client .generate_stream ("test" , max_new_tokens = 1 )
103109 ]
104110
105111 assert len (responses ) == 1
106112 response = responses [0 ]
107113
108- assert response .generated_text == ". "
114+ assert response .generated_text == ""
109115 assert response .details .finish_reason == FinishReason .Length
110116 assert response .details .generated_tokens == 1
111117 assert response .details .seed is None
0 commit comments