Skip to content

Commit

Permalink
adding type hints and minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Nov 22, 2023
1 parent 3b0374f commit a857e33
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 108 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ If you are working with seconds time units (for e.g. music transcription), you'l

## TODO

* better documentation
* absolute timing
* cropping: Control Changes
* cropping: bars
* better documentation;
* finish the code cleaning of the pianoroll methods (vis);
* a way to switch the time in seconds across the whole MidiFile object;
* cropping Control Changes and bars;
* symbolic features
* new structural analysis

Expand Down
1 change: 0 additions & 1 deletion miditoolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Author: Wen-Yi Hsiao, Taiwan
Update date: 2020.06.23
"""

__version__ = "1.0.1"
Expand Down
19 changes: 19 additions & 0 deletions miditoolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,22 @@
"""

PITCH_RANGE = (0, 127)
PITCH_ID_TO_NAME = {
0: "C",
1: "C#",
2: "D",
3: "D#",
4: "E",
5: "F",
6: "F#",
7: "G",
8: "G#",
9: "A",
10: "A#",
11: "B",
}

MAJOR_NAMES = ["M", "Maj", "Major", "maj", "major"]
MINOR_NAMES = ["m", "Min", "Minor", "min", "minor"]

DEFAULT_BPM = 120
16 changes: 8 additions & 8 deletions miditoolkit/midi/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import List, Optional, Union

from ..constants import MAJOR_NAMES, MINOR_NAMES


@dataclass
class Note:
Expand All @@ -27,7 +29,7 @@ class Note:
end: int

@property
def duration(self):
def duration(self) -> int:
"""Get the duration of the note in ticks."""
return self.end - self.start

Expand All @@ -49,7 +51,7 @@ class Pedal:
end: int

@property
def duration(self):
def duration(self) -> int:
return self.end - self.start


Expand Down Expand Up @@ -333,10 +335,8 @@ def __eq__(self, other):
return True


def _key_name_to_key_number(key_string: str):
def _key_name_to_key_number(key_string: str) -> int:
# Create lists of possible mode names (major or minor)
major_strs = ["M", "Maj", "Major", "maj", "major"]
minor_strs = ["m", "Min", "Minor", "min", "minor"]
# Construct regular expression for matching key
pattern = re.compile(
# Start with any of A-G, a-g
Expand All @@ -348,7 +348,7 @@ def _key_name_to_key_number(key_string: str):
# Next, look for any of the mode strings
"(?P<mode>(?:(?:" +
# Next, look for any of the major or minor mode strings
")|(?:".join(major_strs + minor_strs)
")|(?:".join(MAJOR_NAMES + MINOR_NAMES)
+ "))?)$"
)
# Match provided key string
Expand All @@ -371,8 +371,8 @@ def _key_name_to_key_number(key_string: str):
# Circle around 12 pitch classes
key_number = key_number % 12
# Offset if mode is minor, or the key name is lowercase
if result["mode"] in minor_strs or (
result["key"].islower() and result["mode"] not in major_strs
if result["mode"] in MINOR_NAMES or (
result["key"].islower() and result["mode"] not in MAJOR_NAMES
):
key_number += 12

Expand Down
83 changes: 47 additions & 36 deletions miditoolkit/midi/parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
import functools
from pathlib import Path
from typing import Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import mido
import numpy as np

from ..constants import DEFAULT_BPM
from .containers import (
ControlChange,
Instrument,
Expand All @@ -19,8 +20,6 @@
TimeSignature,
)

DEFAULT_BPM = 120

# We "hack" mido's Note_on messages checks to allow to add an "end" attribute, that
# will serve us to sort the messages in the good order when writing a MIDI file.
new_set = {"end", *mido.messages.SPEC_BY_TYPE["note_on"]["attribute_names"]}
Expand All @@ -40,12 +39,12 @@ def __init__(
# create empty file
self.ticks_per_beat: int = ticks_per_beat
self.max_tick: int = 0
self.tempo_changes: Sequence[TempoChange] = []
self.time_signature_changes: Sequence[TimeSignature] = []
self.key_signature_changes: Sequence[KeySignature] = []
self.lyrics: Sequence[str] = []
self.markers: Sequence[Marker] = []
self.instruments: Sequence[Instrument] = []
self.tempo_changes: List[TempoChange] = []
self.time_signature_changes: List[TimeSignature] = []
self.key_signature_changes: List[KeySignature] = []
self.lyrics: List[Lyric] = []
self.markers: List[Marker] = []
self.instruments: List[Instrument] = []

# load file
if filename or file:
Expand All @@ -58,7 +57,7 @@ def __init__(
self.ticks_per_beat = mido_obj.ticks_per_beat

# convert delta time to cumulative time
mido_obj = self._convert_delta_to_cumulative(mido_obj)
self._convert_delta_to_cumulative(mido_obj)

# load tempo changes
self.tempo_changes = self._load_tempo_changes(mido_obj)
Expand Down Expand Up @@ -89,16 +88,15 @@ def __init__(
# tick and sec mapping

@staticmethod
def _convert_delta_to_cumulative(mido_obj):
def _convert_delta_to_cumulative(mido_obj: mido.MidiFile):
for track in mido_obj.tracks:
tick = 0
for event in track:
event.time += tick
tick = event.time
return mido_obj

@staticmethod
def _load_tempo_changes(mido_obj: mido.MidiFile):
def _load_tempo_changes(mido_obj: mido.MidiFile) -> List[TempoChange]:
# default bpm
tempo_changes = [TempoChange(DEFAULT_BPM, 0)]

Expand All @@ -118,7 +116,7 @@ def _load_tempo_changes(mido_obj: mido.MidiFile):
return tempo_changes

@staticmethod
def _load_time_signatures(mido_obj):
def _load_time_signatures(mido_obj: mido.MidiFile) -> List[TimeSignature]:
# no default
time_signature_changes = []

Expand All @@ -133,7 +131,7 @@ def _load_time_signatures(mido_obj):
return time_signature_changes

@staticmethod
def _load_key_signatures(mido_obj):
def _load_key_signatures(mido_obj: mido.MidiFile) -> List[KeySignature]:
# no default
key_signature_changes = []

Expand All @@ -146,7 +144,7 @@ def _load_key_signatures(mido_obj):
return key_signature_changes

@staticmethod
def _load_markers(mido_obj):
def _load_markers(mido_obj: mido.MidiFile) -> List[Marker]:
# no default
markers = []

Expand All @@ -158,7 +156,7 @@ def _load_markers(mido_obj):
return markers

@staticmethod
def _load_lyrics(mido_obj):
def _load_lyrics(mido_obj: mido.MidiFile) -> List[Lyric]:
# no default
lyrics = []

Expand All @@ -170,51 +168,56 @@ def _load_lyrics(mido_obj):
return lyrics

@staticmethod
def _load_instruments(midi_data):
def _load_instruments(midi_data: mido.MidiFile) -> List[Instrument]:
instrument_map = collections.OrderedDict()
# Store a similar mapping to instruments storing "straggler events",
# e.g. events which appear before we want to initialize an Instrument
stragglers = {}
# This dict will map track indices to any track names encountered
track_name_map = collections.defaultdict(str)

def __get_instrument(program, channel, track, create_new):
def __get_instrument(
program_: int,
channel: int,
track_: int,
create_new: bool,
):
"""Gets the Instrument corresponding to the given program number,
drum/non-drum type, channel, and track index. If no such
instrument exists, one is created.
"""
# If we have already created an instrument for this program
# number/track/channel, return it
if (program, channel, track) in instrument_map:
return instrument_map[(program, channel, track)]
if (program_, channel, track_) in instrument_map:
return instrument_map[(program_, channel, track_)]
# If there's a straggler instrument for this instrument and we
# aren't being requested to create a new instrument
if not create_new and (channel, track) in stragglers:
return stragglers[(channel, track)]
if not create_new and (channel, track_) in stragglers:
return stragglers[(channel, track_)]
is_drum = channel == 9
# If we are told to, create a new instrument and store it
if create_new:
instrument_ = Instrument(program, is_drum, track_name_map[track_idx])
instrument_ = Instrument(program_, is_drum, track_name_map[track_idx])
# If any events appeared for this instrument before now,
# include them in the new instrument
if (channel, track) in stragglers:
straggler = stragglers[(channel, track)]
if (channel, track_) in stragglers:
straggler = stragglers[(channel, track_)]
instrument_.control_changes = straggler.control_changes
instrument_.pitch_bends = straggler.pitch_bends
instrument_.pedals = straggler.pedals
# Add the instrument to the instrument map
instrument_map[(program, channel, track)] = instrument_
instrument_map[(program_, channel, track_)] = instrument_
# Otherwise, create a "straggler" instrument which holds events
# which appear before we actually want to create a proper new
# instrument
else:
# Create a "straggler" instrument
instrument_ = Instrument(program, is_drum, track_name_map[track_idx])
instrument_ = Instrument(program_, is_drum, track_name_map[track_idx])
# Note that stragglers ignores program number, because we want
# to store all events on a track which appear before the first
# note-on, regardless of program
stragglers[(channel, track)] = instrument_
stragglers[(channel, track_)] = instrument_
return instrument_

for track_idx, track in enumerate(midi_data.tracks):
Expand Down Expand Up @@ -271,7 +274,7 @@ def __get_instrument(program, channel, track, create_new):
# instrument
# Create a new instrument if none exists
instrument = __get_instrument(
program, event.channel, track_idx, 1
program, event.channel, track_idx, True
)
# Add the note event
instrument.notes.append(note)
Expand All @@ -291,7 +294,9 @@ def __get_instrument(program, channel, track, create_new):
program = current_instrument[event.channel]
# Retrieve the Instrument instance for the current inst
# Don't create a new instrument if none exists
instrument = __get_instrument(program, event.channel, track_idx, 0)
instrument = __get_instrument(
program, event.channel, track_idx, False
)
# Add the pitch bend event
instrument.pitch_bends.append(bend)
# Store control changes
Expand All @@ -303,7 +308,9 @@ def __get_instrument(program, channel, track, create_new):
program = current_instrument[event.channel]
# Retrieve the Instrument instance for the current inst
# Don't create a new instrument if none exists
instrument = __get_instrument(program, event.channel, track_idx, 0)
instrument = __get_instrument(
program, event.channel, track_idx, False
)
# Add the control change event
instrument.control_changes.append(control_change)

Expand All @@ -323,7 +330,7 @@ def __get_instrument(program, channel, track, create_new):
instruments = [i for i in instrument_map.values()]
return instruments

def get_tick_to_time_mapping(self):
def get_tick_to_time_mapping(self) -> np.ndarray:
return _get_tick_to_second_mapping(
self.ticks_per_beat, self.max_tick, self.tempo_changes
)
Expand Down Expand Up @@ -374,7 +381,7 @@ def dump(
filename: Optional[Union[str, Path]] = None,
file=None,
segment: Optional[Tuple[int, int]] = None,
shift=True,
shift: bool = True,
instrument_idx: Optional[int] = None,
charset: str = "latin1",
):
Expand Down Expand Up @@ -689,8 +696,12 @@ def _is_note_within_tick_range(


def _include_meta_events_within_tick_range(
events, start_tick: int, end_tick: int, shift: bool = False, front: bool = True
):
events: Sequence[Union[mido.MetaMessage, mido.Message]],
start_tick: int,
end_tick: int,
shift: bool = False,
front: bool = True,
) -> Sequence[mido.MetaMessage]:
r"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion miditoolkit/midi/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os


def example_midi_file():
def example_midi_file() -> str:
# like librosa.util.example_audio_file()
path_curfile = os.path.dirname(os.path.abspath(__file__))
path_midi = os.path.join(path_curfile, "examples_data", "1390.mid")
Expand Down
2 changes: 1 addition & 1 deletion miditoolkit/pianoroll/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def pianoroll2notes(
pianoroll: np.ndarray,
resample_factor: Optional[float] = None,
pitch_range: Optional[Union[int, Tuple[int, int]]] = None,
):
) -> List[Note]:
"""Converts a pianoroll (numpy array) into a sequence of notes.
Args:
Expand Down
Loading

0 comments on commit a857e33

Please sign in to comment.