Skip to content

Commit

Permalink
More ruff linting (#40)
Browse files Browse the repository at this point in the history
* Apply Ruff UP fixes

* Add explicit optionals according to PEP 484

* Enable ruff B lints

* Remove commented-out code

* Use more list comprehensions (and ignore PERF401 where relevant)

* Enable ruff T lint, add noqa for now

---------

Co-authored-by: Ruff <[email protected]>
  • Loading branch information
akx and Ruff authored Nov 21, 2023
1 parent 38cfa38 commit ad3cc01
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 75 deletions.
22 changes: 9 additions & 13 deletions miditoolkit/midi/containers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union


@dataclass
Expand Down Expand Up @@ -270,10 +270,10 @@ def __init__(
program: int,
is_drum: bool = False,
name: str = "",
notes: List[Note] = None,
pitch_bends: List[PitchBend] = None,
control_changes: List[ControlChange] = None,
pedals: List[Pedal] = None,
notes: Optional[List[Note]] = None,
pitch_bends: Optional[List[PitchBend]] = None,
control_changes: Optional[List[ControlChange]] = None,
pedals: Optional[List[Pedal]] = None,
):
"""Create the Instrument."""
self.program = program
Expand All @@ -287,16 +287,12 @@ def __init__(
def remove_invalid_notes(self, verbose: bool = True):
"""Removes any notes whose end time is before or at their start time."""
# Crete a list of all invalid notes
notes_to_delete = []
for note in self.notes:
if note.end <= note.start:
notes_to_delete.append(note)
notes_to_delete = [note for note in self.notes if note.end <= note.start]
if verbose:
if len(notes_to_delete):
print("\nInvalid notes:")
print(notes_to_delete, "\n\n")
print("\nInvalid notes:\n", notes_to_delete, "\n\n") # noqa: T201
else:
print("no invalid notes found")
print("no invalid notes found") # noqa: T201
return True

# Remove the notes found
Expand Down Expand Up @@ -357,7 +353,7 @@ def _key_name_to_key_number(key_string: str):
# Match provided key string
result = re.match(pattern, key_string)
if result is None:
raise ValueError("Supplied key {} is not valid.".format(key_string))
raise ValueError(f"Supplied key {key_string} is not valid.")
# Convert result to dictionary
result = result.groupdict()

Expand Down
71 changes: 19 additions & 52 deletions miditoolkit/midi/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import functools
from pathlib import Path
from typing import Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import mido
import numpy as np
Expand All @@ -19,19 +19,19 @@
TimeSignature,
)

DEFAULT_BPM = int(120)
DEFAULT_BPM = 120

# We "hack" mido's Note_on messages checks to allow to add an "end" attribute, that
# will serve us to sort the messages in the good order when writing a MIDI file.
new_set = set(list(mido.messages.SPEC_BY_TYPE["note_on"]["attribute_names"]) + ["end"])
new_set = {"end", *mido.messages.SPEC_BY_TYPE["note_on"]["attribute_names"]}
mido.messages.SPEC_BY_TYPE["note_on"]["attribute_names"] = new_set
mido.messages.checks._CHECKS["end"] = mido.messages.checks.check_time


class MidiFile(object):
class MidiFile:
def __init__(
self,
filename: Union[Path, str] = None,
filename: Optional[Union[Path, str]] = None,
file=None,
ticks_per_beat: int = 480,
clip: bool = False,
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
@staticmethod
def _convert_delta_to_cumulative(mido_obj):
for track in mido_obj.tracks:
tick = int(0)
tick = 0
for event in track:
event.time += tick
tick = event.time
Expand Down Expand Up @@ -221,8 +221,7 @@ def __get_instrument(program, channel, track, create_new):

for track_idx, track in enumerate(midi_data.tracks):
# Keep track of last note on location:
# key = (instrument, note),
# value = (note-on tick, velocity)
# key = (instrument, note), value = (note-on tick, velocity)
last_note_on = collections.defaultdict(list)
# Keep track of which instrument is playing in each channel
# initialize to program 0 for all channels
Expand Down Expand Up @@ -336,14 +335,14 @@ def __repr__(self):

def __str__(self):
output_list = [
"ticks per beat: {}".format(self.ticks_per_beat),
"max tick: {}".format(self.max_tick),
"tempo changes: {}".format(len(self.tempo_changes)),
"time sig: {}".format(len(self.time_signature_changes)),
"key sig: {}".format(len(self.key_signature_changes)),
"markers: {}".format(len(self.markers)),
"lyrics: {}".format(bool(len(self.lyrics))),
"instruments: {}".format(len(self.instruments)),
f"ticks per beat: {self.ticks_per_beat}",
f"max tick: {self.max_tick}",
f"tempo changes: {len(self.tempo_changes)}",
f"time sig: {len(self.time_signature_changes)}",
f"key sig: {len(self.key_signature_changes)}",
f"markers: {len(self.markers)}",
f"lyrics: {bool(len(self.lyrics))}",
f"instruments: {len(self.instruments)}",
]
output_str = "\n".join(output_list)
return output_str
Expand All @@ -370,11 +369,11 @@ def __eq__(self, other):

def dump(
self,
filename: Union[str, Path] = None,
filename: Optional[Union[str, Path]] = None,
file=None,
segment: Tuple[int, int] = None,
segment: Optional[Tuple[int, int]] = None,
shift=True,
instrument_idx: int = None,
instrument_idx: Optional[int] = None,
charset: str = "latin1",
):
# comparison function
Expand Down Expand Up @@ -412,7 +411,7 @@ def event_compare(event1, event2):
return 0

if (filename is None) and (file is None):
raise IOError("please specify the output.")
raise OSError("please specify the output.")

if instrument_idx is None:
pass
Expand All @@ -427,9 +426,7 @@ def event_compare(event1, event2):
midi_parsed = mido.MidiFile(ticks_per_beat=self.ticks_per_beat, charset=charset)

# Create track 0 with timing information
# meta_track = mido.MidiTrack()

# -- meta track -- #
# 1. Time signature
# add default
add_ts = True
Expand Down Expand Up @@ -635,36 +632,6 @@ def event_compare(event1, event2):
)
track = sorted(track, key=functools.cmp_to_key(event_compare))

"""memo = 0
i = 0
while i < len(track):
# print(i)
# print(len(track))
if track[i].type == "control_change":
tmp = track[i].value
if tmp == memo:
track.pop(i)
else:
memo = track[i].value
i += 1
else:
i += 1"""

# i = 0
# while i <= len(cc_list)-1:
# assert cc_list[i].value == 127
# if cc_list[i].time < track[0].time:
# track.insert(0, cc_list[i])
# track.insert(0, cc_list[i+1])
# i = i+2
# else:
# for j in range(len(track)-1):
# if track[j].time <= cc_list[i].time < track[j+1].time:
# track.insert(j+1, cc_list[i])
# track.insert(j+2, cc_list[i+1])
# i = i+2
# break

# Finally, add in an end of track event
track.append(mido.MetaMessage("end_of_track", time=track[-1].time + 1))
# Add to the list of output tracks
Expand Down
12 changes: 6 additions & 6 deletions miditoolkit/pianoroll/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -9,12 +9,12 @@

def notes2pianoroll(
notes: List[Note],
pitch_range: Tuple[int, int] = None,
pitch_range: Optional[Tuple[int, int]] = None,
pitch_offset: int = 0,
resample_factor: float = None,
resample_factor: Optional[float] = None,
resample_method: Callable = round,
velocity_threshold: int = 0,
time_portion: Tuple[int, int] = None,
time_portion: Optional[Tuple[int, int]] = None,
keep_note_with_zero_duration: bool = True,
) -> np.ndarray:
r"""Converts a sequence of notes into a pianoroll numpy array.
Expand Down Expand Up @@ -120,8 +120,8 @@ def notes2pianoroll(

def pianoroll2notes(
pianoroll: np.ndarray,
resample_factor: float = None,
pitch_range: Union[int, Tuple[int, int]] = None,
resample_factor: Optional[float] = None,
pitch_range: Optional[Union[int, Tuple[int, int]]] = None,
):
"""Converts a pianoroll (numpy array) into a sequence of notes.
Expand Down
9 changes: 9 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
target-version = "py37"
extend-select = [
"B",
"ERA",
"I",
"PERF",
"RUF013",
"T",
"UP",
]

[extend-per-file-ignores]
"miditoolkit/midi/parser.py" = ["PERF401"]
4 changes: 0 additions & 4 deletions tests/test_pianoroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_pianoroll():
# notes2pianoroll has a "last income priority" logic, for which if a notes is occurs
# when another one of the same pitch is already being played, this new note will be
# represented and will end the previous one (if they have different velocities).
# deduplicate_notes(track.notes)

for test_set in test_sets:
# Set pitch range parameters
Expand Down Expand Up @@ -63,9 +62,6 @@ def test_pianoroll():
), "Number of notes changed in pianoroll conversion"
for note1, note2 in zip(new_notes, new_new_notes):
# We don't test the resampling factor as it might later the number of notes
# if "resample_factor" in test_set:
# note1.start = int(round(note1.start * test_set["resample_factor"]))
# note1.end = int(round(note1.end * test_set["resample_factor"]))
assert (
note1 == note2
), "Notes before and after pianoroll conversion are not the same"
Expand Down

0 comments on commit ad3cc01

Please sign in to comment.