Skip to content

Commit

Permalink
passes tests!
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 26, 2024
1 parent cd741d5 commit 85e7531
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 24 deletions.
8 changes: 7 additions & 1 deletion c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,12 @@ test_empty_genetic_relatedness_vector(void)
&ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED);
CU_ASSERT_EQUAL_FATAL(ret, 0);

windows[0] = 0.5 * tsk_treeseq_get_sequence_length(&ts);
windows[1] = 0.75 * tsk_treeseq_get_sequence_length(&ts);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 1, windows, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_treeseq_free(&ts);
free(weights);
free(result);
Expand Down Expand Up @@ -2135,7 +2141,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void)
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = 10;
windows[0] = 12;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);
Expand Down
18 changes: 10 additions & 8 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -5197,6 +5197,7 @@ tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out)
{
fprintf(out, "Tree position state\n");
fprintf(out, "index = %d\n", (int) self->index);
fprintf(out, "interval = [%f,\t%f)\n", self->interval.left, self->interval.right);
fprintf(
out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop);
fprintf(
Expand Down Expand Up @@ -10219,8 +10220,9 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
}

valid = tsk_tree_position_next(&tree_pos);
j = (tsk_size_t) tree_pos.in.start - 1;
k = (tsk_size_t) tree_pos.out.start - 1;
j = (tsk_size_t) tree_pos.in.start;
k = (tsk_size_t) tree_pos.out.start;
tsk_tree_position_print_state(&tree_pos, stdout);
while (m < self->num_windows) {
if (valid && self->position == tree_pos.interval.left) {
for (k = (tsk_size_t) tree_pos.out.start;
Expand All @@ -10240,13 +10242,13 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
valid = tsk_tree_position_next(&tree_pos);
}
next_position = windows[m + 1];
if (j + 1 < num_edges) {
next_position = TSK_MIN(next_position, edge_left[tree_pos.in.order[j + 1]]);
if (j < num_edges) {
next_position = TSK_MIN(next_position, edge_left[tree_pos.in.order[j]]);
}
if (k + 1 < num_edges) {
next_position
= TSK_MIN(next_position, edge_right[tree_pos.out.order[k + 1]]);
if (k < num_edges) {
next_position = TSK_MIN(next_position, edge_right[tree_pos.out.order[k]]);
}
tsk_bug_assert(self->position < next_position);
self->position = next_position;
if (self->position == windows[m + 1]) {
out = GET_2D_ROW(self->result, self->num_weights * n, m);
Expand Down Expand Up @@ -10278,7 +10280,7 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
goto out;
}
ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN);
ret = tsk_treeseq_check_windows(self, num_windows, windows, 0);
if (ret != 0) {
goto out;
}
Expand Down
71 changes: 56 additions & 15 deletions python/tests/test_relatedness_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def run(self):
edges_parent = self.edges_parent
edges_child = self.edges_child
tree_pos = self.tree_pos
in_order = tree_pos.in_range.order
out_order = tree_pos.out_range.order
num_windows = len(self.windows) - 1
out = np.zeros((num_windows,) + self.sample_weights.shape)

Expand All @@ -239,7 +241,7 @@ def run(self):

# seek to first window
for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, 1):
e = tree_pos.in_range.order[j]
e = in_order[j]
if edges_left[e] <= self.position and self.position < edges_right[e]:
p = edges_parent[e]
c = edges_child[e]
Expand All @@ -251,26 +253,23 @@ def run(self):
while m < num_windows:
if valid and self.position == tree_pos.interval.left:
for k in range(tree_pos.out_range.start, tree_pos.out_range.stop, 1):
e = tree_pos.out_range.order[k]
e = out_order[k]
p = edges_parent[e]
c = edges_child[e]
self.remove_edge(p, c)
for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, 1):
e = tree_pos.in_range.order[j]
e = in_order[j]
p = edges_parent[e]
c = edges_child[e]
self.insert_edge(p, c)
assert self.parent[p] == tskit.NULL or self.x[p] == self.position
valid = tree_pos.next()
next_position = self.windows[m + 1]
if j + 1 < M:
next_position = min(
next_position, edges_left[tree_pos.in_range.order[j + 1]]
)
next_position = min(next_position, edges_left[in_order[j + 1]])
if k + 1 < M:
next_position = min(
next_position, edges_right[tree_pos.out_range.order[k + 1]]
)
next_position = min(next_position, edges_right[out_order[k + 1]])
assert self.position < next_position
self.position = next_position
if self.position == self.windows[m + 1]:
out[m] = self.write_output()
Expand Down Expand Up @@ -321,16 +320,27 @@ def relatedness_vector(ts, sample_weights, windows=None, **kwargs):


def relatedness_matrix(ts, windows, centre):
use_windows = windows
drop_first = windows is not None and windows[0] > 0
if drop_first:
use_windows = np.concatenate([[0], np.array(use_windows)])
drop_last = windows is not None and windows[-1] < ts.sequence_length
if drop_last:
use_windows = np.concatenate([np.array(use_windows), [ts.sequence_length]])
Sigma = ts.genetic_relatedness(
sample_sets=[[i] for i in ts.samples()],
indexes=[(i, j) for i in range(ts.num_samples) for j in range(ts.num_samples)],
windows=windows,
windows=use_windows,
mode="branch",
span_normalise=False,
proportion=False,
centre=centre,
)
if windows is not None:
if drop_first:
Sigma = Sigma[1:]
if drop_last:
Sigma = Sigma[:-1]
shape = (len(windows) - 1, ts.num_samples, ts.num_samples)
else:
shape = (ts.num_samples, ts.num_samples)
Expand Down Expand Up @@ -380,6 +390,10 @@ def check_relatedness_vector(
rng = np.random.default_rng(seed=seed)
if num_windows == 0:
windows = None
elif num_windows % 2 == 0:
windows = np.linspace(
0.2 * ts.sequence_length, 0.8 * ts.sequence_length, num_windows + 1
)
else:
windows = np.linspace(0, ts.sequence_length, num_windows + 1)
for k in range(n):
Expand Down Expand Up @@ -430,7 +444,7 @@ def test_bad_windows(self):
@pytest.mark.parametrize("n", [2, 3, 5])
@pytest.mark.parametrize("seed", range(1, 4))
@pytest.mark.parametrize("centre", (True, False))
@pytest.mark.parametrize("num_windows", (0, 1, 2))
@pytest.mark.parametrize("num_windows", (0, 1, 2, 3))
def test_small_internal_checks(self, n, seed, centre, num_windows):
ts = msprime.sim_ancestry(
n,
Expand All @@ -440,12 +454,14 @@ def test_small_internal_checks(self, n, seed, centre, num_windows):
random_seed=seed,
)
assert ts.num_trees >= 2
check_relatedness_vector(ts, internal_checks=True, centre=centre)
check_relatedness_vector(
ts, num_windows=num_windows, internal_checks=True, centre=centre
)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
@pytest.mark.parametrize("seed", range(1, 5))
@pytest.mark.parametrize("centre", (True, False))
@pytest.mark.parametrize("num_windows", (0, 1, 3))
@pytest.mark.parametrize("num_windows", (0, 1, 2, 3))
def test_simple_sims(self, n, seed, centre, num_windows):
ts = msprime.sim_ancestry(
n,
Expand All @@ -460,6 +476,31 @@ def test_simple_sims(self, n, seed, centre, num_windows):
ts, num_windows=num_windows, centre=centre, verbosity=0
)

def test_simple_sims_windows(self):
L = 100
ts = msprime.sim_ancestry(
5,
ploidy=1,
population_size=20,
sequence_length=L,
recombination_rate=0.01,
random_seed=345,
)
assert ts.num_trees >= 2
W = np.linspace(0, 1, 2 * ts.num_samples).reshape((ts.num_samples, 2))
kwargs = {"centre": False, "mode": "branch"}
total = ts.genetic_relatedness_vector(W, **kwargs)
for windows in [[0, L], [0, L / 3, L / 2, L]]:
pieces = ts.genetic_relatedness_vector(W, windows=windows, **kwargs)
np.testing.assert_allclose(total, pieces.sum(axis=0), atol=1e-13)
assert len(pieces) == len(windows) - 1
for k in range(len(pieces)):
piece = ts.genetic_relatedness_vector(
W, windows=windows[k : k + 2], **kwargs
)
assert piece.shape[0] == 1
np.testing.assert_allclose(piece[0], pieces[k], atol=1e-13)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
@pytest.mark.parametrize("centre", (True, False))
def test_single_balanced_tree(self, n, centre):
Expand All @@ -478,7 +519,7 @@ def test_internal_sample(self, centre):

@pytest.mark.parametrize("seed", range(1, 5))
@pytest.mark.parametrize("centre", (True, False))
@pytest.mark.parametrize("num_windows", (0, 1, 2))
@pytest.mark.parametrize("num_windows", (0, 1, 2, 3))
def test_one_internal_sample_sims(self, seed, centre, num_windows):
ts = msprime.sim_ancestry(
10,
Expand All @@ -498,7 +539,7 @@ def test_one_internal_sample_sims(self, seed, centre, num_windows):
check_relatedness_vector(ts, num_windows=num_windows, centre=centre)

@pytest.mark.parametrize("centre", (True, False))
@pytest.mark.parametrize("num_windows", (0, 1, 2))
@pytest.mark.parametrize("num_windows", (0, 1, 2, 3))
def test_missing_flanks(self, centre, num_windows):
ts = msprime.sim_ancestry(
2,
Expand Down

0 comments on commit 85e7531

Please sign in to comment.