From 62698095b167dd0ccd6f68e12cc45dab3f03d4d1 Mon Sep 17 00:00:00 2001 From: Sevag Hanssian Date: Tue, 10 Aug 2021 10:04:39 -0400 Subject: [PATCH] Remove float32 dtypes Needs float64 or it gets large errors in regression tests --- museval/metrics.py | 6 +++--- tests/test_regression.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/museval/metrics.py b/museval/metrics.py index 34566a8..0a4e687 100644 --- a/museval/metrics.py +++ b/museval/metrics.py @@ -557,7 +557,7 @@ def _compute_reference_correlations(reference_sources, filters_len): if use_cupy: try: - sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources, dtype=np.float32), n=n_fft, axis=2)) + sf = cupy.asnumpy(cupyx.scipy.fft.rfft(cupy.asarray(reference_sources), n=n_fft, axis=2)) except cupy.cuda.memory.OutOfMemoryError: sf = scipy.fft.rfft(reference_sources, n=n_fft, axis=2) else: @@ -647,8 +647,8 @@ def _compute_projection_filters(G, sf, estimated_source): if use_cupy: try: - D_gpu = cupy.asarray(D, dtype=np.float32) - G_gpu = cupy.asarray(G, dtype=np.float32) + D_gpu = cupy.asarray(D) + G_gpu = cupy.asarray(G) # Distortion filters try: diff --git a/tests/test_regression.py b/tests/test_regression.py index 820be53..b2b6df4 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -5,8 +5,6 @@ import museval import numpy as np -#museval.disable_cupy() - @pytest.fixture() def mus():