Skip to content

Commit

Permalink
add requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Sep 20, 2023
2 parents c2016ad + 300a2e8 commit 60740d0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers==4.29.2
torch==2.0.1
6 changes: 3 additions & 3 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
print(f"small (approx) model autoregressive_sampling: {generated_text}")

torch.manual_seed(123)
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, verbose = verbose)
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"google's speculative_sampling: {generated_text}")

torch.manual_seed(123)
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, )
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"deepmind's speculative_sampling: {generated_text}")

if __name__ == "__main__":
args = parse_arguments()
generate(args.input, args.approx_model_name, args.target_model_name, args.seed, verbose=args.verbose)
generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose)


0 comments on commit 60740d0

Please sign in to comment.