Skip to content

Commit

Permalink
Fixes remove_invalid_notes (renamed remove_notes_with_no_duration) + …
Browse files Browse the repository at this point in the history
…minor improvements (#42)

* fixes remove_invalid_notes (renamed remove_notes_with_no_duration) + minor improvements

* Update miditoolkit/midi/containers.py

Co-authored-by: Aarni Koskela <[email protected]>

* fixing warnings call/import

* renamed nb contractions to num

* num_instruments property for MidiFile

---------

Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
Natooz and akx authored Nov 22, 2023
1 parent c81dfe2 commit 3b0374f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 39 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
# Install local package with tests dependencies extras
Expand All @@ -36,9 +36,9 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
Expand Down
45 changes: 23 additions & 22 deletions miditoolkit/midi/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -38,9 +39,9 @@ class Pedal:
Parameters
----------
start : int
Time where the pedal starts.
Time when the pedal starts.
end : int
Time where the pedal ends.
Time when the pedal ends.
"""

Expand All @@ -61,7 +62,7 @@ class PitchBend:
pitch : int
MIDI pitch bend amount, in the range ``[-8192, 8191]``.
time : int
Time where the pitch bend occurs.
Time when the pitch bend occurs.
"""

Expand All @@ -80,7 +81,7 @@ class ControlChange:
value : int
The value of the control change, in ``[0, 127]``.
time : int
Time where the control change occurs.
Time when the control change occurs.
"""

Expand Down Expand Up @@ -284,26 +285,26 @@ def __init__(
self.control_changes = [] if control_changes is None else control_changes
self.pedals = [] if pedals is None else pedals

def remove_invalid_notes(self, verbose: bool = True):
"""Removes any notes whose end time is before or at their start time."""
# Crete a list of all invalid notes
notes_to_delete = [note for note in self.notes if note.end <= note.start]
if verbose:
if len(notes_to_delete):
print("\nInvalid notes:\n", notes_to_delete, "\n\n") # noqa: T201
else:
print("no invalid notes found") # noqa: T201
return True

# Remove the notes found
for note in notes_to_delete:
self.notes.remove(note)
return False
def remove_invalid_notes(self, verbose: bool = True) -> None:
warnings.warn(
"Call remove_notes_with_no_duration() instead.",
DeprecationWarning,
stacklevel=2,
)
return self.remove_notes_with_no_duration()

def remove_notes_with_no_duration(self) -> None:
"""Removes (inplace) notes whose end time is before or at their start time."""
for i in range(self.num_notes - 1, -1, -1):
if self.notes[i].start >= self.notes[i].end:
del self.notes[i]

@property
def num_notes(self) -> int:
return len(self.notes)

def __repr__(self):
return 'Instrument(program={}, is_drum={}, name="{}")'.format(
self.program, self.is_drum, self.name.replace('"', r"\"")
)
return f"Instrument(program={self.program}, is_drum={self.is_drum}, name={self.name}) - {self.num_notes} notes"

def __eq__(self, other):
# Here we check all tracks attributes except the name.
Expand Down
30 changes: 16 additions & 14 deletions miditoolkit/midi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,18 @@ def __init__(
charset: str = "latin1",
):
# create empty file
if filename is None and file is None:
self.ticks_per_beat = ticks_per_beat
self.max_tick = 0
self.tempo_changes = []
self.time_signature_changes = []
self.key_signature_changes = []
self.lyrics = []
self.markers = []
self.instruments = []

# load
else:
self.ticks_per_beat: int = ticks_per_beat
self.max_tick: int = 0
self.tempo_changes: Sequence[TempoChange] = []
self.time_signature_changes: Sequence[TimeSignature] = []
self.key_signature_changes: Sequence[KeySignature] = []
self.lyrics: Sequence[str] = []
self.markers: Sequence[Marker] = []
self.instruments: Sequence[Instrument] = []

# load file
if filename or file:
if filename:
# filename
mido_obj = mido.MidiFile(filename=filename, clip=clip, charset=charset)
else:
mido_obj = mido.MidiFile(file=file, clip=clip, charset=charset)
Expand Down Expand Up @@ -330,6 +328,10 @@ def get_tick_to_time_mapping(self):
self.ticks_per_beat, self.max_tick, self.tempo_changes
)

@property
def num_instruments(self) -> int:
return len(self.instruments)

def __repr__(self):
return self.__str__()

Expand All @@ -342,7 +344,7 @@ def __str__(self):
f"key sig: {len(self.key_signature_changes)}",
f"markers: {len(self.markers)}",
f"lyrics: {bool(len(self.lyrics))}",
f"instruments: {len(self.instruments)}",
f"instruments: {self.num_instruments}",
]
output_str = "\n".join(output_list)
return output_str
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from operator import attrgetter

import pytest

from miditoolkit import MidiFile, Note
from tests.utils import MIDI_PATHS


@pytest.mark.parametrize("midi_path", MIDI_PATHS[:5], ids=attrgetter("name"))
def test_remove_notes_with_no_duration(midi_path, tmp_path):
"""Test that a MIDI loaded and saved unchanged is indeed the save as before."""
# Load the MIDI file and removes the notes with durations <= 0
midi = MidiFile(midi_path)
midi.instruments[0].remove_notes_with_no_duration()
num_notes_before = midi.instruments[0].num_notes

# Adding notes with durations <= 0, then reapply the method
midi.instruments[0].notes.append(Note(50, 50, 100, 100))
midi.instruments[0].notes.append(Note(50, 50, 101, 100))
midi.instruments[0].remove_notes_with_no_duration()

assert (
midi.instruments[0].num_notes == num_notes_before
), "The notes with duration <=0 were not removed by test_remove_notes_with_no_duration"

0 comments on commit 3b0374f

Please sign in to comment.