Skip to content

Commit

Permalink
Add support for string window-specs to divmat
Browse files Browse the repository at this point in the history
Closes #2791
  • Loading branch information
jeromekelleher committed Aug 15, 2023
1 parent 844083f commit 893484e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 8 additions & 2 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def span_normalise_windows(D, windows):

def branch_divergence_matrix(ts, windows=None, samples=None, span_normalise=True):
windows_specified = windows is not None
windows = [0, ts.sequence_length] if windows is None else windows
windows = ts.parse_windows(windows)
num_windows = len(windows) - 1
samples = ts.samples() if samples is None else samples

Expand Down Expand Up @@ -296,7 +296,7 @@ def group_alleles(genotypes, num_alleles):

def site_divergence_matrix(ts, windows=None, samples=None, span_normalise=True):
windows_specified = windows is not None
windows = [0, ts.sequence_length] if windows is None else windows
windows = ts.parse_windows(windows)
num_windows = len(windows) - 1
samples = ts.samples() if samples is None else samples

Expand Down Expand Up @@ -954,6 +954,8 @@ def check(self, ts, num_threads, *, windows, samples=None, mode=None):
([5, 7, 9, 20],),
([5.1, 5.2, 5.3, 5.5, 6],),
([5.1, 5.2, 6.5],),
("trees",),
("sites",),
],
)
@pytest.mark.parametrize("mode", DIVMAT_MODES)
Expand All @@ -968,6 +970,8 @@ def test_all_trees(self, num_threads, windows, mode):
[
([0, 26],),
(None,),
("trees",),
("sites",),
],
)
@pytest.mark.parametrize("mode", DIVMAT_MODES)
Expand All @@ -984,6 +988,8 @@ def test_all_trees_samples(self, samples, windows, mode):
([50, 75, 95, 100],),
([0, 50, 75, 95],),
(list(range(100)),),
("trees",),
("sites",),
],
)
@pytest.mark.parametrize("mode", DIVMAT_MODES)
Expand Down
7 changes: 3 additions & 4 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7865,12 +7865,11 @@ def divergence_matrix(
span_normalise=True,
):
windows_specified = windows is not None
windows = [0, self.sequence_length] if windows is None else windows

windows = self.parse_windows(windows)
mode = "site" if mode is None else mode

# NOTE: maybe we want to use a different default for num_threads here, just
# following the approach in GNN
# FIXME this logic should be merged into __run_windowed_stat if
# we generalise the num_threads argument to all stats.
if num_threads <= 0:
D = self._ll_tree_sequence.divergence_matrix(
windows,
Expand Down

0 comments on commit 893484e

Please sign in to comment.