diff --git a/saxml/server/pax/lm/params/template.py b/saxml/server/pax/lm/params/template.py index a0ae3e9..4d31f90 100644 --- a/saxml/server/pax/lm/params/template.py +++ b/saxml/server/pax/lm/params/template.py @@ -56,6 +56,7 @@ class CommonServingTemplate: MAX_SEQ_LEN = None NUM_SAMPLES = 2 TOP_K = 40 + TEMPERATURE = 0.0 TOP_K_RECALL_TARGET = 1.0 # When < 1.0, use tpu optimized approx_max_k USE_TOP_K_FOR_LOGPROBS = False BEAM_SIZE = 4 @@ -263,7 +264,7 @@ def generate(self) -> Optional[servable_lm_model.DecodeHParams]: max_decode_steps=self.MAX_DECODE_STEPS, seqlen=seqlen, num_samples=self.NUM_SAMPLES, - temperature=0.0, + temperature=self.TEMPERATURE, eos_id=stop_token_ids, k=self.TOP_K, top_k_recall_target=self.TOP_K_RECALL_TARGET, @@ -347,7 +348,7 @@ def generate_stream(self) -> Optional[servable_lm_model.DecodeHParams]: max_decode_steps=self.MAX_DECODE_STEPS, seqlen=seqlen, num_samples=self.NUM_SAMPLES, - temperature=0.0, + temperature=self.TEMPERATURE, eos_id=stop_token_ids, k=self.TOP_K, top_k_recall_target=self.TOP_K_RECALL_TARGET,