Skip to content

Commit

Permalink
Make cupy optional and add helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sevagh committed Aug 10, 2021
1 parent f3b7540 commit 178f294
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ data/*
!data/fetch.sh
!data/decode.sh
Estimates/
*.json
1 change: 1 addition & 0 deletions museval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from . aggregate import TrackStore, MethodStore, EvalStore, json2df
from . import metrics
from . metrics import disable_cupy, clear_cupy_cache


def _load_track_estimates(track, estimates_dir, output_dir, ext='wav'):
Expand Down
111 changes: 82 additions & 29 deletions museval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,47 @@

import numpy as np
import scipy.fft
import cupyx
import cupy
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import itertools
import collections
import warnings

use_cupy = False
try:
import cupyx
import cupy
use_cupy = True
except ImportError:
warnings.warn('cupy not available, falling back to regular numpy', file=sys.stderr)

# The maximum allowable number of sources (prevents insane computational load)
MAX_SOURCES = 100


# allows one to disable cupy even if its available
def disable_cupy():
global use_cupy
use_cupy = False


# fft plans take up space, you might need to call this between large tracks
def clear_cupy_cache():
# cupy disable fft caching to free blocks
fft_cache = cupy.fft.config.get_plan_cache()
orig_sz = fft_cache.get_size()
orig_memsz = fft_cache.get_memsize()

# clear the cache
fft_cache.set_size(0)

cupy.get_default_memory_pool().free_all_blocks()

# cupy reenable fft caching
fft_cache.set_size(orig_sz)
fft_cache.set_memsize(orig_memsz)


def validate(reference_sources, estimated_sources):
"""Checks that the input data to a metric are valid, and throws helpful
errors if not.
Expand Down Expand Up @@ -526,10 +555,13 @@ def _compute_reference_correlations(reference_sources, filters_len):
reference_sources = _zeropad(reference_sources, filters_len - 1, axis=2)
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))

try:
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2))
except cupy.cuda.memory.OutOfMemoryError:
sf = scipy.fft.rfft(reference_sources, n=nfft, axis=2)
if use_cupy:
try:
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2))
except cupy.cuda.memory.OutOfMemoryError:
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2)
else:
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2)

# compute intercorrelation between sources
G = np.zeros((nsrc, nsrc, nchan, nchan, filters_len, filters_len))
Expand All @@ -541,10 +573,15 @@ def _compute_reference_correlations(reference_sources, filters_len):
2
):
tmp = sf[j, c2] * np.conj(sf[i, c1])
try:
ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:

if use_cupy:
try:
ssf = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssf = scipy.fft.irfft(tmp)
else:
ssf = scipy.fft.irfft(tmp)

ss = toeplitz(
np.hstack((ssf[0], ssf[-1:-filters_len:-1])),
r=ssf[:filters_len]
Expand Down Expand Up @@ -580,9 +617,12 @@ def _compute_projection_filters(G, sf, estimated_source):
# compute its FFT
n_fft = int(2**np.ceil(np.log2(nsampl + filters_len - 1.)))

try:
sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft))
except cupy.cuda.memory.OutOfMemoryError:
if use_cupy:
try:
sef = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(estimated_source, dtype=np.float32), n=n_fft))
except cupy.cuda.memory.OutOfMemoryError:
sef = scipy.fft.rfft(estimated_source, n=n_fft)
else:
sef = scipy.fft.rfft(estimated_source, n=n_fft)

# compute the cross-correlations between sources and estimates
Expand All @@ -592,32 +632,45 @@ def _compute_projection_filters(G, sf, estimated_source):
list(range(nsrc)), list(range(nchan)), list(range(nchan))
):
tmp = sf[j, cj] * np.conj(sef[c])
try:
ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
if use_cupy:
try:
ssef = cupy.asnumpy(cupyx.scipy.fft.irfft(cupy.asarray(tmp)))
except cupy.cuda.memory.OutOfMemoryError:
ssef = scipy.fft.irfft(tmp)
else:
ssef = scipy.fft.irfft(tmp)
D[j, cj, :, c] = np.hstack((ssef[0], ssef[-1:-filters_len:-1]))

# reshape matrices to build the filters
D = D.reshape(nsrc * nchan * filters_len, nchan)
G = _reshape_G(G)

try:
D_gpu = cupy.asarray(D, dtype=np.float32)
G_gpu = cupy.asarray(G, dtype=np.float32)

# Distortion filters
if use_cupy:
try:
C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape(
nsrc, nchan, filters_len, nchan
)
except cupy.cuda.memory.OutOfMemoryError:
D_gpu = cupy.asarray(D, dtype=np.float32)
G_gpu = cupy.asarray(G, dtype=np.float32)

# Distortion filters
try:
C = cupy.asnumpy(cupy.linalg.solve(G_gpu + eps*cupy.eye(G.shape[0]), D_gpu)).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = cupy.asnumpy(cupy.linalg.lstsq(G_gpu, D_gpu))[0].reshape(
nsrc, nchan, filters_len, nchan
)
except cupy.cuda.memory.OutOfMemoryError:
try:
C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
C = np.linalg.lstsq(G, D)[0].reshape(
nsrc, nchan, filters_len, nchan
)
else:
try:
C = np.linalg.solve(G + eps*cupy.eye(G.shape[0]), D).reshape(
C = np.linalg.solve(G + eps*np.eye(G.shape[0]), D).reshape(
nsrc, nchan, filters_len, nchan
)
except np.linalg.linalg.LinAlgError:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
extras_require={ # Optional
'dev': ['check-manifest'],
'tests': ['pytest'],
'cupy': ['cupy-cuda114'],
'docs': [
'sphinx',
'sphinx_rtd_theme',
Expand Down
2 changes: 2 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import museval
import numpy as np

#museval.disable_cupy()


@pytest.fixture()
def mus():
Expand Down

0 comments on commit 178f294

Please sign in to comment.