Skip to content

Commit

Permalink
Improve AFS branch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
tforest committed Jul 18, 2024
1 parent e0832e9 commit 8f5fa02
Showing 1 changed file with 30 additions and 63 deletions.
93 changes: 30 additions & 63 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
return A


# Timewindows test
def naive_branch_general_stat(
ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True
):
Expand All @@ -153,6 +152,9 @@ def naive_branch_general_stat(
drop_time_windows = time_windows is None
if time_windows is None:
time_windows = [0.0, np.inf]
else:
if time_windows[0] != 0:
time_windows = [0] + time_windows
n, k = w.shape
tw = len(time_windows) - 1
# hack to determine m
Expand Down Expand Up @@ -180,7 +182,7 @@ def naive_branch_general_stat(
for u in tree.nodes()
)
sigma[tree.index, j, :] = s * tree.span
for j in range(1, len(time_windows) - 1):
for j in range(1, tw):
sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :]
if isinstance(windows, str) and windows == "trees":
# need to average across the windows
Expand All @@ -191,48 +193,11 @@ def naive_branch_general_stat(
else:
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
if drop_time_windows:
# beware: this assumes the first dimension is windows
assert out.shape[1] == 1
assert out.shape[1] == 3
out = out[:, 0]
return out


# Previous version without tw
# def naive_branch_general_stat(
# ts, w, f, windows=None, polarised=False, span_normalise=True
# ):
# if windows is None:
# windows = [0.0, ts.sequence_length]
# n, k = w.shape
# # hack to determine m
# m = len(f(w[0]))
# total = np.sum(w, axis=0)

# sigma = np.zeros((ts.num_trees, m))
# for tree in ts.trees():
# x = np.zeros((ts.num_nodes, k))
# x[ts.samples()] = w
# for u in tree.nodes(order="postorder"):
# for v in tree.children(u):
# x[u] += x[v]
# if polarised:
# s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
# else:
# s = sum(
# tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
# for u in tree.nodes()
# )
# sigma[tree.index] = s * tree.span
# if isinstance(windows, str) and windows == "trees":
# # need to average across the windows
# if span_normalise:
# for j, tree in enumerate(ts.trees()):
# sigma[j] /= tree.span
# return sigma
# else:
# return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)


def branch_general_stat(
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
):
Expand Down Expand Up @@ -313,7 +278,6 @@ def polarised_summary(u):
# for the next tree
break

# print("window_index:", window_index, windows.shape)
assert window_index == windows.shape[0] - 1
if span_normalise:
for j in range(num_windows):
Expand Down Expand Up @@ -3644,14 +3608,19 @@ def naive_branch_allele_frequency_spectrum(
drop_windows = windows is None
if windows is None:
windows = [0.0, ts.sequence_length]
else:
if windows[0] != 0:
windows = [0] + windows
drop_time_windows = time_windows is None
if time_windows is None:
time_windows = [0.0, np.inf]
else:
if time_windows[0] != 0:
time_windows = [0] + time_windows
windows = ts.parse_windows(windows)
num_windows = len(windows) - 1
num_time_windows = len(time_windows) - 1
out_dim = [1 + len(sample_set) for sample_set in sample_sets]
out = np.zeros([num_windows] + out_dim)
out = np.zeros([num_windows] + [num_time_windows] + out_dim)
for j in range(num_windows):
begin = windows[j]
Expand Down Expand Up @@ -3689,15 +3658,11 @@ def naive_branch_allele_frequency_spectrum(
out[j, k, :] = S

if drop_time_windows:
# beware: this assumes the first dimension is windows
assert out.shape[1] == 1
assert out.ndim == 2 + len(out_dim)
out = out[:, 0]
elif drop_windows:
# drop windows dim if only using time windows
assert out.shape[0] == 1
out = out[0]
# assert out.shape[0] == 1
# Warning: when using Windows and TimeWindows,
# the output has three dimensions
return out


Expand Down Expand Up @@ -3756,14 +3721,14 @@ def branch_allele_frequency_spectrum(
last_update = np.zeros(ts.num_nodes)
window_index = 0
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
branch_length = np.zeros(ts.num_nodes)
# branch_length = np.zeros(ts.num_nodes)
tree_index = 0

def update_result(window_index, u, right, time_windows):
def update_result(window_index, u, right):
for k_tw, _ in enumerate(time_windows[:-1]):
if 0 < count[u, -1] < ts.num_samples:
# interval between child and parent inside the window
t_v = branch_length[u] + time[u]
# t_v = branch_length[u] + time[u]
t_v = time[parent[u]]
tw_branch_length = min(time_windows[k_tw + 1], t_v) - max(
time_windows[0], time[u]
)
Expand All @@ -3779,21 +3744,21 @@ def update_result(window_index, u, right, time_windows):
for edge in edges_out:
u = edge.child
v = edge.parent
update_result(window_index, u, t_left, time_windows)
update_result(window_index, u, t_left)
while v != -1:
update_result(window_index, v, t_left, time_windows)
update_result(window_index, v, t_left)
count[v] -= count[u]
v = parent[v]
parent[u] = -1
branch_length[u] = 0
# branch_length[u] = 0

for edge in edges_in:
u = edge.child
v = edge.parent
parent[u] = v
branch_length[u] = time[v] - time[u]
# branch_length[u] = time[v] - time[u]
while v != -1:
update_result(window_index, v, t_left, time_windows)
update_result(window_index, v, t_left)
count[v] += count[u]
v = parent[v]

Expand All @@ -3812,7 +3777,7 @@ def update_result(window_index, u, right, time_windows):
# non-zero branches, but this would add a O(log n) cost to each edge
# insertion and removal and a lot of complexity to the C implementation.
for u in range(ts.num_nodes):
update_result(window_index, u, w_right, time_windows)
update_result(window_index, u, w_right)
window_index += 1
tree_index += 1

Expand All @@ -3822,13 +3787,12 @@ def update_result(window_index, u, right, time_windows):
result[j] /= windows[j + 1] - windows[j]

if drop_time_windows:
# beware: this assumes the first dimension is windows
assert result.ndim == 2 + len(out_dim)
assert result.shape[1] == 1
result = result[:, 0]
elif drop_windows:
# drop windows dim if only using time windows
assert result.shape[0] == 1
result = result[0]
# assert out.shape[0] == 1
return result


Expand Down Expand Up @@ -6952,6 +6916,9 @@ def test_afs_branch(self):
self.assertArrayAlmostEqual(sfs1_tws, sfs1_tws_opti)
# test if time windows and windows obtained with naive version
# and opti are equal
# Warning: when using Windows and TimeWindows,
# the output has three dimensions
self.assertArrayAlmostEqual(sfs1_w_tw, sfs1_w_tw_opti)
# dimmensions Tests
assert sfs1_tws.ndim == sfs1_w.ndim
# dimensions are dim1: windows ; dim2: time_windows ;
# dim3-or-more: num_sample_sets
assert sfs1_w_tw.ndim == 3

0 comments on commit 8f5fa02

Please sign in to comment.