Skip to content

Commit

Permalink
Fix typing for python < 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanheise committed Nov 25, 2023
1 parent 880c1dd commit bea1f24
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import traceback
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -49,7 +49,7 @@ def transcribe(
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
clip_timestamps: Union[str, list[float]] = "0",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
Expand Down Expand Up @@ -105,7 +105,7 @@ def transcribe(
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, list[float]]
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
Expand Down Expand Up @@ -163,12 +163,12 @@ def transcribe(
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: list[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: list[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))

punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"

Expand Down Expand Up @@ -317,7 +317,7 @@ def is_segment_anomaly(segment: Optional[dict]) -> bool:
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)

def next_words_segment(segments: list[dict]) -> Optional[dict]:
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
Expand Down
8 changes: 4 additions & 4 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
import zlib
from typing import Callable, Optional, TextIO
from typing import Callable, List, Optional, TextIO

system_encoding = sys.getdefaultencoding()

Expand Down Expand Up @@ -68,14 +68,14 @@ def format_timestamp(
)


def get_start(segments: list[dict]) -> Optional[float]:
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)


def get_end(segments: list[dict]) -> Optional[float]:
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
Expand Down Expand Up @@ -143,7 +143,7 @@ def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
chunk_index = 0
Expand Down

0 comments on commit bea1f24

Please sign in to comment.