From c580f466bc97fa6e2ebc896cd52731a6fdcb7553 Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Wed, 20 Dec 2023 11:32:04 +0800 Subject: [PATCH] Enhancement: Accept Path type for `audio_path` parameter and return `output_path` --- whisper/utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b13862..b67f5afd8 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -3,7 +3,9 @@ import re import sys import zlib -from typing import Callable, List, Optional, TextIO +from typing import Callable, List, Optional, TextIO, Union +from pathlib import Path +from typing import Callable, Optional, TextIO, Union system_encoding = sys.getdefaultencoding() @@ -89,16 +91,17 @@ def __init__(self, output_dir: str): self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs - ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] + self, result: dict, audio_path: Union[str, Path], options: Optional[dict] = None, **kwargs + ) -> str: + if not isinstance(audio_path, Path): + audio_path = Path(audio_path) output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + self.output_dir, audio_path.with_suffix(self.extension).name ) with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) + return output_path def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs