From e2c1a714a45037c85bb11bacb1d7f2a6a31aeb71 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 19:17:31 -0400 Subject: [PATCH] TST: Fix test in `common.labels` (#747) * Fix assertion in tests/test_common/test_labels.py -- we now get back pandas.arrays.StringArray * Change dependency to 'dask[dataframe]' to squelch warning --- pyproject.toml | 2 +- tests/test_common/test_labels.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 31810f2f9..bacf95d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license = {file = "LICENSE"} dependencies = [ "attrs >=19.3.0", "crowsetta >=5.0.1", - "dask >=2.10.1", + "dask[dataframe] >=2.10.1", "evfuncs >=0.3.4", "joblib >=0.14.1", "pytorch-lightning >=2.0.7", diff --git a/tests/test_common/test_labels.py b/tests/test_common/test_labels.py index a0cdae139..3a3f794e2 100644 --- a/tests/test_common/test_labels.py +++ b/tests/test_common/test_labels.py @@ -1,7 +1,7 @@ import copy -import pathlib import numpy as np +import pandas as pd import pytest import vak.common.files.spect @@ -72,7 +72,7 @@ def test_from_df(config_type, model_name, audio_format, spect_format, annot_form out = vak.common.labels.from_df(df, dataset_path) assert isinstance(out, list) - assert all([isinstance(labels, np.ndarray) for labels in out]) + assert all([isinstance(labels, (np.ndarray, pd.arrays.StringArray)) for labels in out]) INTS_LABELMAP = {str(val): val for val in range(1, 20)}