diff --git a/.gitignore b/.gitignore index db52ccb..2a9e0ba 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ data/* !data/fetch.sh !data/decode.sh Estimates/ +*.json diff --git a/museval/__init__.py b/museval/__init__.py index 346e748..d383fe2 100644 --- a/museval/__init__.py +++ b/museval/__init__.py @@ -9,6 +9,7 @@ import pandas as pd from . aggregate import TrackStore, MethodStore, EvalStore, json2df from . import metrics +from . metrics import disable_cupy, clear_cupy_cache def _load_track_estimates(track, estimates_dir, output_dir, ext='wav'): diff --git a/museval/metrics.py b/museval/metrics.py index 655d38a..34566a8 100644 --- a/museval/metrics.py +++ b/museval/metrics.py @@ -47,18 +47,47 @@ import numpy as np import scipy.fft -import cupyx -import cupy from scipy.linalg import toeplitz from scipy.signal import fftconvolve import itertools import collections import warnings +use_cupy = False +try: + import cupyx + import cupy + use_cupy = True +except ImportError: + warnings.warn('cupy not available, falling back to regular numpy', file=sys.stderr) + # The maximum allowable number of sources (prevents insane computational load) MAX_SOURCES = 100 +# allows one to disable cupy even if its available +def disable_cupy(): + global use_cupy + use_cupy = False + + +# fft plans take up space, you might need to call this between large tracks +def clear_cupy_cache(): + # cupy disable fft caching to free blocks + fft_cache = cupy.fft.config.get_plan_cache() + orig_sz = fft_cache.get_size() + orig_memsz = fft_cache.get_memsize() + + # clear the cache + fft_cache.set_size(0) + + cupy.get_default_memory_pool().free_all_blocks() + + # cupy reenable fft caching + fft_cache.set_size(orig_sz) + fft_cache.set_memsize(orig_memsz) + + def validate(reference_sources, estimated_sources): """Checks that the input data to a metric are valid, and throws helpful errors if not. @@ -526,10 +555,13 @@ def _compute_reference_correlations(reference_sources, filters_len): reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2) n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - try: - sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2)) - except cupy.cuda.memory.OutOfMemoryError: - sf = scipy.fft.rfft(reference_sources, n=nfft, axis=2) + if use_cupy: + try: + sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2)) + except cupy.cuda.memory.OutOfMemoryError: + sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2) + else: + sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2) # compute intercorrelation between sources G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len)) @@ -541,10 +573,15 @@ def _compute_reference_correlations(reference_sources, filters_len): 2 ): tmp = sf[j, c2] * np.conj(sf[i, c1]) - try: - ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) - except cupy.cuda.memory.OutOfMemoryError: + + if use_cupy: + try: + ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssf = scipy.fft.irfft(tmp) + else: ssf = scipy.fft.irfft(tmp) + ss = toeplitz( np.hstack((ssf[0], ssf[-1:-filters_len:-1])), r=ssf[:filters_len] @@ -580,9 +617,12 @@ def _compute_projection_filters(G, sf, estimated_source): # compute its FFT n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - try: - sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft)) - except cupy.cuda.memory.OutOfMemoryError: + if use_cupy: + try: + sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft)) + except cupy.cuda.memory.OutOfMemoryError: + sef = scipy.fft.rfft(estimated_source, n=n_fft) + else: sef = scipy.fft.rfft(estimated_source, n=n_fft) # compute the cross-correlations between sources and estimates @@ -592,9 +632,12 @@ def _compute_projection_filters(G, sf, estimated_source): list(range(nsrc)), list(range(nchan)), list(range(nchan)) ): tmp = sf[j, cj] * np.conj(sef[c]) - try: - ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) - except cupy.cuda.memory.OutOfMemoryError: + if use_cupy: + try: + ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssef = scipy.fft.irfft(tmp) + else: ssef = scipy.fft.irfft(tmp) D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1])) @@ -602,22 +645,32 @@ def _compute_projection_filters(G, sf, estimated_source): D = D.reshape(nsrc * nchan * filters_len, nchan) G = _reshape_G(G) - try: - D_gpu = cupy.asarray(D, dtype=np.float32) - G_gpu = cupy.asarray(G, dtype=np.float32) - - # Distortion filters + if use_cupy: try: - C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape( - nsrc, nchan, filters_len, nchan - ) - except np.linalg.linalg.LinAlgError: - C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape( - nsrc, nchan, filters_len, nchan - ) - except cupy.cuda.memory.OutOfMemoryError: + D_gpu = cupy.asarray(D, dtype=np.float32) + G_gpu = cupy.asarray(G, dtype=np.float32) + + # Distortion filters + try: + C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape( + nsrc, nchan, filters_len, nchan + ) + except cupy.cuda.memory.OutOfMemoryError: + try: + C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = np.linalg.lstsq(G, D)[0].reshape( + nsrc, nchan, filters_len, nchan + ) + else: try: - C = np.linalg.solve(G + eps*cupy.eye(G.shape[0]), D).reshape( + C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape( nsrc, nchan, filters_len, nchan ) except np.linalg.linalg.LinAlgError: diff --git a/setup.py b/setup.py index f990a2d..ed8b086 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ extras_require={ # Optional 'dev': ['check-manifest'], 'tests': ['pytest'], + 'cupy': ['cupy-cuda114'], 'docs': [ 'sphinx', 'sphinx_rtd_theme', diff --git a/tests/test_regression.py b/tests/test_regression.py index b2b6df4..820be53 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -5,6 +5,8 @@ import museval import numpy as np +#museval.disable_cupy() + @pytest.fixture() def mus():