Skip to content

Commit

Permalink
chore: Make MyPy pass with no errors in strict mode
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarabela committed May 5, 2024
1 parent 0ad56db commit 8ac80f4
Show file tree
Hide file tree
Showing 32 changed files with 436 additions and 250 deletions.
19 changes: 18 additions & 1 deletion pysubs2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .ssafile import SSAFile
from .ssaevent import SSAEvent
from .ssastyle import SSAStyle
from . import time, formats, cli, whisper
from . import time, formats, cli, whisper, exceptions
from .exceptions import *
from .common import Color, Alignment, VERSION

Expand All @@ -18,3 +18,20 @@

#: Alias for `pysubs2.common.VERSION`.
__version__ = VERSION

__all__ = [
"SSAFile",
"SSAEvent",
"SSAStyle",
"time",
"formats",
"cli",
"whisper",
"exceptions",
"Color",
"Alignment",
"VERSION",
"load",
"load_from_whisper",
"make_time",
]
24 changes: 15 additions & 9 deletions pysubs2/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from io import open
import sys
from textwrap import dedent
from typing import List

from .formats import get_file_extension, FORMAT_IDENTIFIERS
from .time import make_time
from .ssafile import SSAFile
Expand Down Expand Up @@ -42,7 +44,7 @@ def change_ext(path: str, ext: str) -> str:


class Pysubs2CLI:
def __init__(self):
def __init__(self) -> None:
parser = self.parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
prog="pysubs2",
description=dedent("""
Expand Down Expand Up @@ -116,15 +118,16 @@ def __init__(self):
extra_sub_options.add_argument("--sub-no-write-fps-declaration", action="store_true",
help="(output) omit writing FPS as first zero-length subtitle")

def __call__(self, argv):
def __call__(self, argv: List[str]) -> int:
try:
self.main(argv)
return self.main(argv)
except KeyboardInterrupt:
exit("\nAborted by user.")
print("\nAborted by user.", file=sys.stderr)
return 1

def main(self, argv):
def main(self, argv: List[str]) -> int:
# Dealing with empty arguments
if argv == []:
if not argv:
argv = ["--help"]

args = self.parser.parse_args(argv)
Expand Down Expand Up @@ -169,10 +172,12 @@ def main(self, argv):
if args.output_format is None:
outpath = path
output_format = subs.format
assert output_format is not None, "subs.format must not be None (it was read from file)"
else:
ext = get_file_extension(args.output_format)
outpath = change_ext(path, ext)
output_format = args.output_format
assert output_format is not None, "args.output_format must not be None (see if/else)"

if args.output_dir is not None:
_, filename = op.split(outpath)
Expand All @@ -188,12 +193,13 @@ def main(self, argv):
subs = SSAFile.from_file(infile, args.input_format, args.fps)
self.process(subs, args)
output_format = args.output_format or subs.format
assert output_format is not None, "output_format must not be None (it's either given or inferred at read time)"
subs.to_file(outfile, output_format, args.fps, apply_styles=not args.clean)

return (0 if errors == 0 else 1)
return 0 if errors == 0 else 1

@staticmethod
def process(subs, args):
def process(subs: SSAFile, args: argparse.Namespace) -> None:
if args.shift is not None:
subs.shift(ms=args.shift)
elif args.shift_back is not None:
Expand All @@ -206,7 +212,7 @@ def process(subs, args):
subs.remove_miscellaneous_events()


def __main__():
def __main__() -> None:
cli = Pysubs2CLI()
rv = cli(sys.argv[1:])
sys.exit(rv)
Expand Down
4 changes: 2 additions & 2 deletions pysubs2/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Union
from typing import Tuple, Union
from enum import IntEnum


Expand Down Expand Up @@ -54,7 +54,7 @@ def to_ssa_alignment(self) -> int:
return SSA_ALIGNMENT[self.value - 1]


SSA_ALIGNMENT = (1, 2, 3, 9, 10, 11, 5, 6, 7)
SSA_ALIGNMENT: Tuple[int, ...] = (1, 2, 3, 9, 10, 11, 5, 6, 7)


#: Version of the pysubs2 library.
Expand Down
10 changes: 5 additions & 5 deletions pysubs2/formatbase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
import io
import pysubs2
from typing import Optional, TYPE_CHECKING, Any, TextIO
if TYPE_CHECKING:
from .ssafile import SSAFile


class FormatBase:
Expand All @@ -19,7 +19,7 @@ class FormatBase:
"""
@classmethod
def from_file(cls, subs: "pysubs2.SSAFile", fp: io.TextIOBase, format_: str, **kwargs):
def from_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""
Load subtitle file into an empty SSAFile.
Expand All @@ -42,7 +42,7 @@ def from_file(cls, subs: "pysubs2.SSAFile", fp: io.TextIOBase, format_: str, **k
raise NotImplementedError("Parsing is not supported for this format")

@classmethod
def to_file(cls, subs: "pysubs2.SSAFile", fp: io.TextIOBase, format_: str, **kwargs):
def to_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""
Write SSAFile into a file.
Expand Down
14 changes: 10 additions & 4 deletions pysubs2/jsonformat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import dataclasses
import json
from typing import Any, Optional, TextIO, TYPE_CHECKING

from .common import Color
from .ssaevent import SSAEvent
from .ssastyle import SSAStyle
from .formatbase import FormatBase
if TYPE_CHECKING:
from .ssafile import SSAFile


# We're using Color dataclass
# https://stackoverflow.com/questions/51286748/make-the-python-json-encoder-support-pythons-new-dataclasses
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
def default(self, o: Any) -> Any:
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
Expand All @@ -22,13 +26,15 @@ class JSONFormat(FormatBase):
This is essentially SubStation Alpha as JSON.
"""
@classmethod
def guess_format(cls, text):
def guess_format(cls, text: str) -> Optional[str]:
"""See :meth:`pysubs2.formats.FormatBase.guess_format()`"""
if text.startswith("{\"") and "\"info:\"" in text:
return "json"
else:
return None

@classmethod
def from_file(cls, subs, fp, format_, **kwargs):
def from_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""See :meth:`pysubs2.formats.FormatBase.from_file()`"""
data = json.load(fp)

Expand All @@ -47,7 +53,7 @@ def from_file(cls, subs, fp, format_, **kwargs):
subs.events = [SSAEvent(**fields) for fields in data["events"]]

@classmethod
def to_file(cls, subs, fp, format_, **kwargs):
def to_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""See :meth:`pysubs2.formats.FormatBase.to_file()`"""
data = {
"info": dict(**subs.info),
Expand Down
19 changes: 14 additions & 5 deletions pysubs2/microdvd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from functools import partial
import re
from typing import Optional, TextIO, Any, TYPE_CHECKING

from .exceptions import UnknownFPSError
from .ssaevent import SSAEvent
from .ssastyle import SSAStyle
from .formatbase import FormatBase
from .substation import parse_tags
from .time import ms_to_frames, frames_to_ms
if TYPE_CHECKING:
from .ssafile import SSAFile


#: Matches a MicroDVD line.
MICRODVD_LINE = re.compile(r" *\{ *(\d+) *\} *\{ *(\d+) *\}(.+)")
Expand All @@ -14,13 +19,16 @@
class MicroDVDFormat(FormatBase):
"""MicroDVD subtitle format implementation"""
@classmethod
def guess_format(cls, text):
def guess_format(cls, text: str) -> Optional[str]:
"""See :meth:`pysubs2.formats.FormatBase.guess_format()`"""
if any(map(MICRODVD_LINE.match, text.splitlines())):
return "microdvd"
else:
return None

@classmethod
def from_file(cls, subs, fp, format_, fps=None, strict_fps_inference: bool = True, **kwargs):
def from_file(cls, subs: "SSAFile", fp: TextIO, format_: str, fps: Optional[float] = None,
strict_fps_inference: bool = True, **kwargs: Any) -> None:
"""
See :meth:`pysubs2.formats.FormatBase.from_file()`
Expand Down Expand Up @@ -60,10 +68,10 @@ def from_file(cls, subs, fp, format_, fps=None, strict_fps_inference: bool = Tru

start, end = map(partial(frames_to_ms, fps=fps), (fstart, fend))

def prepare_text(text):
def prepare_text(text: str) -> str:
text = text.replace("|", r"\N")

def style_replacer(match: re.Match) -> str:
def style_replacer(match: re.Match[str]) -> str:
tags = [c for c in "biu" if c in match.group(0)]
return "{%s}" % "".join(f"\\{c}1" for c in tags)

Expand All @@ -78,7 +86,8 @@ def style_replacer(match: re.Match) -> str:
subs.append(ev)

@classmethod
def to_file(cls, subs, fp, format_, fps=None, write_fps_declaration=True, apply_styles=True, **kwargs):
def to_file(cls, subs: "SSAFile", fp: TextIO, format_: str, fps: Optional[float] = None,
write_fps_declaration: bool = True, apply_styles: bool = True, **kwargs: Any) -> None:
"""
See :meth:`pysubs2.formats.FormatBase.to_file()`
Expand Down
12 changes: 8 additions & 4 deletions pysubs2/mpl2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re

from typing import Optional, Any, TextIO, TYPE_CHECKING
from .time import times_to_ms
from .formatbase import FormatBase
from .ssaevent import SSAEvent
if TYPE_CHECKING:
from .ssafile import SSAFile


# thanks to http://otsaloma.io/gaupol/doc/api/aeidon.files.mpl2_source.html
Expand All @@ -12,13 +14,15 @@
class MPL2Format(FormatBase):
"""MPL2 subtitle format implementation"""
@classmethod
def guess_format(cls, text):
def guess_format(cls, text: str) -> Optional[str]:
"""See :meth:`pysubs2.formats.FormatBase.guess_format()`"""
if MPL2_FORMAT.search(text):
return "mpl2"
else:
return None

@classmethod
def from_file(cls, subs, fp, format_, **kwargs):
def from_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""See :meth:`pysubs2.formats.FormatBase.from_file()`"""
def prepare_text(lines: str) -> str:
out = []
Expand All @@ -42,7 +46,7 @@ def prepare_text(lines: str) -> str:
subs.append(e)

@classmethod
def to_file(cls, subs, fp, format_, **kwargs):
def to_file(cls, subs: "SSAFile", fp: TextIO, format_: str, **kwargs: Any) -> None:
"""
See :meth:`pysubs2.formats.FormatBase.to_file()`
Expand Down
28 changes: 14 additions & 14 deletions pysubs2/ssaevent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import warnings
from typing import Optional, Dict, Any, ClassVar
from typing import Optional, Dict, Any, ClassVar, FrozenSet
import dataclasses

from .common import IntOrFloat
Expand Down Expand Up @@ -41,7 +41,7 @@ class SSAEvent:
type: str = "Dialogue" #: Line type (Dialogue/Comment)

@property
def FIELDS(self):
def FIELDS(self) -> FrozenSet[str]:
"""All fields in SSAEvent."""
warnings.warn("Deprecated in 1.2.0 - it's a dataclass now", DeprecationWarning)
return frozenset(field.name for field in dataclasses.fields(self))
Expand All @@ -57,7 +57,7 @@ def duration(self) -> IntOrFloat:
return self.end - self.start

@duration.setter
def duration(self, ms: int):
def duration(self, ms: int) -> None:
if ms >= 0:
self.end = self.start + ms
else:
Expand All @@ -74,7 +74,7 @@ def is_comment(self) -> bool:
return self.type == "Comment"

@is_comment.setter
def is_comment(self, value: bool):
def is_comment(self, value: bool) -> None:
if value:
self.type = "Comment"
else:
Expand Down Expand Up @@ -111,11 +111,11 @@ def plaintext(self) -> str:
return text

@plaintext.setter
def plaintext(self, text: str):
def plaintext(self, text: str) -> None:
self.text = text.replace("\n", r"\N")

def shift(self, h: IntOrFloat=0, m: IntOrFloat=0, s: IntOrFloat=0, ms: IntOrFloat=0,
frames: Optional[int]=None, fps: Optional[float]=None):
def shift(self, h: IntOrFloat = 0, m: IntOrFloat = 0, s: IntOrFloat = 0, ms: IntOrFloat = 0,
frames: Optional[int] = None, fps: Optional[float] = None) -> None:
"""
Shift start and end times.
Expand All @@ -141,36 +141,36 @@ def equals(self, other: "SSAEvent") -> bool:
else:
raise TypeError("Cannot compare to non-SSAEvent object")

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
# XXX document this
if not isinstance(other, SSAEvent):
return NotImplemented
return self.start == other.start and self.end == other.end

def __ne__(self, other) -> bool:
def __ne__(self, other: object) -> bool:
if not isinstance(other, SSAEvent):
return NotImplemented
return self.start != other.start or self.end != other.end

def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, SSAEvent):
return NotImplemented
return (self.start, self.end) < (other.start, other.end)

def __le__(self, other) -> bool:
def __le__(self, other: object) -> bool:
if not isinstance(other, SSAEvent):
return NotImplemented
return (self.start, self.end) <= (other.start, other.end)

def __gt__(self, other) -> bool:
def __gt__(self, other: object) -> bool:
if not isinstance(other, SSAEvent):
return NotImplemented
return (self.start, self.end) > (other.start, other.end)

def __ge__(self, other) -> bool:
def __ge__(self, other: object) -> bool:
if not isinstance(other, SSAEvent):
return NotImplemented
return (self.start, self.end) >= (other.start, other.end)

def __repr__(self):
def __repr__(self) -> str:
return f"<SSAEvent type={self.type} start={ms_to_str(self.start)} end={ms_to_str(self.end)} text={self.text!r}>"
Loading

0 comments on commit 8ac80f4

Please sign in to comment.