Skip to content

Commit

Permalink
Fall back to host numpy/scipy if CUDA oom
Browse files Browse the repository at this point in the history
  • Loading branch information
sevagh committed Apr 27, 2021
1 parent 65f6843 commit f3b7540
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions museval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,10 @@ def _compute_reference_correlations(reference_sources, filters_len):
reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2)
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))

sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2))
try:
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2))
except cupy.cuda.memory.OutOfMemoryError:
sf = scipy.fft.rfft(reference_sources, n=nfft, axis=2)

# compute intercorrelation between sources
G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len))
Expand All @@ -537,7 +540,11 @@ def _compute_reference_correlations(reference_sources, filters_len):
),
2
):
ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(sf[j, c2] * np.conj(sf[i, c1]))))
tmp = sf[j, c2] * np.conj(sf[i, c1])
try:
ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssf = scipy.fft.irfft(tmp)
ss = toeplitz(
np.hstack((ssf[0], ssf[-1:-filters_len:-1])),
r=ssf[:filters_len]
Expand Down Expand Up @@ -573,33 +580,50 @@ def _compute_projection_filters(G, sf, estimated_source):
# compute its FFT
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))

sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft))
try:
sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft))
except cupy.cuda.memory.OutOfMemoryError:
sef = scipy.fft.rfft(estimated_source, n=n_fft)

# compute the cross-correlations between sources and estimates
D = np.zeros((nsrc, nchan, filters_len, nchan))

for (j, cj, c) in itertools.product(
list(range(nsrc)), list(range(nchan)), list(range(nchan))
):
ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(sf[j, cj] * np.conj(sef[c]))))
tmp = sf[j, cj] * np.conj(sef[c])
try:
ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssef = scipy.fft.irfft(tmp)
D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1]))

# reshape matrices to build the filters
D = D.reshape(nsrc * nchan * filters_len, nchan)
G = _reshape_G(G)

D_gpu = cupy.asarray(D, dtype=np.float32)
G_gpu = cupy.asarray(G, dtype=np.float32)

# Distortion filters
try:
C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape(
nsrc, nchan, filters_len, nchan
)
D_gpu = cupy.asarray(D, dtype=np.float32)
G_gpu = cupy.asarray(G, dtype=np.float32)

# Distortion filters
try:
C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape(
nsrc, nchan, filters_len, nchan
)
except cupy.cuda.memory.OutOfMemoryError:
try:
C = np.linalg.solve(G + eps*cupy.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(
nsrc, nchan, filters_len, nchan
)

# if we asked for one single reference source,
# return just a nchan X filters_len matrix
Expand Down

0 comments on commit f3b7540

Please sign in to comment.