From f3b7540faddc8d2b1bc91cecaf05fd20d96df7c6 Mon Sep 17 00:00:00 2001 From: Sevag Hanssian Date: Tue, 27 Apr 2021 11:40:00 -0400 Subject: [PATCH] Fall back to host numpy/scipy if CUDA oom --- museval/metrics.py | 54 +++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/museval/metrics.py b/museval/metrics.py index 8d5b701..655d38a 100644 --- a/museval/metrics.py +++ b/museval/metrics.py @@ -526,7 +526,10 @@ def _compute_reference_correlations(reference_sources, filters_len): reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2) n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2)) + try: + sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2)) + except cupy.cuda.memory.OutOfMemoryError: + sf = scipy.fft.rfft(reference_sources, n=nfft, axis=2) # compute intercorrelation between sources G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len)) @@ -537,7 +540,11 @@ def _compute_reference_correlations(reference_sources, filters_len): ), 2 ): - ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(sf[j, c2] * np.conj(sf[i, c1])))) + tmp = sf[j, c2] * np.conj(sf[i, c1]) + try: + ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssf = scipy.fft.irfft(tmp) ss = toeplitz( np.hstack((ssf[0], ssf[-1:-filters_len:-1])), r=ssf[:filters_len] @@ -573,7 +580,10 @@ def _compute_projection_filters(G, sf, estimated_source): # compute its FFT n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.))) - sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft)) + try: + sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft)) + except cupy.cuda.memory.OutOfMemoryError: + sef = scipy.fft.rfft(estimated_source, n=n_fft) # compute the cross-correlations between sources and estimates D = np.zeros((nsrc, nchan, filters_len, nchan)) @@ -581,25 +591,39 @@ def _compute_projection_filters(G, sf, estimated_source): for (j, cj, c) in itertools.product( list(range(nsrc)), list(range(nchan)), list(range(nchan)) ): - ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(sf[j, cj] * np.conj(sef[c])))) + tmp = sf[j, cj] * np.conj(sef[c]) + try: + ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp))) + except cupy.cuda.memory.OutOfMemoryError: + ssef = scipy.fft.irfft(tmp) D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1])) # reshape matrices to build the filters D = D.reshape(nsrc * nchan * filters_len, nchan) G = _reshape_G(G) - D_gpu = cupy.asarray(D, dtype=np.float32) - G_gpu = cupy.asarray(G, dtype=np.float32) - - # Distortion filters try: - C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape( - nsrc, nchan, filters_len, nchan - ) - except np.linalg.linalg.LinAlgError: - C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape( - nsrc, nchan, filters_len, nchan - ) + D_gpu = cupy.asarray(D, dtype=np.float32) + G_gpu = cupy.asarray(G, dtype=np.float32) + + # Distortion filters + try: + C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape( + nsrc, nchan, filters_len, nchan + ) + except cupy.cuda.memory.OutOfMemoryError: + try: + C = np.linalg.solve(G + eps*cupy.eye(G.shape[0]), D).reshape( + nsrc, nchan, filters_len, nchan + ) + except np.linalg.linalg.LinAlgError: + C = np.linalg.lstsq(G, D)[0].reshape( + nsrc, nchan, filters_len, nchan + ) # if we asked for one single reference source, # return just a nchan X filters_len matrix