33from parameterized import parameterized_class
44from transformers import AutoTokenizer
55
6- from tests .tools import CHAT_TEMPLATE , RayUnittestBaseAsync , get_model_path , get_template_config
6+ from tests .tools import (
7+ CHAT_TEMPLATE ,
8+ RayUnittestBaseAsync ,
9+ get_model_path ,
10+ get_template_config ,
11+ )
712from trinity .common .models import create_explorer_models
813from trinity .common .models .model import ModelWrapper
914
@@ -20,7 +25,9 @@ async def prepare_engines(engines, auxiliary_engines):
2025
2126def assert_experience_tokens_match_text (test_case , tokenizer , exp , prompt_contents , response_text ):
2227 full_text = tokenizer .decode (exp .tokens .tolist (), skip_special_tokens = False )
23- prompt_text = tokenizer .decode (exp .tokens [: exp .prompt_length ].tolist (), skip_special_tokens = False )
28+ prompt_text = tokenizer .decode (
29+ exp .tokens [: exp .prompt_length ].tolist (), skip_special_tokens = False
30+ )
2431 decoded_response_text = tokenizer .decode (
2532 exp .tokens [exp .prompt_length :].tolist (), skip_special_tokens = False
2633 )
@@ -67,7 +74,9 @@ def setUp(self):
6774 def _assert_experience_matches_text (self , exp , prompt_contents , response_text ):
6875 self .assertGreater (exp .prompt_length , 0 )
6976 self .assertGreater (len (exp .tokens ), exp .prompt_length )
70- assert_experience_tokens_match_text (self , self .tokenizer , exp , prompt_contents , response_text )
77+ assert_experience_tokens_match_text (
78+ self , self .tokenizer , exp , prompt_contents , response_text
79+ )
7180
7281 def _assert_history_matches_responses (self , expected_count , prompt_contents , response_texts ):
7382 if not self .enable_history :
0 commit comments