-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make py.test use more idiomatic (parametrize instead of internal loop…
…ing)
- Loading branch information
Showing
4 changed files
with
73 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
"pytest-cov", | ||
"pytest-xdist[psutil]", | ||
"setuptools", | ||
"tqdm", | ||
] | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,67 @@ | ||
#!/usr/bin/python3 python | ||
from operator import attrgetter | ||
|
||
"""Testing creating pianorolls of notes. | ||
""" | ||
|
||
from pathlib import Path | ||
|
||
from tqdm import tqdm | ||
import pytest | ||
|
||
from miditoolkit import MidiFile | ||
from miditoolkit.constants import PITCH_RANGE | ||
from miditoolkit.pianoroll import notes2pianoroll, pianoroll2notes | ||
from tests.utils import MIDI_PATHS | ||
|
||
test_sets = [ | ||
{"pitch_range": (0, 127)}, | ||
{"pitch_range": (24, 96)}, | ||
{"pitch_range": (24, 116), "pitch_offset": 12}, | ||
{"pitch_range": (6, 96), "pitch_offset": 12}, | ||
{"pitch_range": (24, 96), "pitch_offset": 12, "velocity_threshold": 36}, | ||
] | ||
|
||
def test_pianoroll(): | ||
midi_paths = list(Path("tests", "testcases").glob("**/*.mid")) | ||
test_sets = [ | ||
{"pitch_range": (0, 127)}, | ||
{"pitch_range": (24, 96)}, | ||
{"pitch_range": (24, 116), "pitch_offset": 12}, | ||
{"pitch_range": (6, 96), "pitch_offset": 12}, | ||
{"pitch_range": (24, 96), "pitch_offset": 12, "velocity_threshold": 36}, | ||
] | ||
|
||
for path in tqdm(midi_paths, desc="Checking pianoroll conversion"): | ||
midi = MidiFile(path) | ||
|
||
for track in midi.instruments: | ||
# We do a first notes -> pianoroll -> notes conversion before | ||
# This step is required as the pianoroll conversion is lossy with overlapping notes. | ||
# notes2pianoroll has a "last income priority" logic, for which if a notes is occurs | ||
# when another one of the same pitch is already being played, this new note will be | ||
# represented and will end the previous one (if they have different velocities). | ||
# deduplicate_notes(track.notes) | ||
|
||
for test_set in test_sets: | ||
# Set pitch range parameters | ||
pitch_range = test_set.get("pitch_range", PITCH_RANGE) | ||
if "pitch_offset" in test_set: | ||
pitch_range = ( | ||
max(PITCH_RANGE[0], pitch_range[0] - test_set["pitch_offset"]), | ||
min(PITCH_RANGE[1], pitch_range[1] + test_set["pitch_offset"]), | ||
) | ||
@pytest.mark.parametrize("midi_path", MIDI_PATHS, ids=attrgetter("name")) | ||
def test_pianoroll(midi_path): | ||
"""Testing creating pianorolls of notes.""" | ||
midi = MidiFile(midi_path) | ||
|
||
# First pianoroll <--> notes conversion, losing overlapping notes | ||
pianoroll = notes2pianoroll(track.notes, **test_set) | ||
new_notes = pianoroll2notes(pianoroll, pitch_range=pitch_range) | ||
for track in midi.instruments: | ||
# We do a first notes -> pianoroll -> notes conversion before | ||
# This step is required as the pianoroll conversion is lossy with overlapping notes. | ||
# notes2pianoroll has a "last income priority" logic, for which if a notes is occurs | ||
# when another one of the same pitch is already being played, this new note will be | ||
# represented and will end the previous one (if they have different velocities). | ||
# deduplicate_notes(track.notes) | ||
|
||
# Second one, notes -> pianoroll -> new notes should be equal | ||
new_pianoroll = notes2pianoroll(new_notes, **test_set) | ||
new_new_notes = pianoroll2notes(new_pianoroll, pitch_range=pitch_range) | ||
if "velocity_threshold" in test_set: | ||
new_notes = [ | ||
note | ||
for note in new_notes | ||
if note.velocity >= test_set["velocity_threshold"] | ||
] | ||
for test_set in test_sets: | ||
# Set pitch range parameters | ||
pitch_range = test_set.get("pitch_range", PITCH_RANGE) | ||
if "pitch_offset" in test_set: | ||
pitch_range = ( | ||
max(PITCH_RANGE[0], pitch_range[0] - test_set["pitch_offset"]), | ||
min(PITCH_RANGE[1], pitch_range[1] + test_set["pitch_offset"]), | ||
) | ||
|
||
# Assert notes are all retrieved | ||
assert len(new_notes) == len( | ||
new_new_notes | ||
), "Number of notes changed in pianoroll conversion" | ||
for note1, note2 in zip(new_notes, new_new_notes): | ||
# We don't test the resampling factor as it might later the number of notes | ||
# if "resample_factor" in test_set: | ||
# note1.start = int(round(note1.start * test_set["resample_factor"])) | ||
# note1.end = int(round(note1.end * test_set["resample_factor"])) | ||
assert ( | ||
note1 == note2 | ||
), "Notes before and after pianoroll conversion are not the same" | ||
# First pianoroll <--> notes conversion, losing overlapping notes | ||
pianoroll = notes2pianoroll(track.notes, **test_set) | ||
new_notes = pianoroll2notes(pianoroll, pitch_range=pitch_range) | ||
|
||
# Second one, notes -> pianoroll -> new notes should be equal | ||
new_pianoroll = notes2pianoroll(new_notes, **test_set) | ||
new_new_notes = pianoroll2notes(new_pianoroll, pitch_range=pitch_range) | ||
if "velocity_threshold" in test_set: | ||
new_notes = [ | ||
note | ||
for note in new_notes | ||
if note.velocity >= test_set["velocity_threshold"] | ||
] | ||
|
||
if __name__ == "__main__": | ||
test_pianoroll() | ||
# Assert notes are all retrieved | ||
assert len(new_notes) == len( | ||
new_new_notes | ||
), "Number of notes changed in pianoroll conversion" | ||
for note1, note2 in zip(new_notes, new_new_notes): | ||
# We don't test the resampling factor as it might later the number of notes | ||
# if "resample_factor" in test_set: | ||
# note1.start = int(round(note1.start * test_set["resample_factor"])) | ||
# note1.end = int(round(note1.end * test_set["resample_factor"])) | ||
assert ( | ||
note1 == note2 | ||
), "Notes before and after pianoroll conversion are not the same" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,22 @@ | ||
#!/usr/bin/python3 python | ||
from operator import attrgetter | ||
|
||
"""Testing that a MIDI loaded and saved unchanged is indeed the save as before. | ||
""" | ||
|
||
import shutil | ||
from pathlib import Path | ||
|
||
from tqdm import tqdm | ||
import pytest | ||
|
||
from miditoolkit import MidiFile | ||
from tests.utils import MIDI_PATHS | ||
|
||
|
||
def test_load_dump(): | ||
midi_paths = list(Path("tests", "testcases").glob("**/*.mid")) | ||
out_path = Path("tests", "tmp", "load_dump") | ||
out_path.mkdir(parents=True, exist_ok=True) | ||
|
||
for path in tqdm(midi_paths, desc="Checking midis load/save"): | ||
midi = MidiFile(path) | ||
# Writing it unchanged | ||
midi.dump(out_path / path.name) | ||
# Loading it back | ||
midi2 = MidiFile(out_path / path.name) | ||
|
||
# Sorting the notes, as after dump the order might have changed | ||
for track1, track2 in zip(midi.instruments, midi2.instruments): | ||
track1.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) | ||
track2.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) | ||
|
||
assert midi == midi2 | ||
|
||
# deletes tmp directory after tests | ||
shutil.rmtree(out_path) | ||
@pytest.mark.parametrize("midi_path", MIDI_PATHS, ids=attrgetter("name")) | ||
def test_load_dump(midi_path, tmp_path): | ||
"""Test that a MIDI loaded and saved unchanged is indeed the save as before.""" | ||
midi1 = MidiFile(midi_path) | ||
dump_path = tmp_path / midi_path.name | ||
midi1.dump(dump_path) # Writing it unchanged | ||
midi2 = MidiFile(dump_path) # Loading it back | ||
|
||
# Sorting the notes, as after dump the order might have changed | ||
for track1, track2 in zip(midi1.instruments, midi2.instruments): | ||
track1.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) | ||
track2.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) | ||
|
||
if __name__ == "__main__": | ||
test_load_dump() | ||
assert midi1 == midi2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from pathlib import Path | ||
|
||
HERE = Path(__file__).parent | ||
|
||
MIDI_PATHS = sorted((HERE / "testcases").rglob("*.mid")) |