Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional cupy for faster BSS evaluation #84

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ data/*
!data/fetch.sh
!data/decode.sh
Estimates/
*.json
1 change: 1 addition & 0 deletions museval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from . aggregate import TrackStore, MethodStore, EvalStore, json2df
from . import metrics
from . metrics import disable_cupy, clear_cupy_cache


def _load_track_estimates(track, estimates_dir, output_dir, ext='wav'):
Expand Down
2 changes: 1 addition & 1 deletion museval/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def save(self, path):

def __repr__(self):
targets = self.df['target'].unique()
out = "Aggrated Scores ({} over frames, {} over tracks)\n".format(
out = "Aggregated Scores ({} over frames, {} over tracks)\n".format(
self.frames_agg, self.tracks_agg
)
for target in targets:
Expand Down
117 changes: 101 additions & 16 deletions museval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,49 @@
1706, IRISA, April 2005."""

import numpy as np
import scipy.fftpack
import scipy.fft
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import itertools
import collections
import warnings
import sys

use_cupy = False
try:
import cupyx
import cupy
use_cupy = True
except ImportError:
warnings.warn('cupy not available, falling back to regular numpy')

# The maximum allowable number of sources (prevents insane computational load)
MAX_SOURCES = 100


# allows one to disable cupy even if its available
def disable_cupy():
global use_cupy
use_cupy = False


# fft plans take up space, you might need to call this between large tracks
def clear_cupy_cache():
# cupy disable fft caching to free blocks
fft_cache = cupy.fft.config.get_plan_cache()
orig_sz = fft_cache.get_size()
orig_memsz = fft_cache.get_memsize()

# clear the cache
fft_cache.set_size(0)

cupy.get_default_memory_pool().free_all_blocks()

# cupy reenable fft caching
fft_cache.set_size(orig_sz)
fft_cache.set_memsize(orig_memsz)


def validate(reference_sources, estimated_sources):
"""Checks that the input data to a metric are valid, and throws helpful
errors if not.
Expand Down Expand Up @@ -523,25 +555,41 @@ def _compute_reference_correlations(reference_sources, filters_len):
# zero padding and FFT of references
reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2)
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))
sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=2)

if use_cupy:
try:
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources), n=n_fft, axis=2))
except cupy.cuda.memory.OutOfMemoryError:
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2)
else:
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2)

# compute intercorrelation between sources
G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len))

for ((i, c1), (j, c2)) in itertools.combinations_with_replacement(
itertools.product(
list(range(nsrc)), list(range(nchan))
),
2
):
tmp = sf[j, c2] * np.conj(sf[i, c1])

if use_cupy:
try:
ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssf = scipy.fft.irfft(tmp)
else:
ssf = scipy.fft.irfft(tmp)

ssf = sf[j, c2] * np.conj(sf[i, c1])
ssf = np.real(scipy.fftpack.ifft(ssf))
ss = toeplitz(
np.hstack((ssf[0], ssf[-1:-filters_len:-1])),
r=ssf[:filters_len]
)
G[j, i, c2, c1] = ss
G[i, j, c1, c2] = ss.T

return G, sf


Expand Down Expand Up @@ -569,30 +617,67 @@ def _compute_projection_filters(G, sf, estimated_source):

# compute its FFT
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))
sef = scipy.fftpack.fft(estimated_source, n=n_fft)

if use_cupy:
try:
sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft))
except cupy.cuda.memory.OutOfMemoryError:
sef = scipy.fft.rfft(estimated_source, n=n_fft)
else:
sef = scipy.fft.rfft(estimated_source, n=n_fft)

# compute the cross-correlations between sources and estimates
D = np.zeros((nsrc, nchan, filters_len, nchan))

for (j, cj, c) in itertools.product(
list(range(nsrc)), list(range(nchan)), list(range(nchan))
):
ssef = sf[j, cj] * np.conj(sef[c])
ssef = np.real(scipy.fftpack.ifft(ssef))
tmp = sf[j, cj] * np.conj(sef[c])
if use_cupy:
try:
ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssef = scipy.fft.irfft(tmp)
else:
ssef = scipy.fft.irfft(tmp)
D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1]))

# reshape matrices to build the filters
D = D.reshape(nsrc * nchan * filters_len, nchan)
G = _reshape_G(G)

# Distortion filters
try:
C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(
nsrc, nchan, filters_len, nchan
)
if use_cupy:
try:
D_gpu = cupy.asarray(D)
G_gpu = cupy.asarray(G)

# Distortion filters
try:
C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape(
nsrc, nchan, filters_len, nchan
)
except cupy.cuda.memory.OutOfMemoryError:
try:
C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(
nsrc, nchan, filters_len, nchan
)
else:
try:
C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(
nsrc, nchan, filters_len, nchan
)

# if we asked for one single reference source,
# return just a nchan X filters_len matrix
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
extras_require={ # Optional
'dev': ['check-manifest'],
'tests': ['pytest'],
'cupy': ['cupy-cuda114'],
'docs': [
'sphinx',
'sphinx_rtd_theme',
Expand Down