Skip to content

Commit

Permalink
Fix model.transcribe
Browse files Browse the repository at this point in the history
Some of the inference functions now return lists of dicts where they previously returned just dicts, causing them to fail when transcribe is called. This makes a trivial fix of just picking the first list option when that occurs.
  • Loading branch information
jbetker committed Jan 25, 2024
1 parent ba3f3cd commit 3afb6af
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def transcribe(
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)

def maybe_dereference_list(obj: Union[dict, list[dict]]) -> dict:
if isinstance(obj, list):
return obj[0]
return obj

if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
Expand All @@ -144,7 +149,7 @@ def transcribe(
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
decode_options["language"] = max(maybe_dereference_list(probs), key=maybe_dereference_list(probs).get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
Expand Down Expand Up @@ -192,7 +197,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
kwargs.pop("best_of", None)

options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
decode_result = maybe_dereference_list(model.decode(segment, options))

needs_fallback = False
if (
Expand Down

0 comments on commit 3afb6af

Please sign in to comment.