6464from  datetime  import  datetime 
6565import  gc 
6666import  json 
67+ import  os 
6768import  random 
6869import  time 
6970from  typing  import  Any , AsyncGenerator , Optional 
70- import  os 
71- 
7271
72+ from  benchmarks .eval_accuracy  import  eval_accuracy 
73+ from  benchmarks .metrics  import  CounterMetric , EventMetric 
7374import  grpc 
74- from  benchmarks .metrics  import  EventMetric , CounterMetric 
7575from  jetstream .core .proto  import  jetstream_pb2 
7676from  jetstream .core .proto  import  jetstream_pb2_grpc 
7777from  jetstream .engine .token_utils  import  load_vocab 
7878from  jetstream .external_tokenizers .llama3  import  llama3_tokenizer 
7979import  numpy  as  np 
80- from  tqdm .asyncio  import  tqdm   # pytype: disable=pyi-error 
8180import  pandas 
82- 
83- from  eval_accuracy  import  eval_accuracy 
81+ from  tqdm .asyncio  import  tqdm   # pytype: disable=pyi-error 
8482from  transformers  import  AutoTokenizer 
8583
8684
@@ -706,136 +704,7 @@ def sample_warmup_requests(requests):
706704        break 
707705
708706
709- def  main (args : argparse .Namespace ):
710-   print (args )
711-   random .seed (args .seed )
712-   np .random .seed (args .seed )
713- 
714-   model_id  =  args .model 
715-   tokenizer_id  =  args .tokenizer 
716-   use_hf_tokenizer  =  args .use_hf_tokenizer 
717- 
718-   prefill_quota  =  AsyncCounter (init_value = 3 )
719-   active_req_quota  =  AsyncCounter (init_value = 450 )
720- 
721-   api_url  =  f"{ args .server } { args .port }  
722- 
723-   tokenizer  =  get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
724-   if  tokenizer  ==  "test"  or  args .dataset  ==  "test" :
725-     input_requests  =  mock_requests (
726-         args .total_mock_requests 
727-     )  # e.g. [("AB", 2, "AB", 3)] 
728-   else :
729-     dataset  =  []
730-     if  args .dataset  ==  "openorca" :
731-       dataset  =  load_openorca_dataset_pkl (args .dataset_path )
732-     elif  args .dataset  ==  "sharegpt" :
733-       dataset  =  load_sharegpt_dataset (
734-           args .dataset_path ,
735-           args .conversation_starter ,
736-       )
737- 
738-     # A given args.max_output_length value is the max generation step, 
739-     # when the args.max_output_length is default to None, the sample's golden 
740-     # output length will be used to decide the generation step. 
741-     input_requests  =  sample_requests (
742-         dataset = dataset ,
743-         tokenizer = tokenizer ,
744-         num_requests = args .num_prompts ,
745-         max_output_length = args .max_output_length ,
746-     )
747- 
748-   warmup_requests  =  None 
749-   if  args .warmup_mode  ==  "full" :
750-     warmup_requests  =  input_requests 
751-   elif  args .warmup_mode  ==  "sampled" :
752-     warmup_requests  =  list (sample_warmup_requests (input_requests )) *  2 
753- 
754-   if  warmup_requests :
755-     print (f"Warmup (mode: { args .warmup_mode }  )
756-     _ , _  =  asyncio .run (
757-         benchmark (
758-             api_url = api_url ,
759-             tokenizer = tokenizer ,
760-             input_requests = warmup_requests ,
761-             request_rate = args .request_rate ,
762-             disable_tqdm = args .disable_tqdm ,
763-             prefill_quota = prefill_quota ,
764-             active_req_quota = active_req_quota ,
765-             is_warmup = True ,
766-         )
767-     )
768-     print (f"Warmup (mode: { args .warmup_mode }  )
769- 
770-   # TODO: Replace this with warmup complete signal once supported. 
771-   # Wait for server completely warmup before running the benchmark. 
772-   time .sleep (5 )
773- 
774-   benchmark_result , request_outputs  =  asyncio .run (
775-       benchmark (
776-           api_url = api_url ,
777-           tokenizer = tokenizer ,
778-           input_requests = input_requests ,
779-           request_rate = args .request_rate ,
780-           disable_tqdm = args .disable_tqdm ,
781-           prefill_quota = prefill_quota ,
782-           active_req_quota = active_req_quota ,
783-       )
784-   )
785- 
786-   # Process output 
787-   output  =  [output .to_dict () for  output  in  request_outputs ]
788-   if  args .run_eval :
789-     eval_json  =  eval_accuracy (output )
790- 
791-   # Save config and results to json 
792-   if  args .save_result :
793-     # dimensions values are strings 
794-     dimensions_json  =  {}
795-     # metrics values are numerical 
796-     metrics_json  =  {}
797- 
798-     # Setup 
799-     current_dt  =  datetime .now ().strftime ("%Y%m%d-%H%M%S" )
800-     dimensions_json ["date" ] =  current_dt 
801-     dimensions_json ["model_id" ] =  model_id 
802-     dimensions_json ["tokenizer_id" ] =  tokenizer_id 
803-     if  args .additional_metadata_metrics_to_save  is  not None :
804-       dimensions_json  =  {
805-           ** dimensions_json ,
806-           ** json .loads (args .additional_metadata_metrics_to_save ),
807-       }
808-     metrics_json ["num_prompts" ] =  args .num_prompts 
809- 
810-     # Traffic 
811-     metrics_json ["request_rate" ] =  args .request_rate 
812-     metrics_json  =  {** metrics_json , ** benchmark_result }
813-     if  args .run_eval :
814-       metrics_json  =  {** metrics_json , ** eval_json }
815- 
816-     final_json  =  {}
817-     final_json ["metrics" ] =  metrics_json 
818-     final_json ["dimensions" ] =  dimensions_json 
819- 
820-     # Save to file 
821-     base_model_id  =  model_id .split ("/" )[- 1 ]
822-     file_name  =  (
823-         f"JetStream-{ args .request_rate } { base_model_id } { current_dt }  
824-     )
825-     with  open (file_name , "w" , encoding = "utf-8" ) as  outfile :
826-       json .dump (final_json , outfile )
827- 
828-   if  args .save_request_outputs :
829-     file_path  =  args .request_outputs_file_path 
830-     with  open (file_path , "w" , encoding = "utf-8" ) as  output_file :
831-       json .dump (
832-           output ,
833-           output_file ,
834-           indent = 4 ,
835-       )
836- 
837- 
838- if  __name__  ==  "__main__" :
707+ def  parse_args () ->  argparse .Namespace :
839708  parser  =  argparse .ArgumentParser (
840709      description = "Benchmark the online serving throughput." 
841710  )
@@ -909,7 +778,6 @@ def main(args: argparse.Namespace):
909778      default = 150 ,
910779      help = "The maximum number of mock requests to send for benchmark testing." ,
911780  )
912- 
913781  parser .add_argument (
914782      "--max-output-length" ,
915783      type = int ,
@@ -926,7 +794,6 @@ def main(args: argparse.Namespace):
926794          "the output length of the golden dataset would be passed." 
927795      ),
928796  )
929- 
930797  parser .add_argument ("--seed" , type = int , default = 0 )
931798  parser .add_argument (
932799      "--disable-tqdm" ,
@@ -977,7 +844,138 @@ def main(args: argparse.Namespace):
977844      choices = ["human" , "gpt" , "both" ],
978845      help = "What entity should be the one starting the conversations." ,
979846  )
847+   return  parser .parse_args ()
848+ 
849+ 
850+ def  main (args : argparse .Namespace ):
851+   print (args )
852+   random .seed (args .seed )
853+   np .random .seed (args .seed )
854+ 
855+   model_id  =  args .model 
856+   tokenizer_id  =  args .tokenizer 
857+   use_hf_tokenizer  =  args .use_hf_tokenizer 
858+ 
859+   prefill_quota  =  AsyncCounter (init_value = 3 )
860+   active_req_quota  =  AsyncCounter (init_value = 450 )
861+ 
862+   api_url  =  f"{ args .server } { args .port }  
863+ 
864+   tokenizer  =  get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
865+   if  tokenizer  ==  "test"  or  args .dataset  ==  "test" :
866+     input_requests  =  mock_requests (
867+         args .total_mock_requests 
868+     )  # e.g. [("AB", 2, "AB", 3)] 
869+   else :
870+     dataset  =  []
871+     if  args .dataset  ==  "openorca" :
872+       dataset  =  load_openorca_dataset_pkl (args .dataset_path )
873+     elif  args .dataset  ==  "sharegpt" :
874+       dataset  =  load_sharegpt_dataset (
875+           args .dataset_path ,
876+           args .conversation_starter ,
877+       )
878+ 
879+     # A given args.max_output_length value is the max generation step, 
880+     # when the args.max_output_length is default to None, the sample's golden 
881+     # output length will be used to decide the generation step. 
882+     input_requests  =  sample_requests (
883+         dataset = dataset ,
884+         tokenizer = tokenizer ,
885+         num_requests = args .num_prompts ,
886+         max_output_length = args .max_output_length ,
887+     )
888+ 
889+   warmup_requests  =  None 
890+   if  args .warmup_mode  ==  "full" :
891+     warmup_requests  =  input_requests 
892+   elif  args .warmup_mode  ==  "sampled" :
893+     warmup_requests  =  list (sample_warmup_requests (input_requests )) *  2 
894+ 
895+   if  warmup_requests :
896+     print (f"Warmup (mode: { args .warmup_mode }  )
897+     _ , _  =  asyncio .run (
898+         benchmark (
899+             api_url = api_url ,
900+             tokenizer = tokenizer ,
901+             input_requests = warmup_requests ,
902+             request_rate = args .request_rate ,
903+             disable_tqdm = args .disable_tqdm ,
904+             prefill_quota = prefill_quota ,
905+             active_req_quota = active_req_quota ,
906+             is_warmup = True ,
907+         )
908+     )
909+     print (f"Warmup (mode: { args .warmup_mode }  )
910+ 
911+   # TODO: Replace this with warmup complete signal once supported. 
912+   # Wait for server completely warmup before running the benchmark. 
913+   time .sleep (5 )
914+ 
915+   benchmark_result , request_outputs  =  asyncio .run (
916+       benchmark (
917+           api_url = api_url ,
918+           tokenizer = tokenizer ,
919+           input_requests = input_requests ,
920+           request_rate = args .request_rate ,
921+           disable_tqdm = args .disable_tqdm ,
922+           prefill_quota = prefill_quota ,
923+           active_req_quota = active_req_quota ,
924+       )
925+   )
926+ 
927+   # Process output 
928+   output  =  [output .to_dict () for  output  in  request_outputs ]
929+   if  args .run_eval :
930+     eval_json  =  eval_accuracy (output )
931+ 
932+   # Save config and results to json 
933+   if  args .save_result :
934+     # dimensions values are strings 
935+     dimensions_json  =  {}
936+     # metrics values are numerical 
937+     metrics_json  =  {}
980938
981-   parsed_args  =  parser .parse_args ()
939+     # Setup 
940+     current_dt  =  datetime .now ().strftime ("%Y%m%d-%H%M%S" )
941+     dimensions_json ["date" ] =  current_dt 
942+     dimensions_json ["model_id" ] =  model_id 
943+     dimensions_json ["tokenizer_id" ] =  tokenizer_id 
944+     if  args .additional_metadata_metrics_to_save  is  not None :
945+       dimensions_json  =  {
946+           ** dimensions_json ,
947+           ** json .loads (args .additional_metadata_metrics_to_save ),
948+       }
949+     metrics_json ["num_prompts" ] =  args .num_prompts 
950+ 
951+     # Traffic 
952+     metrics_json ["request_rate" ] =  args .request_rate 
953+     metrics_json  =  {** metrics_json , ** benchmark_result }
954+     if  args .run_eval :
955+       metrics_json  =  {** metrics_json , ** eval_json }
956+ 
957+     final_json  =  {}
958+     final_json ["metrics" ] =  metrics_json 
959+     final_json ["dimensions" ] =  dimensions_json 
960+ 
961+     # Save to file 
962+     base_model_id  =  model_id .split ("/" )[- 1 ]
963+     file_name  =  (
964+         f"JetStream-{ args .request_rate } { base_model_id } { current_dt }  
965+     )
966+     with  open (file_name , "w" , encoding = "utf-8" ) as  outfile :
967+       json .dump (final_json , outfile )
968+ 
969+   if  args .save_request_outputs :
970+     file_path  =  args .request_outputs_file_path 
971+     with  open (file_path , "w" , encoding = "utf-8" ) as  output_file :
972+       json .dump (
973+           output ,
974+           output_file ,
975+           indent = 4 ,
976+       )
977+ 
978+ 
979+ if  __name__  ==  "__main__" :
982980  gc .disable ()
983-   main (parsed_args )
981+   main (parse_args () )
0 commit comments