Skip to content

Commit

Permalink
Use tiktoken (openai#1044)
Browse files Browse the repository at this point in the history
* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <[email protected]>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <[email protected]>
  • Loading branch information
jongwook and Majdoddin authored Mar 13, 2023
1 parent ad3250a commit 839639a
Show file tree
Hide file tree
Showing 15 changed files with 100,601 additions and 100,096 deletions.
2 changes: 0 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@ include requirements.txt
include README.md
include LICENSE
include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ numpy
torch
tqdm
more-itertools
transformers>=4.19.0
tiktoken==0.3.1
ffmpeg-python==0.2.0
7 changes: 6 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import whisper
from whisper.tokenizer import get_tokenizer


@pytest.mark.parametrize("model_name", whisper.available_models())
Expand All @@ -24,14 +25,18 @@ def test_transcribe(model_name: str):
assert "your country" in transcription
assert "do for you" in transcription

tokenizer = get_tokenizer(model.is_multilingual)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")

timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True

assert timing_checked
Loading

0 comments on commit 839639a

Please sign in to comment.