Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prediction Output #131

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

[![Coverage Status](https://coveralls.io/repos/github/ubclaunchpad/minutes/badge.svg)](https://coveralls.io/github/ubclaunchpad/minutes)

Audio speaker diarization library.

## Under Construction!
Audio speaker diarization library.

## :running: Development

Expand Down Expand Up @@ -43,9 +41,6 @@ minutes.add_speakers([s1, s2])
# Fit the model.
minutes.fit()

# Collect a new conversation for prediction.
conversation = Conversation('/path/to/conversation.wav')

# Create phrases from the conversation.
phrases = minutes.phrases(conversation)
# Predict against a new conversation had by speakers s1 and s2.
conversation = Conversation('/path/to/conversation.wav', minutes)
```
6 changes: 3 additions & 3 deletions minutes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .minutes import Minutes # noqa
from .speaker import Speaker # noqa
from .conversation import Conversation # noqa
from minutes.minutes import Minutes # noqa
from minutes.speaker import Speaker # noqa
from minutes.conversation import Phrase, Conversation # noqa
28 changes: 14 additions & 14 deletions minutes/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@


class Audio:
"""Internal audio maninpulation class. I reserve the right to change this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: s/maninpulation/manipulation/g

API :)
"""

def __init__(self, audio_loc):
if os.path.isdir(audio_loc):
Expand Down Expand Up @@ -40,19 +43,19 @@ def samples_per_observation(self, ms_per_observation):
"""
return int(self.rate * ms_per_observation / 1000.)

def get_spectrograms(self, ms_per_observation, verbose=False):
"""Converts a internal table of raw audio audio phrases into with
one spectrogram per row.
def get_observations(self, ms_per_observation):
"""Converts a internal raw audio vector into table with
one spectrogram per row. Also returns raw observations.

Arguments:
ms_per_observation {int} -- The number of desired ms per obs.

Returns:
np.array -- An array of spectrograms, one per row. The width
of each spectrogram depends on the ms_per_observation,
The number of rows depends on the length of the audio file
and the ms per observations.
raw -- The raw audio observation table.
processed -- An array of spectrograms, one per row. The width
of each spectrogram depends on the ms_per_observation.
"""
# Reshape for processing into spectrograms.
d = self.samples_per_observation(ms_per_observation)
N = len(self.data) // d

Expand All @@ -65,17 +68,14 @@ def get_spectrograms(self, ms_per_observation, verbose=False):
# Truncate last (partial) observation.
data = data[:N * d]

if verbose:
t = len(self.data) - (N * d)
print('Truncating {} bytes from end of sample'.format(t))

# Reshape for processing into spectrograms.
data = data.reshape((N, d))
raw = data.reshape((N, d))

def spec_from_row(row):
_, _, Sxx = signal.spectrogram(row)
return Sxx

# This is very slow! Perhaps some logging?
rows = (spec_from_row(row) for row in data)
return np.array([x for x in rows])
rows = (spec_from_row(row) for row in raw)
processed = np.array([x for x in rows])
return raw, processed
44 changes: 39 additions & 5 deletions minutes/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import pickle

from keras import backend as K
from keras.models import Sequential, load_model
Expand All @@ -24,6 +23,19 @@ class BaseModel:
'random_state',
}

@property
def preprocessing_params(self):
"""Returns a mapping of parameters that are required to do preprocessing
of audio data suitable for this model. Useful as kwargs to audio
manipulation classes.
"""
return {
i: getattr(self, i) for i in self.intialization_params
if i in {
'ms_per_observation',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to declare this set as a constant somewhere and refer to it by name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course!

}
}

@property
def fitted(self):
return self.model is not None
Expand Down Expand Up @@ -63,7 +75,7 @@ def load_model(cls, name):
def __init__(self, name, ms_per_observation=3000, test_size=0.33,
random_state=42):
self.name = name
self.speakers = set()
self.speakers = []
self.test_size = test_size
self.random_state = random_state
self.ms_per_observation = ms_per_observation
Expand All @@ -78,7 +90,7 @@ def add_speaker(self, speaker):
"""
if speaker in self.speakers:
raise LookupError(f'Speaker {speaker.name} already added.')
self.speakers.add(speaker)
self.speakers.append(speaker)

def add_speakers(self, speakers):
"""Add a collection of speakers to the model.
Expand All @@ -98,8 +110,12 @@ def _generate_training_data(self):
y -- a categorical one-hot encoding of different speakers
numbered 1..k.
"""
obs = [s.get_observations(self.ms_per_observation)
for s in self.speakers]
obs = []
for s in self.speakers:
_, processed = s.get_observations(**self.preprocessing_params)
obs += processed,

# Generate and flatten labels.
labels = [[i] * len(o) for i, o in enumerate(obs)]
flattened_labels = [j for i in labels for j in i]

Expand Down Expand Up @@ -153,3 +169,21 @@ def save_model(self):
# Save internal model.
if self.model is not None:
self.model.save(os.path.join(self.home, 'keras.h5'))

def predict(self, observations):
"""Predict against a table of audio observations.

Arguments:
observations {np.array} -- A table of processed audio observations.

Returns:
np.array -- An array of predicted speakers.
"""
result = self.model.predict(observations)
y_hat_indices = np.argmax(result, axis=1)

# Index into the speaker array using the predicted speaker indicies.
return np.array(self.speakers)[y_hat_indices]

def __str__(self):
return self.name
40 changes: 23 additions & 17 deletions minutes/conversation.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
from minutes.audio import Audio


class Conversation:
class Phrase:
def __init__(self, observation, speaker):
"""A phrase in a conversation, characterized by an audio segment
and a speaker.

def __init__(self, audio_loc, speakers):
Arguments:
observation {np.array} -- 1 dimensional audio sample.
speaker {Speaker} -- The inferred speaker for the audio segment.
"""
self.observation = observation
self.speaker = speaker


class Conversation(Audio):

def __init__(self, audio_loc, model):
"""Create a new conversation from audio sample.

Arguments:
audio_loc {str} -- The absolute location of an audio conversation
sample.
speakers {List[Speaker]} -- A list of speakers included in this
model {Minutes} -- A model trained on speakers within this
conversation.
"""
self.speakers = speakers
self.audio = Audio(audio_loc)

def get_observations(self, ms_per_observation, verbose=False):
"""Converts the conversation audio sample into an n x d matrix of
observations.
self.model = model
super().__init__(audio_loc)

Keyword Arguments:
verbose {bool} -- (default: {False})
ms_per_observation {int} -- (default: {False})

Returns:
np.array -- An n x d matrix of observations.
"""
return self.audio.get_spectrograms(ms_per_observation, verbose)
# Predict against the conversation spectrograms.
raw, X_hat = self.get_observations(**model.preprocessing_params)
y_hat = model.predict(X_hat)

# Convert to a list of phrases.
self.phrases = [Phrase(o, speaker) for o, speaker in zip(raw, y_hat)]
22 changes: 6 additions & 16 deletions minutes/minutes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def __init__(self, parent='cnn', test_size=0.33, random_state=42):

# Load in parent, copy in fixed parameters.
self.parent = BaseModel.load_model(parent)
self.ms_per_observation = self.parent.ms_per_observation

self.model = None
self.speakers = set()
self.test_size = test_size
self.random_state = random_state
super().__init__(
self.parent.name + '-child',
self.parent.ms_per_observation,
test_size,
random_state,
)

def fit(self, verbose=0):
"""Trains the model, given the speakers currently added."""
Expand Down Expand Up @@ -72,14 +73,3 @@ def fit(self, verbose=0):
batch_size=16,
verbose=verbose
)

def phrases(self, conversation):
"""Predict against a new conversation.

Arguments:
conversation {Conversation} -- A conversation built from an audio
sample.

Returns: <TODO>
"""
pass # TODO
11 changes: 7 additions & 4 deletions minutes/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ def add_audio(self, audio_loc):
"""
self.audio += Audio(audio_loc),

def get_observations(self, ms_per_observation, verbose=False):
obs = [a.get_spectrograms(ms_per_observation, verbose)
for a in self.audio]
return np.concatenate(obs)
def get_observations(self, **preprocessing_params):
raw, processed = [], []
for a in self.audio:
r, p = a.get_observations(**preprocessing_params)
raw += r,
processed += p,
return np.concatenate(raw), np.concatenate(processed)

def __eq__(self, other):
return self.name == other.name
Expand Down
2 changes: 1 addition & 1 deletion minutes/utils/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def copy_model(model):
"""Returns a copy of the model.

Arguments:
model {keras.Sequential} -- A model for copying.
"""
Expand Down
3 changes: 2 additions & 1 deletion test/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import shutil
import tempfile


from minutes import Speaker


TEST_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.join(TEST_DIR, '..')
FIXTURE_DIR = os.path.join(TEST_DIR, 'fixtures')

SPEAKER1_AUDIO = os.path.join(FIXTURE_DIR, 'sample1.wav')
SPEAKER2_AUDIO = os.path.join(FIXTURE_DIR, 'sample2.wav')
CONVERSATION_AUDIO = os.path.join(FIXTURE_DIR, 'conversation.wav')

# Load speaker audio just once for all tests.
SPEAKER1 = Speaker('speaker1')
Expand Down
Binary file added test/fixtures/conversation.wav
Binary file not shown.
2 changes: 1 addition & 1 deletion test/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def test_samples_per_observation():

def test_get_spectrograms():
audio = Audio(c.SPEAKER1_AUDIO)
spec = audio.get_spectrograms(3000)
_, spec = audio.get_observations(3000)
assert spec.shape == (5, 129, 214)
assert spec.dtype == np.float64
22 changes: 21 additions & 1 deletion test/test_minutes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from minutes import Minutes
import numpy as np

from minutes import Minutes, Conversation
import test.config as c


Expand All @@ -15,3 +17,21 @@ def test_train():

def test_parents():
assert Minutes.parents == ['cnn']


def test_phrases():
for model_name in Minutes.parents:
minutes = Minutes(parent=model_name)
minutes.add_speaker(c.SPEAKER1)
minutes.add_speaker(c.SPEAKER2)
minutes.fit()

# Predict new phrases (make sure we ony predict once per obs)
conversation = Conversation(c.CONVERSATION_AUDIO, minutes)
raw, _ = conversation.get_observations(**minutes.preprocessing_params)
assert len(conversation.phrases) == len(raw)
print(conversation.phrases)

# Make sure we ony predicted on speaker 1 and 2.
names = [p.speaker.name for p in conversation.phrases]
assert sorted(list(np.unique(names))) == ['speaker1', 'speaker2']