Skip to content

Commit

Permalink
Also TEMPERATURE variables in params to make it configurable in deriv…
Browse files Browse the repository at this point in the history
…ed class w/o the extra inputs field.

PiperOrigin-RevId: 675692334
Change-Id: Ie1572096d2ecacebd8fea26a2787d09536757431
  • Loading branch information
Sax Authors authored and copybara-github committed Sep 17, 2024
1 parent 4684536 commit 869e61c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions saxml/server/pax/lm/params/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CommonServingTemplate:
MAX_SEQ_LEN = None
NUM_SAMPLES = 2
TOP_K = 40
TEMPERATURE = 0.0
TOP_K_RECALL_TARGET = 1.0 # When < 1.0, use tpu optimized approx_max_k
USE_TOP_K_FOR_LOGPROBS = False
BEAM_SIZE = 4
Expand Down Expand Up @@ -263,7 +264,7 @@ def generate(self) -> Optional[servable_lm_model.DecodeHParams]:
max_decode_steps=self.MAX_DECODE_STEPS,
seqlen=seqlen,
num_samples=self.NUM_SAMPLES,
temperature=0.0,
temperature=self.TEMPERATURE,
eos_id=stop_token_ids,
k=self.TOP_K,
top_k_recall_target=self.TOP_K_RECALL_TARGET,
Expand Down Expand Up @@ -347,7 +348,7 @@ def generate_stream(self) -> Optional[servable_lm_model.DecodeHParams]:
max_decode_steps=self.MAX_DECODE_STEPS,
seqlen=seqlen,
num_samples=self.NUM_SAMPLES,
temperature=0.0,
temperature=self.TEMPERATURE,
eos_id=stop_token_ids,
k=self.TOP_K,
top_k_recall_target=self.TOP_K_RECALL_TARGET,
Expand Down

0 comments on commit 869e61c

Please sign in to comment.