diff --git a/src/crowsetta/annotation.py b/src/crowsetta/annotation.py index 7ddd84d..62e55d4 100644 --- a/src/crowsetta/annotation.py +++ b/src/crowsetta/annotation.py @@ -1,8 +1,8 @@ """A class to represent annotations for a single file.""" from __future__ import annotations -from pathlib import Path import reprlib +from pathlib import Path from typing import Optional import crowsetta @@ -89,12 +89,10 @@ def __init__( if seq: if not ( - isinstance(seq, crowsetta.Sequence) or - (isinstance(seq, list) and all([isinstance(seq_, crowsetta.Sequence) for seq_ in seq])) + isinstance(seq, crowsetta.Sequence) + or (isinstance(seq, list) and all([isinstance(seq_, crowsetta.Sequence) for seq_ in seq])) ): - raise TypeError( - f"``seq`` should be a crowsetta.Sequence or list of Sequences but was: {type(seq)}" - ) + raise TypeError(f"``seq`` should be a crowsetta.Sequence or list of Sequences but was: {type(seq)}") self.seq = seq if bboxes: diff --git a/src/crowsetta/formats/seq/birdsongrec.py b/src/crowsetta/formats/seq/birdsongrec.py index 3ad9d0c..0f001f2 100644 --- a/src/crowsetta/formats/seq/birdsongrec.py +++ b/src/crowsetta/formats/seq/birdsongrec.py @@ -8,11 +8,13 @@ Boundaries in the Birdsong with Variable Sequences. PLoS ONE 11(7): e0159188. doi:10.1371/journal.pone.0159188 """ +from __future__ import annotations + import os import pathlib import warnings -from typing import ClassVar, List, Optional import xml.etree.ElementTree as ET +from typing import ClassVar, List, Optional import attr import numpy as np @@ -36,13 +38,14 @@ class BirdsongRecSyllable: text representation of syllable as classified by a human or a machine learning algorithm """ - def __init__(self, position, length, label): + + def __init__(self, position: int, length: int, label: str) -> None: if not isinstance(position, int): - raise TypeError(f'position must be an int, not type {type(position)}') + raise TypeError(f"position must be an int, not type {type(position)}") if not isinstance(length, int): - raise TypeError(f'length must be an int, not type {type(length)}') + raise TypeError(f"length must be an int, not type {type(length)}") if not isinstance(label, str): - raise TypeError(f'label must be a string, not type {type(label)}') + raise TypeError(f"label must be a string, not type {type(label)}") self.position = position self.length = length self.label = label @@ -67,20 +70,19 @@ class BirdsongRecSequence: list of syllable objects that make up sequence seq_spect : spectrogram object """ - def __init__(self, wav_file, position, length, syl_list): + def __init__(self, wav_file: PathLike, position: int, length: int, syl_list: list[BirdsongRecSyllable]): if not isinstance(wav_file, (str, pathlib.Path)): - raise TypeError(f'wav_file must be a string or pathlib.Path, not type {type(wav_file)}') + raise TypeError(f"wav_file must be a string or pathlib.Path, not type {type(wav_file)}") wav_file = str(wav_file) if not isinstance(position, int): - raise TypeError(f'position must be an int, not type {type(position)}') + raise TypeError(f"position must be an int, not type {type(position)}") if not isinstance(length, int): - raise TypeError(f'length must be an int, not type {type(length)}') + raise TypeError(f"length must be an int, not type {type(length)}") if not isinstance(syl_list, list): - raise TypeError(f'syl_list must be a list, not type {type(syl_list)}') + raise TypeError(f"syl_list must be a list, not type {type(syl_list)}") if not all([type(syl) == BirdsongRecSyllable for syl in syl_list]): - raise TypeError('not all elements in syl list are of type BirdsongRecSyllable: ' - f'{syl_list}') + raise TypeError("not all elements in syl list are of type BirdsongRecSyllable: " f"{syl_list}") self.wav_file = wav_file self.position = position self.length = length @@ -91,9 +93,12 @@ def __repr__(self): return f"Sequence(wav_file={self.wav_file}, position={self.position}, length={self.length}, syls={self.syls})" - -def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False, - wav_abspath=None): +def parse_xml( + xml_file: PathLike, + concat_seqs_into_songs: bool = False, + return_wav_abspath: bool = False, + wav_abspath: PathLike = None, +) -> list[BirdsongRecSequence]: """parses Annotation.xml files from the BirdsongRecognition dataset: Koumura, T. (2016). BirdsongRecognition (Version 1). figshare. https://doi.org/10.6084/m9.figshare.3470165.v1 @@ -138,12 +143,11 @@ def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False, if return_wav_abspath: if wav_abspath: if not os.path.isdir(wav_abspath): - raise NotADirectoryError(f'return_wav_abspath is True but {wav_abspath} ' - 'is not a valid directory.') + raise NotADirectoryError(f"return_wav_abspath is True but {wav_abspath} " "is not a valid directory.") tree = ET.ElementTree(file=xml_file) seq_list = [] - for seq in tree.iter(tag='Sequence'): - wav_file = seq.find('WaveFileName').text + for seq in tree.iter(tag="Sequence"): + wav_file = seq.find("WaveFileName").text if return_wav_abspath: if wav_abspath: wav_file = os.path.join(wav_abspath, wav_file) @@ -152,26 +156,21 @@ def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False, # Annotation.xml file is kept (since this is how the repository is # structured) xml_dirname = os.path.dirname(xml_file) - wav_file = os.path.join(xml_dirname, 'Wave', wav_file) + wav_file = os.path.join(xml_dirname, "Wave", wav_file) if not os.path.isfile(wav_file): - raise FileNotFoundError('File {wav_file} is not found') + raise FileNotFoundError("File {wav_file} is not found") - position = int(seq.find('Position').text) - length = int(seq.find('Length').text) + position = int(seq.find("Position").text) + length = int(seq.find("Length").text) syl_list = [] - for syl in seq.iter(tag='Note'): - syl_position = int(syl.find('Position').text) - syl_length = int(syl.find('Length').text) - label = syl.find('Label').text - - syl_obj = BirdsongRecSyllable(position=syl_position, - length=syl_length, - label=label) + for syl in seq.iter(tag="Note"): + syl_position = int(syl.find("Position").text) + syl_length = int(syl.find("Length").text) + label = syl.find("Label").text + + syl_obj = BirdsongRecSyllable(position=syl_position, length=syl_length, label=label) syl_list.append(syl_obj) - seq_obj = BirdsongRecSequence(wav_file=wav_file, - position=position, - length=length, - syl_list=syl_list) + seq_obj = BirdsongRecSequence(wav_file=wav_file, position=position, length=length, syl_list=syl_list) seq_list.append(seq_obj) if concat_seqs_into_songs: diff --git a/src/crowsetta/formats/seq/notmat.py b/src/crowsetta/formats/seq/notmat.py index 92dd2da..4e0cf14 100644 --- a/src/crowsetta/formats/seq/notmat.py +++ b/src/crowsetta/formats/seq/notmat.py @@ -1,6 +1,8 @@ """Module with functions that handle .not.mat annotation files produced by evsonganaly GUI. """ +from __future__ import annotations + import pathlib from typing import ClassVar, Dict, Optional @@ -12,7 +14,7 @@ from crowsetta.typing import PathLike -def load_notmat(filename): +def load_notmat(filename: PathLike) -> dict: """loads .not.mat files created by evsonganaly (Matlab GUI for labeling song) Parameters @@ -52,12 +54,10 @@ def load_notmat(filename): filename = filename.parent.joinpath(filename.name + ".not.mat") else: ext = filename.suffix - raise ValueError( - f"Filename should have extension .cbin.not.mat or .cbin but extension was: {ext}" - ) + raise ValueError(f"Filename should have extension .cbin.not.mat or .cbin but extension was: {ext}") notmat_dict = scipy.io.loadmat(filename, squeeze_me=True) # ensure that onsets and offsets are always arrays, not scalar - for key in ('onsets', 'offsets'): + for key in ("onsets", "offsets"): if np.isscalar(notmat_dict[key]): # `squeeze_me` makes them a ``float``, this will be True in that case value = np.array(notmat_dict[key])[np.newaxis] # ``np.newaxis`` ensures 1-d array with shape (1,) notmat_dict[key] = value diff --git a/src/crowsetta/formats/seq/textgrid/parse.py b/src/crowsetta/formats/seq/textgrid/parse.py index 44a7ca2..a0f93b8 100644 --- a/src/crowsetta/formats/seq/textgrid/parse.py +++ b/src/crowsetta/formats/seq/textgrid/parse.py @@ -173,7 +173,7 @@ def parse_fp(fp: TextIO, keep_empty: bool = False) -> dict: } tiers = [] - for i in range(n_tier): + for _ in range(n_tier): if not is_short: fp.readline() # skip item[\d]: (where \d is some number) tier_type = get_str_from_next_line(fp) @@ -182,7 +182,7 @@ def parse_fp(fp: TextIO, keep_empty: bool = False) -> dict: xmax_tier = get_float_from_next_line(fp) entries = [] # intervals or points depending on tier type - for i in range(get_int_from_next_line(fp)): + for _ in range(get_int_from_next_line(fp)): if not is_short: fp.readline() # skip intervals [\d] if tier_type == INTERVAL_TIER: @@ -246,7 +246,7 @@ def parse(textgrid_path: str | pathlib.Path, keep_empty: bool = False) -> dict: try: with textgrid_path.open("r", encoding="utf-16") as fp: textgrid_raw = parse_fp(fp, keep_empty) - except (UnicodeError, UnicodeDecodeError): + except UnicodeError: with textgrid_path.open("r", encoding="utf-8") as fp: textgrid_raw = parse_fp(fp, keep_empty) return textgrid_raw diff --git a/src/crowsetta/formats/seq/textgrid/textgrid.py b/src/crowsetta/formats/seq/textgrid/textgrid.py index 1c0d63e..821e5b2 100644 --- a/src/crowsetta/formats/seq/textgrid/textgrid.py +++ b/src/crowsetta/formats/seq/textgrid/textgrid.py @@ -368,7 +368,9 @@ def to_seq( return seq - def to_annot(self, tier: int | str | None = None, round_times: bool = True, decimals: int = 3) -> crowsetta.Annotation: + def to_annot( + self, tier: int | str | None = None, round_times: bool = True, decimals: int = 3 + ) -> crowsetta.Annotation: """Convert interval tier or tiers from this TextGrid annotation to a :class:`crowsetta.Annotation` with a :data:`seq` attribute. diff --git a/src/crowsetta/transcriber.py b/src/crowsetta/transcriber.py index 3c95b6a..124f3fb 100644 --- a/src/crowsetta/transcriber.py +++ b/src/crowsetta/transcriber.py @@ -121,6 +121,7 @@ def __init__(self, format: "Union[str, crowsetta.interface.SeqLike, crowsetta.in "and the name 'csv' will stop working in the next version. " "Please change any usages of the name 'csv' to 'generic-seq'` now.", FutureWarning, + stacklevel=2, ) _format_class = formats.by_name(format) elif inspect.isclass(format):