diff --git a/whisper/utils.py b/whisper/utils.py index 9b9b13862..f4bc11474 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -3,7 +3,8 @@ import re import sys import zlib -from typing import Callable, List, Optional, TextIO +from pathlib import Path +from typing import Callable, List, Optional, TextIO, Union system_encoding = sys.getdefaultencoding() @@ -89,16 +90,21 @@ def __init__(self, output_dir: str): self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs - ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] + self, + result: dict, + audio_path: Union[str, Path], + options: Optional[dict] = None, + **kwargs, + ) -> str: + if not isinstance(audio_path, Path): + audio_path = Path(audio_path) output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + self.output_dir, audio_path.with_suffix(self.extension).name ) with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) + return output_path def write_result( self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs