We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2f47c70 commit bc75432Copy full SHA for bc75432
fms_extras/utils/generation.py
@@ -593,7 +593,9 @@ def speculative_generate(
593
if do_sample:
594
do_sample_vector = torch.ones(bsize, device=logits.device, dtype=torch.bool)
595
else:
596
- do_sample_vector = torch.zeros(bsize, device=logits.device, dtype=torch.bool)
+ do_sample_vector = torch.zeros(
597
+ bsize, device=logits.device, dtype=torch.bool
598
+ )
599
next_vals = __generate_targets(
600
logits, do_sample_vector, temperature=temperature, top_k=top_k
601
)
0 commit comments