diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index 91b84e84b2..bf5118a5d4 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -163,7 +163,7 @@ def span_normalise_windows(D, windows): def branch_divergence_matrix(ts, windows=None, samples=None, span_normalise=True): windows_specified = windows is not None - windows = [0, ts.sequence_length] if windows is None else windows + windows = ts.parse_windows(windows) num_windows = len(windows) - 1 samples = ts.samples() if samples is None else samples @@ -296,7 +296,7 @@ def group_alleles(genotypes, num_alleles): def site_divergence_matrix(ts, windows=None, samples=None, span_normalise=True): windows_specified = windows is not None - windows = [0, ts.sequence_length] if windows is None else windows + windows = ts.parse_windows(windows) num_windows = len(windows) - 1 samples = ts.samples() if samples is None else samples @@ -954,6 +954,8 @@ def check(self, ts, num_threads, *, windows, samples=None, mode=None): ([5, 7, 9, 20],), ([5.1, 5.2, 5.3, 5.5, 6],), ([5.1, 5.2, 6.5],), + ("trees",), + ("sites",), ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -968,6 +970,8 @@ def test_all_trees(self, num_threads, windows, mode): [ ([0, 26],), (None,), + ("trees",), + ("sites",), ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -984,6 +988,8 @@ def test_all_trees_samples(self, samples, windows, mode): ([50, 75, 95, 100],), ([0, 50, 75, 95],), (list(range(100)),), + ("trees",), + ("sites",), ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index cf0b9c78c1..31ad272a05 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1537,6 +1537,8 @@ def test_divergence_matrix(self): assert D.shape == (1, n, n) D = ts.divergence_matrix(windows, samples=[0, 1]) assert D.shape == (1, 2, 2) + D = ts.divergence_matrix(windows, samples=[0, 1], span_normalise=True) + assert D.shape == (1, 2, 2) with pytest.raises(TypeError, match="str"): ts.divergence_matrix(windows, span_normalise="xdf") with pytest.raises(TypeError): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0d27f1678c..2cc47b2959 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7865,12 +7865,11 @@ def divergence_matrix( span_normalise=True, ): windows_specified = windows is not None - windows = [0, self.sequence_length] if windows is None else windows - + windows = self.parse_windows(windows) mode = "site" if mode is None else mode - # NOTE: maybe we want to use a different default for num_threads here, just - # following the approach in GNN + # FIXME this logic should be merged into __run_windowed_stat if + # we generalise the num_threads argument to all stats. if num_threads <= 0: D = self._ll_tree_sequence.divergence_matrix( windows,