diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3796a3973..dffc17c61 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,8 +6,38 @@ on: pull_request: branches: - main + jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Fetch base branch + run: git fetch origin ${{ github.base_ref }} + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + architecture: x64 + - name: Get pip cache dir + id: pip-cache + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip/pre-commit cache + uses: actions/cache@v3 + with: + path: | + ${{ steps.pip-cache.outputs.dir }} + ~/.cache/pre-commit + key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }} + restore-keys: | + ${{ runner.os }}-pip-pre-commit + - name: pre-commit + run: | + pip install -U pre-commit + pre-commit install --install-hooks + pre-commit run --all-files whisper-test: + needs: pre-commit runs-on: ubuntu-latest strategy: matrix: @@ -23,7 +53,4 @@ jobs: - uses: actions/checkout@v3 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install .["dev"] - - run: black --check --diff -t py38 --include '(\.pyi?)$' . - - run: isort --check --diff . - - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..3f5a74b6d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-json + - id: end-of-file-fixer + types: [file, python] + - id: trailing-whitespace + types: [file, python] + - id: mixed-line-ending + - id: check-added-large-files + args: [--maxkb=4096] + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"] + - repo: https://github.com/pycqa/flake8.git + rev: 6.0.0 + hooks: + - id: flake8 + types: [python] + args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"] diff --git a/CHANGELOG.md b/CHANGELOG.md index a77d966f9..58955410f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,45 @@ # CHANGELOG +## [v20231117](https://github.com/openai/whisper/releases/tag/v20231117) + +* Relax triton requirements for compatibility with pytorch 2.1 and newer ([#1802](https://github.com/openai/whisper/pull/1802)) + +## [v20231106](https://github.com/openai/whisper/releases/tag/v20231106) + +* large-v3 ([#1761](https://github.com/openai/whisper/pull/1761)) + +## [v20231105](https://github.com/openai/whisper/releases/tag/v20231105) + +* remove tiktoken pin ([#1759](https://github.com/openai/whisper/pull/1759)) +* docs: Disambiguation of the term "relative speed" in the README ([#1751](https://github.com/openai/whisper/pull/1751)) +* allow_pickle=False while loading of mel matrix IN audio.py ([#1511](https://github.com/openai/whisper/pull/1511)) +* handling transcribe exceptions. ([#1682](https://github.com/openai/whisper/pull/1682)) +* Add new option to generate subtitles by a specific number of words ([#1729](https://github.com/openai/whisper/pull/1729)) +* Fix exception when an audio file with no speech is provided ([#1396](https://github.com/openai/whisper/pull/1396)) + +## [v20230918](https://github.com/openai/whisper/releases/tag/v20230918) + +* Add .pre-commit-config.yaml ([#1528](https://github.com/openai/whisper/pull/1528)) +* fix doc of TextDecoder ([#1526](https://github.com/openai/whisper/pull/1526)) +* Update model-card.md ([#1643](https://github.com/openai/whisper/pull/1643)) +* word timing tweaks ([#1559](https://github.com/openai/whisper/pull/1559)) +* Avoid rearranging all caches ([#1483](https://github.com/openai/whisper/pull/1483)) +* Improve timestamp heuristics. ([#1461](https://github.com/openai/whisper/pull/1461)) +* fix condition_on_previous_text ([#1224](https://github.com/openai/whisper/pull/1224)) +* Fix numba depreceation notice ([#1233](https://github.com/openai/whisper/pull/1233)) +* Updated README.md to provide more insight on BLEU and specific appendices ([#1236](https://github.com/openai/whisper/pull/1236)) +* Avoid computing higher temperatures on no_speech segments ([#1279](https://github.com/openai/whisper/pull/1279)) +* Dropped unused execute bit from mel_filters.npz. ([#1254](https://github.com/openai/whisper/pull/1254)) +* Drop ffmpeg-python dependency and call ffmpeg directly. ([#1242](https://github.com/openai/whisper/pull/1242)) +* Python 3.11 ([#1171](https://github.com/openai/whisper/pull/1171)) +* Update decoding.py ([#1219](https://github.com/openai/whisper/pull/1219)) +* Update decoding.py ([#1155](https://github.com/openai/whisper/pull/1155)) +* Update README.md to reference tiktoken ([#1105](https://github.com/openai/whisper/pull/1105)) +* Implement max line width and max line count, and make word highlighting optional ([#1184](https://github.com/openai/whisper/pull/1184)) +* Squash long words at window and sentence boundaries. ([#1114](https://github.com/openai/whisper/pull/1114)) +* python-publish.yml: bump actions version to fix node warning ([#1211](https://github.com/openai/whisper/pull/1211)) +* Update tokenizer.py ([#1163](https://github.com/openai/whisper/pull/1163)) + ## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314) * abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090)) diff --git a/README.md b/README.md index 20532574c..afca9c971 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,7 @@ pip install setuptools-rust ## Available models and languages -There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed. - +There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and inference speed relative to the large model; actual speed may vary depending on many factors including the available hardware. | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | |:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:| @@ -70,9 +69,9 @@ There are five model sizes, four with English-only versions, offering speed and The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models. -Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model (The smaller the numbers, the better the performance). Additional WER scores corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4. Meanwhile, more BLEU (Bilingual Evaluation Understudy) scores can be found in Appendix D.3. Both are found in [the paper](https://arxiv.org/abs/2212.04356). +Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3. -![WER breakdown by language](https://raw.githubusercontent.com/openai/whisper/main/language-breakdown.svg) +![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62) diff --git a/language-breakdown.svg b/language-breakdown.svg index 49a0653eb..616fd57ea 100644 --- a/language-breakdown.svg +++ b/language-breakdown.svg @@ -1,16 +1,16 @@ - + - 2022-12-03T03:56:51.812586 + 2023-11-06T12:34:22.337927 image/svg+xml - Matplotlib v3.5.1, https://matplotlib.org/ + Matplotlib v3.7.3, https://matplotlib.org/ @@ -21,498 +21,42 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - + - + - + - + - + - + - + - + - + - + - + @@ -672,18 +216,18 @@ L 453.44096 7.2 - + - + - + - - + - + - + - + - + - + - + - + - + + + + + + + + - - + + + + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - + + - - - - - - - - - - - - - - - - - - - + + + + + + + - - - - - - - + + - + - + - + - - + + @@ -1240,164 +891,220 @@ z - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - + + + - - - - - - - + + + + + + - + - + - - + + - + - - - - - - - - - - + + + + + + + - + - + - + + - + - + - - + + - + + - - - - - - - - + + + + - - - - - - - - - - - - - - - - - - - + - + - + @@ -1536,74 +1280,112 @@ z - + - + - - - + + + - - - - - - - - - - - - - +M 3481 434 +Q 3481 -459 3084 -895 +Q 2688 -1331 1869 -1331 +Q 1566 -1331 1297 -1286 +Q 1028 -1241 775 -1147 +L 775 -588 +Q 1028 -725 1275 -790 +Q 1522 -856 1778 -856 +Q 2344 -856 2625 -561 +Q 2906 -266 2906 331 +L 2906 616 +Q 2728 306 2450 153 +Q 2172 0 1784 0 +Q 1141 0 747 490 +Q 353 981 353 1791 +Q 353 2603 747 3093 +Q 1141 3584 1784 3584 +Q 2172 3584 2450 3431 +Q 2728 3278 2906 2969 +L 2906 3500 +L 3481 3500 +L 3481 434 +z +" transform="scale(0.015625)"/> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - + + + + - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + @@ -1705,15 +1757,100 @@ z - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + + - + - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - + + + + + + + + - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + - - + + - + - - - - - - - - - - - + + + + + + + + + + + + - - + + - + - - - - - - - - - - - + + + + + + + + - - + + - + - - - - - - - - - - + + + + + + + + + + + - - + + - + - + - + @@ -2121,627 +2360,496 @@ z - - + + - + - - - - - - - - - - - - + + + + + + + - - + + - + - - - - - - - - - - - + + + + + + + + + - - + + - + - - - - - - - - - - - - - - - + + + + + + + + + - - + + - + - - - - - - - - - - + + + + + + + + + + - - + + - + - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + - + - - - - - - - - - - + + + + + + + + + + + + + - - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - + + + + + + + - - + + - + - - - - - - - - - - - - - + + + + + + + + - - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - + + + + + + + + + - - + + - + - - - - - - - - + + + + + + + + + + - - + + - + - - - - - - - - + + + + + + + + + + + + - - + + - + - - - - - - - - - - - + + + + + + + + + + + + + + + + - - + + - + - - - - - - - + + + + + + + + + + - - + + - + - - - - - - - - - - - - + + + + + + + + + - - + + - + - - - - - - - - - - - - - + + + + + + + + + + - - + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -2751,72 +2859,35 @@ z - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + - + - - - - - - - - + + + + + + + + + + + - - + + - + - + - + @@ -2827,751 +2898,5564 @@ z - - + + - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - + + + + + + + - + - + - - - - - - - - + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - + + - - + + - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - +" clip-path="url(#p224e7fb1ba)" style="fill: #ffffff; stroke: #808080; stroke-width: 0.5; stroke-linejoin: miter"/> + + + + + + + + + + + + - - - + + + + + + + + + + + - + - - - + + + - + - - - + + + - + - - - + + + - + - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - + + + - + - - - - - - + + + - + - + + + + + + + + + + + + + + + + + - - - - + - - - + + + - + - - - + + + - + - - - + + + - - + + + + + + + + + - + - - - - + + + + - + - - - + + + - + - - - - - - + + + + + + - - - + + + - + - - - - - - + + + - + - - - + + + - + - - - + + + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - - - - - - - - - - + - - - + + + - + - + - - - + + + - + - + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - - - - - - - - - - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - - - + + + - + - + - - - + + + - + - - - - - + + + + + - + - - - - - + + + + + - + - - - + + + - + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - - - + + + + + - - - - - + + + + + - + - - - - - + + + + + - - - - - - - - - - - - + + + - + - + - - - + + + - + - + - - - - - + + + + + - + - - + + + + + diff --git a/model-card.md b/model-card.md index 2ed85cf24..3c041a1c0 100644 --- a/model-card.md +++ b/model-card.md @@ -17,12 +17,12 @@ The Whisper models are trained for speech recognition and translation tasks, cap | medium | 769 M | ✓ | ✓ | | large | 1550 M | | ✓ | -In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661). +In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023. ### Release date -September 2022 (original series) and December 2022 (`large-v2`) +September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`) ### Model type @@ -37,7 +37,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m ### Evaluated Use -The primary intended users of these models are AI researchers studying robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. +The primary intended users of these models are AI researchers studying the robustness, generalization, capabilities, biases, and constraints of the current model. However, Whisper is also potentially quite useful as an ASR solution for developers, especially for English speech recognition. We recognize that once models are released, it is impossible to restrict access to only “intended” uses or to draw reasonable guidelines around what is or is not research. The models are primarily trained and evaluated on ASR and speech translation to English tasks. They show strong ASR results in ~10 languages. They may exhibit additional capabilities, particularly if fine-tuned on certain tasks like voice activity detection, speaker classification, or speaker diarization but have not been robustly evaluated in these areas. We strongly recommend that users perform robust evaluations of the models in a particular context and domain before deploying them. @@ -53,17 +53,17 @@ As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we s ## Performance and Limitations -Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, technical language, as well as zero shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. +Our studies show that, over many existing ASR systems, the models exhibit improved robustness to accents, background noise, and technical language, as well as zero-shot translation from multiple languages into English; and that accuracy on speech recognition and translation is near the state-of-the-art level. However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself. -Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). +Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include a higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356). -In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages. +In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis of these limitations is provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse in lower-resource and/or lower-discoverability languages. ## Broader Implications We anticipate that Whisper models’ transcription capabilities may be used for improving accessibility tools. While Whisper models cannot be used for real-time transcription out of the box – their speed and size suggest that others may be able to build applications on top of them that allow for near-real-time speech recognition and translation. The real value of beneficial applications built on top of Whisper models suggests that the disparate performance of these models may have real economic implications. -There are also potential dual use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. +There are also potential dual-use concerns that come with releasing Whisper. While we hope the technology will be used primarily for beneficial purposes, making ASR technology more accessible could enable more actors to build capable surveillance technologies or scale up existing surveillance efforts, as the speed and accuracy allow for affordable automatic transcription and translation of large volumes of audio communication. Moreover, these models may have some capabilities to recognize specific individuals out of the box, which in turn presents safety concerns related both to dual use and disparate performance. In practice, we expect that the cost of transcription is not the limiting factor of scaling up surveillance projects. diff --git a/requirements.txt b/requirements.txt index 3c11ac32c..62f5f9d7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ numpy torch tqdm more-itertools -tiktoken==0.3.3 +tiktoken +triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2" diff --git a/setup.py b/setup.py index 1161d815a..183b52755 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -import os import platform import sys +from pathlib import Path import pkg_resources from setuptools import find_packages, setup @@ -13,7 +13,7 @@ def read_version(fname="whisper/version.py"): requirements = [] if sys.platform.startswith("linux") and platform.machine() == "x86_64": - requirements.append("triton==2.0.0") + requirements.append("triton>=2.0.0,<3") setup( name="openai-whisper", @@ -28,11 +28,10 @@ def read_version(fname="whisper/version.py"): url="https://github.com/openai/whisper", license="MIT", packages=find_packages(exclude=["tests*"]), - install_requires=requirements - + [ + install_requires=[ str(r) for r in pkg_resources.parse_requirements( - open(os.path.join(os.path.dirname(__file__), "requirements.txt")) + Path(__file__).with_name("requirements.txt").open() ) ], entry_points={ diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 09d0351e1..be424e5fe 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,7 +1,17 @@ +import pytest + from whisper.tokenizer import get_tokenizer -def test_tokenizer(): +@pytest.mark.parametrize("multilingual", [True, False]) +def test_tokenizer(multilingual): + tokenizer = get_tokenizer(multilingual=False) + assert tokenizer.sot in tokenizer.sot_sequence + assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens) + assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens) + + +def test_multilingual_tokenizer(): gpt2_tokenizer = get_tokenizer(multilingual=False) multilingual_tokenizer = get_tokenizer(multilingual=True) @@ -20,5 +30,5 @@ def test_split_on_unicode(): tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378] words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens) - assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"] + assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"] assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]] diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index e4f8fd0f7..599221af5 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -25,7 +25,7 @@ def test_transcribe(model_name: str): assert "your country" in transcription assert "do for you" in transcription - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) all_tokens = [t for s in result["segments"] for t in s["tokens"]] assert tokenizer.decode(all_tokens) == result["text"] assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>") diff --git a/whisper/__init__.py b/whisper/__init__.py index 379133b6a..d7fbba36f 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -25,7 +25,8 @@ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are @@ -41,7 +42,8 @@ "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", - "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } diff --git a/whisper/assets/mel_filters.npz b/whisper/assets/mel_filters.npz index 1a7839244..28ea26909 100644 Binary files a/whisper/assets/mel_filters.npz and b/whisper/assets/mel_filters.npz differ diff --git a/whisper/audio.py b/whisper/audio.py index 4f5b6e073..cf6c66ad9 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -12,7 +12,6 @@ # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) -def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: +def mel_filters(device, n_mels: int) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -98,18 +97,19 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" - with np.load( - os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") - ) as f: + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + with np.load(filters_path, allow_pickle=False) as f: return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], - n_mels: int = N_MELS, + n_mels: int = 80, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): diff --git a/whisper/decoding.py b/whisper/decoding.py index 457ee7ccb..49485d009 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -32,7 +32,9 @@ def detect_language( list of dictionaries containing the probability distribution over all languages. """ if tokenizer is None: - tokenizer = get_tokenizer(model.is_multilingual) + tokenizer = get_tokenizer( + model.is_multilingual, num_languages=model.num_languages + ) if ( tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence @@ -146,6 +148,10 @@ def __init__(self, model: "Whisper", initial_token_length: int): self.kv_cache = {} self.hooks = [] + key_modules = [block.attn.key for block in self.model.decoder.blocks] + value_modules = [block.attn.value for block in self.model.decoder.blocks] + self.kv_modules = key_modules + value_modules + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() @@ -164,9 +170,10 @@ def cleanup_caching(self): self.hooks = [] def rearrange_kv_cache(self, source_indices): - for module, tensor in self.kv_cache.items(): - # update the key/value cache to contain the selected sequences - self.kv_cache[module] = tensor[source_indices].detach() + if source_indices != list(range(len(source_indices))): + for module in self.kv_modules: + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = self.kv_cache[module][source_indices].detach() class SequenceRanker: @@ -509,7 +516,10 @@ def __init__(self, model: "Whisper", options: DecodingOptions): language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, language=language, task=options.task + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) @@ -668,7 +678,6 @@ def _detect_language(self, audio_features: Tensor, tokens: Tensor): return languages, lang_probs def _main_loop(self, audio_features: Tensor, tokens: Tensor): - assert audio_features.shape[0] == tokens.shape[0] n_batch = tokens.shape[0] sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch @@ -721,8 +730,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]: ) ] - # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling - audio_features = audio_features.repeat_interleave(self.n_group, dim=0) + # repeat text tensors by the group size, for beam search or best-of-n sampling tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop diff --git a/whisper/model.py b/whisper/model.py index 3457fcfc6..a67828397 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -197,7 +197,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens - xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on """ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 @@ -236,7 +236,8 @@ def __init__(self, dims: ModelDimensions): self.dims.n_text_head, self.dims.n_text_layer, ) - # use the last half layers for alignment by default; see `set_alignment_heads()` below + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) @@ -269,7 +270,11 @@ def device(self): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ diff --git a/whisper/timing.py b/whisper/timing.py index 1a73eaafc..befcf464e 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -202,7 +202,7 @@ def find_alignment( hook.remove() # heads * tokens * frames - weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) + weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = weights[:, :, : num_frames // 2] weights = (weights * qk_scale).softmax(dim=-1) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) @@ -214,6 +214,13 @@ def find_alignment( text_indices, time_indices = dtw(-matrix) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) @@ -225,28 +232,6 @@ def find_alignment( for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] - # hack: truncate long words at the start of a window and the start of a sentence. - # a better segmentation algorithm based on VAD should be able to replace this. - word_durations = end_times - start_times - word_durations = word_durations[word_durations.nonzero()] - if len(word_durations) > 0: - median_duration = np.median(word_durations) - max_duration = median_duration * 2 - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries are not longer than twice the median word duration. - for i in range(1, len(start_times)): - if end_times[i] - start_times[i] > max_duration: - if words[i] in sentence_end_marks: - end_times[i] = start_times[i] + max_duration - elif words[i - 1] in sentence_end_marks: - start_times[i] = end_times[i] - max_duration - # ensure the first and second word is not longer than twice the median word duration. - if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration: - if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration: - boundary = max(end_times[1] / 2, end_times[1] - max_duration) - end_times[0] = start_times[1] = boundary - start_times[0] = max(0, end_times[0] - max_duration) - return [ WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip( @@ -298,6 +283,7 @@ def add_word_timestamps( num_frames: int, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + last_speech_timestamp: float, **kwargs, ): if len(segments) == 0: @@ -310,6 +296,23 @@ def add_word_timestamps( text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) + word_durations = np.array([t.end - t.start for t in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i].end - alignment[i].start > max_duration: + if alignment[i].word in sentence_end_marks: + alignment[i].end = alignment[i].start + max_duration + elif alignment[i - 1].word in sentence_end_marks: + alignment[i].start = alignment[i].end - max_duration + merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE @@ -335,18 +338,48 @@ def add_word_timestamps( saved_tokens += len(timing.tokens) word_index += 1 + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. if len(words) > 0: - segment["start"] = words[0]["start"] - # hack: prefer the segment-level end timestamp if the last word is too long. - # a better segmentation algorithm based on VAD should be able to replace this. + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. if ( segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"] ): - # adjust the word-level timestamps based on the segment-level timestamps - words[-1]["end"] = segment["end"] + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) else: - # adjust the segment-level timestamps based on the word-level timestamps segment["end"] = words[-1]["end"] + last_speech_timestamp = segment["end"] + segment["words"] = words diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 4030e15aa..2af837570 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -107,6 +107,7 @@ "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -123,6 +124,7 @@ "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } @@ -131,6 +133,7 @@ class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding + num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () @@ -145,7 +148,7 @@ def __post_init__(self): translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] - langs = tuple(LANGUAGES.keys()) + langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) @@ -211,10 +214,13 @@ def language_token(self) -> int: if self.language is None: raise ValueError("This tokenizer does not have language token configured") - if token := self.special_tokens.get(f"<|{self.language}|>", None): + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: @@ -222,11 +228,11 @@ def all_language_tokens(self) -> Tuple[int]: for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: - return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) + return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: @@ -269,7 +275,7 @@ def non_speech_tokens(self) -> Tuple[int]: return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my"}: + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -322,7 +328,7 @@ def split_tokens_on_spaces(self, tokens: List[int]): @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2"): +def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") ranks = { base64.b64decode(token): int(rank) @@ -334,7 +340,7 @@ def get_encoding(name: str = "gpt2"): specials = [ "<|endoftext|>", "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", @@ -361,6 +367,7 @@ def get_encoding(name: str = "gpt2"): def get_tokenizer( multilingual: bool, *, + num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: @@ -381,6 +388,8 @@ def get_tokenizer( language = None task = None - encoding = get_encoding(name=encoding_name) + encoding = get_encoding(name=encoding_name, num_languages=num_languages) - return Tokenizer(encoding=encoding, language=language, task=task) + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d92d7da89..1e5c27c9d 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,5 +1,6 @@ import argparse import os +import traceback import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union @@ -118,7 +119,7 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) + mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-1] - N_FRAMES if decode_options.get("language", None) is None: @@ -139,7 +140,12 @@ def transcribe( language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) if word_timestamps and task == "translate": warnings.warn("Word-level timestamps on translations may not be reliable.") @@ -228,6 +234,7 @@ def new_segment( with tqdm.tqdm( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: + last_speech_timestamp = 0.0 while seek < content_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) mel_segment = mel[:, seek : seek + N_FRAMES] @@ -327,10 +334,13 @@ def new_segment( num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) word_end_timestamps = [ w["end"] for s in current_segments for w in s["words"] ] + if len(word_end_timestamps) > 0: + last_speech_timestamp = word_end_timestamps[-1] if not single_timestamp_ending and len(word_end_timestamps) > 0: seek_shift = round( (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND @@ -380,10 +390,17 @@ def new_segment( def cli(): from . import available_models + def valid_model_name(name): + if name in available_models() or os.path.exists(name): + return name + raise ValueError( + f"model should be one of {available_models()} or path to a model checkpoint" + ) + # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") - parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") @@ -414,6 +431,7 @@ def cli(): parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") + parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") # fmt: on @@ -446,17 +464,28 @@ def cli(): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - word_options = ["highlight_words", "max_line_count", "max_line_width"] + word_options = [ + "highlight_words", + "max_line_count", + "max_line_width", + "max_words_per_line", + ] if not args["word_timestamps"]: for option in word_options: if args[option]: parser.error(f"--{option} requires --word_timestamps True") if args["max_line_count"] and not args["max_line_width"]: warnings.warn("--max_line_count has no effect without --max_line_width") + if args["max_words_per_line"] and args["max_line_width"]: + warnings.warn("--max_words_per_line has no effect with --max_line_width") writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): - result = transcribe(model, audio_path, temperature=temperature, **args) - writer(result, audio_path, writer_args) + try: + result = transcribe(model, audio_path, temperature=temperature, **args) + writer(result, audio_path, **writer_args) + except Exception as e: + traceback.print_exc() + print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/utils.py b/whisper/utils.py index ba5a10c41..7a172c401 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -74,7 +74,9 @@ class ResultWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - def __call__(self, result: dict, audio_path: str, options: dict): + def __call__( + self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + ): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -82,16 +84,20 @@ def __call__(self, result: dict, audio_path: str, options: dict): ) with open(output_path, "w", encoding="utf-8") as f: - self.write_result(result, file=f, options=options) + self.write_result(result, file=f, options=options, **kwargs) - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) @@ -100,12 +106,24 @@ class SubtitlesWriter(ResultWriter): always_include_hours: bool decimal_marker: str - def iterate_result(self, result: dict, options: dict): - raw_max_line_width: Optional[int] = options["max_line_width"] - max_line_count: Optional[int] = options["max_line_count"] - highlight_words: bool = options["highlight_words"] - max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width - preserve_segments = max_line_count is None or raw_max_line_width is None + def iterate_result( + self, + result: dict, + options: Optional[dict] = None, + *, + max_line_width: Optional[int] = None, + max_line_count: Optional[int] = None, + highlight_words: bool = False, + max_words_per_line: Optional[int] = None, + ): + options = options or {} + max_line_width = max_line_width or options.get("max_line_width") + max_line_count = max_line_count or options.get("max_line_count") + highlight_words = highlight_words or options.get("highlight_words", False) + max_words_per_line = max_words_per_line or options.get("max_words_per_line") + preserve_segments = max_line_count is None or max_line_width is None + max_line_width = max_line_width or 1000 + max_words_per_line = max_words_per_line or 1000 def iterate_subtitles(): line_len = 0 @@ -114,38 +132,54 @@ def iterate_subtitles(): subtitle: list[dict] = [] last = result["segments"][0]["words"][0]["start"] for segment in result["segments"]: - for i, original_timing in enumerate(segment["words"]): - timing = original_timing.copy() - long_pause = not preserve_segments and timing["start"] - last > 3.0 - has_room = line_len + len(timing["word"]) <= max_line_width - seg_break = i == 0 and len(subtitle) > 0 and preserve_segments - if line_len > 0 and has_room and not long_pause and not seg_break: - # line continuation - line_len += len(timing["word"]) - else: - # new line - timing["word"] = timing["word"].strip() + chunk_index = 0 + words_count = max_words_per_line + while chunk_index < len(segment["words"]): + remaining_words = len(segment["words"]) - chunk_index + if max_words_per_line > len(segment["words"]) - chunk_index: + words_count = remaining_words + for i, original_timing in enumerate( + segment["words"][chunk_index : chunk_index + words_count] + ): + timing = original_timing.copy() + long_pause = ( + not preserve_segments and timing["start"] - last > 3.0 + ) + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if ( - len(subtitle) > 0 - and max_line_count is not None - and (long_pause or line_count >= max_line_count) - or seg_break + line_len > 0 + and has_room + and not long_pause + and not seg_break ): - # subtitle break - yield subtitle - subtitle = [] - line_count = 1 - elif line_len > 0: - # line break - line_count += 1 - timing["word"] = "\n" + timing["word"] - line_len = len(timing["word"].strip()) - subtitle.append(timing) - last = timing["start"] + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + chunk_index += max_words_per_line if len(subtitle) > 0: yield subtitle - if "words" in result["segments"][0]: + if len(result["segments"]) > 0 and "words" in result["segments"][0]: for subtitle in iterate_subtitles(): subtitle_start = self.format_timestamp(subtitle[0]["start"]) subtitle_end = self.format_timestamp(subtitle[-1]["end"]) @@ -190,9 +224,11 @@ class WriteVTT(SubtitlesWriter): always_include_hours: bool = False decimal_marker: str = "." - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): print("WEBVTT\n", file=file) - for start, end, text in self.iterate_result(result, options): + for start, end, text in self.iterate_result(result, options, **kwargs): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) @@ -201,9 +237,11 @@ class WriteSRT(SubtitlesWriter): always_include_hours: bool = True decimal_marker: str = "," - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for i, (start, end, text) in enumerate( - self.iterate_result(result, options), start=1 + self.iterate_result(result, options, **kwargs), start=1 ): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) @@ -220,7 +258,9 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -231,7 +271,9 @@ def write_result(self, result: dict, file: TextIO, options: dict): class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): json.dump(result, file) @@ -249,9 +291,11 @@ def get_writer( if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO, options: dict): + def write_all( + result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for writer in all_writers: - writer(result, file, options) + writer(result, file, options, **kwargs) return write_all diff --git a/whisper/version.py b/whisper/version.py index 572259a53..c96dd9ce4 100644 --- a/whisper/version.py +++ b/whisper/version.py @@ -1 +1 @@ -__version__ = "20230314" +__version__ = "20231117"