Skip to content

Commit

Permalink
Remove float32 dtypes
Browse files Browse the repository at this point in the history
Needs float64 or it gets large errors in regression tests
  • Loading branch information
sevagh committed Aug 10, 2021
1 parent 178f294 commit 6269809
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 3 additions & 3 deletions museval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _compute_reference_correlations(reference_sources, filters_len):

if use_cupy:
try:
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2))
sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources), n=n_fft, axis=2))
except cupy.cuda.memory.OutOfMemoryError:
sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2)
else:
Expand Down Expand Up @@ -647,8 +647,8 @@ def _compute_projection_filters(G, sf, estimated_source):

if use_cupy:
try:
D_gpu = cupy.asarray(D, dtype=np.float32)
G_gpu = cupy.asarray(G, dtype=np.float32)
D_gpu = cupy.asarray(D)
G_gpu = cupy.asarray(G)

# Distortion filters
try:
Expand Down
2 changes: 0 additions & 2 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import museval
import numpy as np

#museval.disable_cupy()


@pytest.fixture()
def mus():
Expand Down

0 comments on commit 6269809

Please sign in to comment.