diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3796a3973..0f9cfabe4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,8 +6,38 @@ on: pull_request: branches: - main + jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Fetch base branch + run: git fetch origin ${{ github.base_ref }} + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + architecture: x64 + - name: Get pip cache dir + id: pip-cache + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip/pre-commit cache + uses: actions/cache@v3 + with: + path: | + ${{ steps.pip-cache.outputs.dir }} + ~/.cache/pre-commit + key: ${{ runner.os }}-pip-pre-commit-${{ hashFiles('**/.pre-commit-config.yaml') }} + restore-keys: | + ${{ runner.os }}-pip-pre-commit + - name: pre-commit + run: | + pip install -U pre-commit + pre-commit install --install-hooks + pre-commit run --from-ref=origin/${{ github.base_ref }} --to-ref=HEAD whisper-test: + needs: pre-commit runs-on: ubuntu-latest strategy: matrix: @@ -23,7 +53,4 @@ jobs: - uses: actions/checkout@v3 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install .["dev"] - - run: black --check --diff -t py38 --include '(\.pyi?)$' . - - run: isort --check --diff . - - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..3f5a74b6d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-json + - id: end-of-file-fixer + types: [file, python] + - id: trailing-whitespace + types: [file, python] + - id: mixed-line-ending + - id: check-added-large-files + args: [--maxkb=4096] + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black", "-l", "88", "--trailing-comma", "--multi-line", "3"] + - repo: https://github.com/pycqa/flake8.git + rev: 6.0.0 + hooks: + - id: flake8 + types: [python] + args: ["--max-line-length", "88", "--ignore", "E203,E501,W503,W504"] diff --git a/whisper/timing.py b/whisper/timing.py index 56e84d432..befcf464e 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -202,7 +202,7 @@ def find_alignment( hook.remove() # heads * tokens * frames - weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T]) + weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) weights = weights[:, :, : num_frames // 2] weights = (weights * qk_scale).softmax(dim=-1) std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index 4030e15aa..3b2399184 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -226,7 +226,7 @@ def all_language_tokens(self) -> Tuple[int]: @cached_property def all_language_codes(self) -> Tuple[str]: - return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) + return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: