Skip to content

Commit

Permalink
Apply linting + flake8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed Feb 2, 2024
1 parent a074f32 commit 6533b59
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 50 deletions.
10 changes: 4 additions & 6 deletions src/crowsetta/annotation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""A class to represent annotations for a single file."""
from __future__ import annotations

from pathlib import Path
import reprlib
from pathlib import Path
from typing import Optional

import crowsetta
Expand Down Expand Up @@ -89,12 +89,10 @@ def __init__(

if seq:
if not (
isinstance(seq, crowsetta.Sequence) or
(isinstance(seq, list) and all([isinstance(seq_, crowsetta.Sequence) for seq_ in seq]))
isinstance(seq, crowsetta.Sequence)
or (isinstance(seq, list) and all([isinstance(seq_, crowsetta.Sequence) for seq_ in seq]))
):
raise TypeError(
f"``seq`` should be a crowsetta.Sequence or list of Sequences but was: {type(seq)}"
)
raise TypeError(f"``seq`` should be a crowsetta.Sequence or list of Sequences but was: {type(seq)}")
self.seq = seq

if bboxes:
Expand Down
69 changes: 34 additions & 35 deletions src/crowsetta/formats/seq/birdsongrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
Boundaries in the Birdsong with Variable Sequences. PLoS ONE 11(7): e0159188.
doi:10.1371/journal.pone.0159188
"""
from __future__ import annotations

import os
import pathlib
import warnings
from typing import ClassVar, List, Optional
import xml.etree.ElementTree as ET
from typing import ClassVar, List, Optional

import attr
import numpy as np
Expand All @@ -36,13 +38,14 @@ class BirdsongRecSyllable:
text representation of syllable as classified by a human
or a machine learning algorithm
"""
def __init__(self, position, length, label):

def __init__(self, position: int, length: int, label: str) -> None:
if not isinstance(position, int):
raise TypeError(f'position must be an int, not type {type(position)}')
raise TypeError(f"position must be an int, not type {type(position)}")
if not isinstance(length, int):
raise TypeError(f'length must be an int, not type {type(length)}')
raise TypeError(f"length must be an int, not type {type(length)}")
if not isinstance(label, str):
raise TypeError(f'label must be a string, not type {type(label)}')
raise TypeError(f"label must be a string, not type {type(label)}")
self.position = position
self.length = length
self.label = label
Expand All @@ -67,20 +70,19 @@ class BirdsongRecSequence:
list of syllable objects that make up sequence
seq_spect : spectrogram object
"""
def __init__(self, wav_file, position, length, syl_list):

def __init__(self, wav_file: PathLike, position: int, length: int, syl_list: list[BirdsongRecSyllable]):
if not isinstance(wav_file, (str, pathlib.Path)):
raise TypeError(f'wav_file must be a string or pathlib.Path, not type {type(wav_file)}')
raise TypeError(f"wav_file must be a string or pathlib.Path, not type {type(wav_file)}")
wav_file = str(wav_file)
if not isinstance(position, int):
raise TypeError(f'position must be an int, not type {type(position)}')
raise TypeError(f"position must be an int, not type {type(position)}")
if not isinstance(length, int):
raise TypeError(f'length must be an int, not type {type(length)}')
raise TypeError(f"length must be an int, not type {type(length)}")
if not isinstance(syl_list, list):
raise TypeError(f'syl_list must be a list, not type {type(syl_list)}')
raise TypeError(f"syl_list must be a list, not type {type(syl_list)}")
if not all([type(syl) == BirdsongRecSyllable for syl in syl_list]):
raise TypeError('not all elements in syl list are of type BirdsongRecSyllable: '
f'{syl_list}')
raise TypeError("not all elements in syl list are of type BirdsongRecSyllable: " f"{syl_list}")
self.wav_file = wav_file
self.position = position
self.length = length
Expand All @@ -91,9 +93,12 @@ def __repr__(self):
return f"Sequence(wav_file={self.wav_file}, position={self.position}, length={self.length}, syls={self.syls})"



def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False,
wav_abspath=None):
def parse_xml(
xml_file: PathLike,
concat_seqs_into_songs: bool = False,
return_wav_abspath: bool = False,
wav_abspath: PathLike = None,
) -> list[BirdsongRecSequence]:
"""parses Annotation.xml files from the BirdsongRecognition dataset:
Koumura, T. (2016). BirdsongRecognition (Version 1). figshare.
https://doi.org/10.6084/m9.figshare.3470165.v1
Expand Down Expand Up @@ -138,12 +143,11 @@ def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False,
if return_wav_abspath:
if wav_abspath:
if not os.path.isdir(wav_abspath):
raise NotADirectoryError(f'return_wav_abspath is True but {wav_abspath} '
'is not a valid directory.')
raise NotADirectoryError(f"return_wav_abspath is True but {wav_abspath} " "is not a valid directory.")
tree = ET.ElementTree(file=xml_file)
seq_list = []
for seq in tree.iter(tag='Sequence'):
wav_file = seq.find('WaveFileName').text
for seq in tree.iter(tag="Sequence"):
wav_file = seq.find("WaveFileName").text
if return_wav_abspath:
if wav_abspath:
wav_file = os.path.join(wav_abspath, wav_file)
Expand All @@ -152,26 +156,21 @@ def parse_xml(xml_file, concat_seqs_into_songs=False, return_wav_abspath=False,
# Annotation.xml file is kept (since this is how the repository is
# structured)
xml_dirname = os.path.dirname(xml_file)
wav_file = os.path.join(xml_dirname, 'Wave', wav_file)
wav_file = os.path.join(xml_dirname, "Wave", wav_file)
if not os.path.isfile(wav_file):
raise FileNotFoundError('File {wav_file} is not found')
raise FileNotFoundError("File {wav_file} is not found")

position = int(seq.find('Position').text)
length = int(seq.find('Length').text)
position = int(seq.find("Position").text)
length = int(seq.find("Length").text)
syl_list = []
for syl in seq.iter(tag='Note'):
syl_position = int(syl.find('Position').text)
syl_length = int(syl.find('Length').text)
label = syl.find('Label').text

syl_obj = BirdsongRecSyllable(position=syl_position,
length=syl_length,
label=label)
for syl in seq.iter(tag="Note"):
syl_position = int(syl.find("Position").text)
syl_length = int(syl.find("Length").text)
label = syl.find("Label").text

syl_obj = BirdsongRecSyllable(position=syl_position, length=syl_length, label=label)
syl_list.append(syl_obj)
seq_obj = BirdsongRecSequence(wav_file=wav_file,
position=position,
length=length,
syl_list=syl_list)
seq_obj = BirdsongRecSequence(wav_file=wav_file, position=position, length=length, syl_list=syl_list)
seq_list.append(seq_obj)

if concat_seqs_into_songs:
Expand Down
10 changes: 5 additions & 5 deletions src/crowsetta/formats/seq/notmat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module with functions that handle .not.mat annotation files
produced by evsonganaly GUI.
"""
from __future__ import annotations

import pathlib
from typing import ClassVar, Dict, Optional

Expand All @@ -12,7 +14,7 @@
from crowsetta.typing import PathLike


def load_notmat(filename):
def load_notmat(filename: PathLike) -> dict:
"""loads .not.mat files created by evsonganaly (Matlab GUI for labeling song)
Parameters
Expand Down Expand Up @@ -52,12 +54,10 @@ def load_notmat(filename):
filename = filename.parent.joinpath(filename.name + ".not.mat")
else:
ext = filename.suffix
raise ValueError(
f"Filename should have extension .cbin.not.mat or .cbin but extension was: {ext}"
)
raise ValueError(f"Filename should have extension .cbin.not.mat or .cbin but extension was: {ext}")
notmat_dict = scipy.io.loadmat(filename, squeeze_me=True)
# ensure that onsets and offsets are always arrays, not scalar
for key in ('onsets', 'offsets'):
for key in ("onsets", "offsets"):
if np.isscalar(notmat_dict[key]): # `squeeze_me` makes them a ``float``, this will be True in that case
value = np.array(notmat_dict[key])[np.newaxis] # ``np.newaxis`` ensures 1-d array with shape (1,)
notmat_dict[key] = value
Expand Down
6 changes: 3 additions & 3 deletions src/crowsetta/formats/seq/textgrid/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def parse_fp(fp: TextIO, keep_empty: bool = False) -> dict:
}

tiers = []
for i in range(n_tier):
for _ in range(n_tier):
if not is_short:
fp.readline() # skip item[\d]: (where \d is some number)
tier_type = get_str_from_next_line(fp)
Expand All @@ -182,7 +182,7 @@ def parse_fp(fp: TextIO, keep_empty: bool = False) -> dict:
xmax_tier = get_float_from_next_line(fp)

entries = [] # intervals or points depending on tier type
for i in range(get_int_from_next_line(fp)):
for _ in range(get_int_from_next_line(fp)):
if not is_short:
fp.readline() # skip intervals [\d]
if tier_type == INTERVAL_TIER:
Expand Down Expand Up @@ -246,7 +246,7 @@ def parse(textgrid_path: str | pathlib.Path, keep_empty: bool = False) -> dict:
try:
with textgrid_path.open("r", encoding="utf-16") as fp:
textgrid_raw = parse_fp(fp, keep_empty)
except (UnicodeError, UnicodeDecodeError):
except UnicodeError:
with textgrid_path.open("r", encoding="utf-8") as fp:
textgrid_raw = parse_fp(fp, keep_empty)
return textgrid_raw
4 changes: 3 additions & 1 deletion src/crowsetta/formats/seq/textgrid/textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def to_seq(

return seq

def to_annot(self, tier: int | str | None = None, round_times: bool = True, decimals: int = 3) -> crowsetta.Annotation:
def to_annot(
self, tier: int | str | None = None, round_times: bool = True, decimals: int = 3
) -> crowsetta.Annotation:
"""Convert interval tier or tiers from this TextGrid annotation
to a :class:`crowsetta.Annotation` with a :data:`seq` attribute.
Expand Down
1 change: 1 addition & 0 deletions src/crowsetta/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, format: "Union[str, crowsetta.interface.SeqLike, crowsetta.in
"and the name 'csv' will stop working in the next version. "
"Please change any usages of the name 'csv' to 'generic-seq'` now.",
FutureWarning,
stacklevel=2,
)
_format_class = formats.by_name(format)
elif inspect.isclass(format):
Expand Down

0 comments on commit 6533b59

Please sign in to comment.