diff --git a/.gitignore b/.gitignore index db52ccb..2a9e0ba 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ data/* !data/fetch.sh !data/decode.sh Estimates/ +*.json diff --git a/museval/__init__.py b/museval/__init__.py index 346e748..d383fe2 100644 --- a/museval/__init__.py +++ b/museval/__init__.py @@ -9,6 +9,7 @@ import pandas as pd from . aggregate import TrackStore, MethodStore, EvalStore, json2df from . import metrics +from . metrics import disable_cupy, clear_cupy_cache def _load_track_estimates(track, estimates_dir, output_dir, ext='wav'): diff --git a/museval/aggregate.py b/museval/aggregate.py index d414ced..a6e8870 100644 --- a/museval/aggregate.py +++ b/museval/aggregate.py @@ -258,7 +258,7 @@ def save(self, path): def __repr__(self): targets = self.df['target'].unique() - out = "Aggrated Scores ({} over frames, {} over tracks)\n".format( + out = "Aggregated Scores ({} over frames, {} over tracks)\n".format( self.frames_agg, self.tracks_agg ) for target in targets: diff --git a/museval/metrics.py b/museval/metrics.py index 7bfb78b..20eff98 100644 --- a/museval/metrics.py +++ b/museval/metrics.py @@ -46,17 +46,49 @@ 1706, IRISA, April 2005.""" import numpy as np -import scipy.fftpack +import scipy.fft from scipy.linalg import toeplitz from scipy.signal import fftconvolve import itertools import collections import warnings +import sys + +use_cupy = False +try: + import cupyx + import cupy + use_cupy = True +except ImportError: + warnings.warn('cupy not available, falling back to regular numpy') # The maximum allowable number of sources (prevents insane computational load) MAX_SOURCES = 100 +# allows one to disable cupy even if its available +def disable_cupy(): + global use_cupy + use_cupy = False + + +# fft plans take up space, you might need to call this between large tracks +def clear_cupy_cache(): + # cupy disable fft caching to free blocks + fft_cache = cupy.fft.config.get_plan_cache() + orig_sz = fft_cache.get_size() + orig_memsz = fft_cache.get_memsize() + + # clear the cache + fft_cache.set_size(0) + + cupy.get_default_memory_pool().free_all_blocks() + + # cupy reenable fft caching + fft_cache.set_size(orig_sz) + fft_cache.set_memsize(orig_memsz) + + def validate(reference_sources, estimated_sources): """Checks that the input data to a metric are valid, and throws helpful errors if not. @@ -523,25 +555,41 @@ def _compute_reference_correlations(reference_sources, filters_len): # zero padding and FFT of references reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2) n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=2) + + if use_cupy: + try: + sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources), n=n_fft, axis=2)) + except cupy.cuda.memory.OutOfMemoryError: + sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2) + else: + sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2) # compute intercorrelation between sources G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len)) + for ((i, c1), (j, c2)) in itertools.combinations_with_replacement( itertools.product( list(range(nsrc)), list(range(nchan)) ), 2 ): + tmp = sf[j, c2] * np.conj(sf[i, c1]) + + if use_cupy: + try: + ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssf = scipy.fft.irfft(tmp) + else: + ssf = scipy.fft.irfft(tmp) - ssf = sf[j, c2] * np.conj(sf[i, c1]) - ssf = np.real(scipy.fftpack.ifft(ssf)) ss = toeplitz( np.hstack((ssf[0], ssf[-1:-filters_len:-1])), r=ssf[:filters_len] ) G[j, i, c2, c1] = ss G[i, j, c1, c2] = ss.T + return G, sf @@ -569,30 +617,67 @@ def _compute_projection_filters(G, sf, estimated_source): # compute its FFT n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - sef = scipy.fftpack.fft(estimated_source, n=n_fft) + + if use_cupy: + try: + sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft)) + except cupy.cuda.memory.OutOfMemoryError: + sef = scipy.fft.rfft(estimated_source, n=n_fft) + else: + sef = scipy.fft.rfft(estimated_source, n=n_fft) # compute the cross-correlations between sources and estimates D = np.zeros((nsrc, nchan, filters_len, nchan)) + for (j, cj, c) in itertools.product( list(range(nsrc)), list(range(nchan)), list(range(nchan)) ): - ssef = sf[j, cj] * np.conj(sef[c]) - ssef = np.real(scipy.fftpack.ifft(ssef)) + tmp = sf[j, cj] * np.conj(sef[c]) + if use_cupy: + try: + ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssef = scipy.fft.irfft(tmp) + else: + ssef = scipy.fft.irfft(tmp) D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1])) # reshape matrices to build the filters D = D.reshape(nsrc * nchan * filters_len, nchan) G = _reshape_G(G) - # Distortion filters - try: - C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape( - nsrc, nchan, filters_len, nchan - ) - except np.linalg.linalg.LinAlgError: - C = np.linalg.lstsq(G, D)[0].reshape( - nsrc, nchan, filters_len, nchan - ) + if use_cupy: + try: + D_gpu = cupy.asarray(D) + G_gpu = cupy.asarray(G) + + # Distortion filters + try: + C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape( + nsrc, nchan, filters_len, nchan + ) + except cupy.cuda.memory.OutOfMemoryError: + try: + C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = np.linalg.lstsq(G, D)[0].reshape( + nsrc, nchan, filters_len, nchan + ) + else: + try: + C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = np.linalg.lstsq(G, D)[0].reshape( + nsrc, nchan, filters_len, nchan + ) # if we asked for one single reference source, # return just a nchan X filters_len matrix diff --git a/setup.py b/setup.py index f990a2d..ed8b086 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ extras_require={ # Optional 'dev': ['check-manifest'], 'tests': ['pytest'], + 'cupy': ['cupy-cuda114'], 'docs': [ 'sphinx', 'sphinx_rtd_theme',