diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d619caa..53b9754 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -23,7 +23,7 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: pip - cache-dependency-path: setup.py + cache-dependency-path: pyproject.toml - name: Install dependencies run: | # Install local package with tests dependencies extras @@ -36,9 +36,9 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies diff --git a/miditoolkit/midi/containers.py b/miditoolkit/midi/containers.py index 4e0fe6b..bf78592 100755 --- a/miditoolkit/midi/containers.py +++ b/miditoolkit/midi/containers.py @@ -1,4 +1,5 @@ import re +import warnings from dataclasses import dataclass from typing import List, Optional, Union @@ -38,9 +39,9 @@ class Pedal: Parameters ---------- start : int - Time where the pedal starts. + Time when the pedal starts. end : int - Time where the pedal ends. + Time when the pedal ends. """ @@ -61,7 +62,7 @@ class PitchBend: pitch : int MIDI pitch bend amount, in the range ``[-8192, 8191]``. time : int - Time where the pitch bend occurs. + Time when the pitch bend occurs. """ @@ -80,7 +81,7 @@ class ControlChange: value : int The value of the control change, in ``[0, 127]``. time : int - Time where the control change occurs. + Time when the control change occurs. """ @@ -284,26 +285,26 @@ def __init__( self.control_changes = [] if control_changes is None else control_changes self.pedals = [] if pedals is None else pedals - def remove_invalid_notes(self, verbose: bool = True): - """Removes any notes whose end time is before or at their start time.""" - # Crete a list of all invalid notes - notes_to_delete = [note for note in self.notes if note.end <= note.start] - if verbose: - if len(notes_to_delete): - print("\nInvalid notes:\n", notes_to_delete, "\n\n") # noqa: T201 - else: - print("no invalid notes found") # noqa: T201 - return True - - # Remove the notes found - for note in notes_to_delete: - self.notes.remove(note) - return False + def remove_invalid_notes(self, verbose: bool = True) -> None: + warnings.warn( + "Call remove_notes_with_no_duration() instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.remove_notes_with_no_duration() + + def remove_notes_with_no_duration(self) -> None: + """Removes (inplace) notes whose end time is before or at their start time.""" + for i in range(self.num_notes - 1, -1, -1): + if self.notes[i].start >= self.notes[i].end: + del self.notes[i] + + @property + def num_notes(self) -> int: + return len(self.notes) def __repr__(self): - return 'Instrument(program={}, is_drum={}, name="{}")'.format( - self.program, self.is_drum, self.name.replace('"', r"\"") - ) + return f"Instrument(program={self.program}, is_drum={self.is_drum}, name={self.name}) - {self.num_notes} notes" def __eq__(self, other): # Here we check all tracks attributes except the name. diff --git a/miditoolkit/midi/parser.py b/miditoolkit/midi/parser.py index 0377722..e26b275 100755 --- a/miditoolkit/midi/parser.py +++ b/miditoolkit/midi/parser.py @@ -38,20 +38,18 @@ def __init__( charset: str = "latin1", ): # create empty file - if filename is None and file is None: - self.ticks_per_beat = ticks_per_beat - self.max_tick = 0 - self.tempo_changes = [] - self.time_signature_changes = [] - self.key_signature_changes = [] - self.lyrics = [] - self.markers = [] - self.instruments = [] - - # load - else: + self.ticks_per_beat: int = ticks_per_beat + self.max_tick: int = 0 + self.tempo_changes: Sequence[TempoChange] = [] + self.time_signature_changes: Sequence[TimeSignature] = [] + self.key_signature_changes: Sequence[KeySignature] = [] + self.lyrics: Sequence[str] = [] + self.markers: Sequence[Marker] = [] + self.instruments: Sequence[Instrument] = [] + + # load file + if filename or file: if filename: - # filename mido_obj = mido.MidiFile(filename=filename, clip=clip, charset=charset) else: mido_obj = mido.MidiFile(file=file, clip=clip, charset=charset) @@ -330,6 +328,10 @@ def get_tick_to_time_mapping(self): self.ticks_per_beat, self.max_tick, self.tempo_changes ) + @property + def num_instruments(self) -> int: + return len(self.instruments) + def __repr__(self): return self.__str__() @@ -342,7 +344,7 @@ def __str__(self): f"key sig: {len(self.key_signature_changes)}", f"markers: {len(self.markers)}", f"lyrics: {bool(len(self.lyrics))}", - f"instruments: {len(self.instruments)}", + f"instruments: {self.num_instruments}", ] output_str = "\n".join(output_list) return output_str diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..1ca2872 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,24 @@ +from operator import attrgetter + +import pytest + +from miditoolkit import MidiFile, Note +from tests.utils import MIDI_PATHS + + +@pytest.mark.parametrize("midi_path", MIDI_PATHS[:5], ids=attrgetter("name")) +def test_remove_notes_with_no_duration(midi_path, tmp_path): + """Test that a MIDI loaded and saved unchanged is indeed the save as before.""" + # Load the MIDI file and removes the notes with durations <= 0 + midi = MidiFile(midi_path) + midi.instruments[0].remove_notes_with_no_duration() + num_notes_before = midi.instruments[0].num_notes + + # Adding notes with durations <= 0, then reapply the method + midi.instruments[0].notes.append(Note(50, 50, 100, 100)) + midi.instruments[0].notes.append(Note(50, 50, 101, 100)) + midi.instruments[0].remove_notes_with_no_duration() + + assert ( + midi.instruments[0].num_notes == num_notes_before + ), "The notes with duration <=0 were not removed by test_remove_notes_with_no_duration"