Skip to content

Commit

Permalink
Merge pull request #240 from benjeffery/drop-pops
Browse files Browse the repository at this point in the history
Drop pop plot
  • Loading branch information
benjeffery authored Dec 8, 2024
2 parents 9a1eee9 + 0528928 commit 5e7b45f
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 217 deletions.
61 changes: 0 additions & 61 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,67 +220,6 @@ def test_multi_tree_with_polytomies_example(self):
nt.assert_array_equal(t["max_internal_arity"], [3.0, 3.0])


class TestMutationFrequencies:
def example_ts(self):
demography = msprime.Demography()
demography.add_population(name="A", initial_size=10_000)
demography.add_population(name="B", initial_size=5_000)
demography.add_population(name="C", initial_size=1_000)
demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C")
return msprime.sim_ancestry(
samples={"A": 1, "B": 1},
demography=demography,
random_seed=12,
sequence_length=10_000,
)

def compute_mutation_counts(self, ts):
pop_mutation_count = np.zeros((ts.num_populations, ts.num_mutations), dtype=int)
for pop in ts.populations():
for tree in ts.trees(tracked_samples=ts.samples(population=pop.id)):
for mut in tree.mutations():
count = tree.num_tracked_samples(mut.node)
pop_mutation_count[pop.id, mut.id] = count
return pop_mutation_count

def check_ts(self, ts):
C1 = self.compute_mutation_counts(ts)
C2 = preprocess.compute_population_mutation_counts(ts)
nt.assert_array_equal(C1, C2)
m = preprocess.mutations(ts)
nt.assert_array_equal(m["pop_A_freq"], C1[0] / ts.num_samples)
nt.assert_array_equal(m["pop_B_freq"], C1[1] / ts.num_samples)
nt.assert_array_equal(m["pop_C_freq"], C1[2] / ts.num_samples)

def test_all_nodes(self):
ts = self.example_ts()
tables = ts.dump_tables()
for u in range(ts.num_nodes - 1):
site_id = tables.sites.add_row(u, "A")
tables.mutations.add_row(site=site_id, node=u, derived_state="T")
ts = tables.tree_sequence()
self.check_ts(ts)

@pytest.mark.parametrize("seed", range(1, 7))
def test_simulated_mutations(self, seed):
ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=seed)
assert ts.num_mutations > 0
self.check_ts(ts)

def test_no_metadata_schema(self):
ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43)
assert ts.num_mutations > 0
tables = ts.dump_tables()
tables.populations.metadata_schema = tskit.MetadataSchema(None)
self.check_ts(tables.tree_sequence())

def test_no_populations(self):
tables = single_tree_example_ts().dump_tables()
tables.populations.add_row(b"{}")
with pytest.raises(ValueError, match="must be assigned to populations"):
preprocess.mutations(tables.tree_sequence())


class TestNodeIsSample:
def test_simple_example(self):
ts = single_tree_example_ts()
Expand Down
52 changes: 1 addition & 51 deletions tsbrowse/pages/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from bokeh.models import HoverTool

from .. import config
from ..plot_helpers import center_plot_title
from ..plot_helpers import customise_ticks
from ..plot_helpers import filter_points
from ..plot_helpers import hover_points
Expand Down Expand Up @@ -130,70 +129,21 @@ def get_mut_data(x_range, y_range, index):
mut_data = filtered_data.loc[index[0]]
return mut_data

def update_pop_freq_plot(x_range, y_range, index):
if not index:
return hv.Bars([], "population", "frequency").opts(
title="Population frequencies",
default_tools=[],
tools=["hover"],
hooks=[center_plot_title],
)

mut_data = get_mut_data(x_range, y_range, index)
pops = [col for col in mut_data.index if "pop_" in col]

if pops:
df = pd.DataFrame(
{
"population": [
pop.replace("pop_", "").replace("_freq", "") for pop in pops
],
"frequency": [mut_data[col] for col in pops],
}
)
df = df[df["frequency"] > 0]

bars = hv.Bars(df, "population", "frequency").opts(
framewise=True,
title=f"Mutation {mut_data['id']}",
ylim=(0, max(df["frequency"]) * 1.1),
xrotation=45,
tools=["hover"],
default_tools=[],
yticks=3,
yformatter="%.3f",
hooks=[center_plot_title],
)
return bars
else:
return hv.Bars([], "population", "frequency").opts(
title="Population frequencies",
default_tools=[],
tools=["hover"],
hooks=[center_plot_title],
)

def update_mut_info_table(x_range, y_range, index):
if not index:
float_panel.visible = False
return hv.Table([], kdims=["mutation"], vdims=["value"])
float_panel.visible = True
mut_data = get_mut_data(x_range, y_range, index)
pops = [col for col in mut_data.index if "pop_" in col]
mut_data = mut_data.drop(pops)
mut_data["time"] = mut_data["time"].round(2)
if "log_time" in mut_data:
mut_data["log_time"] = mut_data["log_time"].round(2)
return hv.Table(mut_data.items(), kdims=["mutation"], vdims=["value"])

pop_data_dynamic = hv.DynamicMap(
update_pop_freq_plot, streams=[range_stream, selection_stream]
)
pop_data_dynamic.opts(align=("center"))
mut_info_table_dynamic = hv.DynamicMap(
update_mut_info_table, streams=[range_stream, selection_stream]
)
tap_widgets_layout = (pop_data_dynamic + mut_info_table_dynamic).cols(1)
tap_widgets_layout = mut_info_table_dynamic
float_panel = pn.layout.FloatPanel(
pn.Column(
tap_widgets_layout,
Expand Down
105 changes: 0 additions & 105 deletions tsbrowse/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import json
import pathlib
import warnings

Expand Down Expand Up @@ -100,92 +99,6 @@ def alloc_tree_position(ts):
)


@jit.numba_jit()
def _compute_population_mutation_counts(
tree_pos,
num_nodes,
num_mutations,
num_populations,
edges_parent,
edges_child,
nodes_is_sample,
nodes_population,
mutations_position,
mutations_node,
mutations_parent,
):
num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32)

pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32)
parent = np.zeros(num_nodes, dtype=np.int32) - 1

for u in range(num_nodes):
if nodes_is_sample[u]:
num_pop_samples[u, nodes_population[u]] = 1

mut_id = 0
while tree_pos.next():
for j in range(tree_pos.out_range[0], tree_pos.out_range[1]):
e = tree_pos.edge_removal_order[j]
c = edges_child[e]
p = edges_parent[e]
parent[c] = -1
u = p
while u != -1:
for k in range(num_populations):
num_pop_samples[u, k] -= num_pop_samples[c, k]
u = parent[u]

for j in range(tree_pos.in_range[0], tree_pos.in_range[1]):
e = tree_pos.edge_insertion_order[j]
p = edges_parent[e]
c = edges_child[e]
parent[c] = p
u = p
while u != -1:
for k in range(num_populations):
num_pop_samples[u, k] += num_pop_samples[c, k]
u = parent[u]

left, right = tree_pos.interval
while mut_id < num_mutations and mutations_position[mut_id] < right:
assert mutations_position[mut_id] >= left
mutation_node = mutations_node[mut_id]
for pop in range(num_populations):
pop_mutation_count[pop, mut_id] = num_pop_samples[mutation_node, pop]
mut_id += 1

return pop_mutation_count


def compute_population_mutation_counts(ts):
"""
Return a (num_populations, num_mutations) array that gives the frequency
of each mutation in each of the populations in the specified tree sequence.
"""
logger.info(
f"Computing mutation frequencies within {ts.num_populations} populations"
)
mutations_position = ts.sites_position[ts.mutations_site].astype(int)

if np.any(ts.nodes_population[ts.samples()] == -1):
raise ValueError("Sample nodes must be assigned to populations")

return _compute_population_mutation_counts(
alloc_tree_position(ts),
ts.num_nodes,
ts.num_mutations,
ts.num_populations,
ts.edges_parent,
ts.edges_child,
node_is_sample(ts),
ts.nodes_population,
mutations_position,
ts.mutations_node,
ts.mutations_parent,
)


@dataclasses.dataclass
class MutationCounts:
num_parents: np.ndarray
Expand Down Expand Up @@ -322,23 +235,6 @@ def mutations(ts):
inherited_state[mutations_with_parent] = derived_state[parent]
mutations_inherited_state = inherited_state

population_data = {}
if ts.num_populations > 0:
pop_mutation_count = compute_population_mutation_counts(ts)
for pop in ts.populations():
name = f"pop{pop.id}"
if isinstance(pop.metadata, bytes):
try:
metadata_dict = json.loads(pop.metadata.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
metadata_dict = {}
else:
metadata_dict = pop.metadata
if "name" in metadata_dict:
name = metadata_dict["name"]
col_name = f"pop_{name}_freq"
population_data[col_name] = pop_mutation_count[pop.id] / ts.num_samples

counts = compute_mutation_counts(ts)
logger.info("Preprocessed mutations")
return {
Expand All @@ -347,7 +243,6 @@ def mutations(ts):
"num_descendants": counts.num_descendants,
"num_inheritors": counts.num_inheritors,
"num_parents": counts.num_parents,
**population_data,
}


Expand Down

0 comments on commit 5e7b45f

Please sign in to comment.