Skip to content

Commit 0452a25

Browse files
authored
Merge pull request #213 from benjeffery/fix-quant-cols
Correctly decode tszip cols
2 parents b319c7c + 8d4cf4c commit 0452a25

File tree

2 files changed

+84
-19
lines changed

2 files changed

+84
-19
lines changed

tests/test_data_model.py

+63-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import msprime
2+
import numpy as np
23
import pytest
34
import tskit
45
import tszip
@@ -9,12 +10,38 @@
910

1011

1112
def test_model(tmpdir):
13+
# Generate a tree sequence with populations and migrations
14+
N = 1000
15+
demography = msprime.Demography()
16+
demography.add_population(name="pop1", initial_size=N)
17+
demography.add_population(name="pop2", initial_size=N)
18+
demography.add_population(name="ancestral", initial_size=N)
19+
demography.set_symmetric_migration_rate(["pop1", "pop2"], 0.01)
20+
demography.add_population_split(
21+
time=1000, derived=["pop1", "pop2"], ancestral="ancestral"
22+
)
1223
ts = msprime.sim_ancestry(
13-
recombination_rate=1e-3, samples=10, sequence_length=1_000, random_seed=42
24+
samples={"pop1": 5, "pop2": 5},
25+
demography=demography,
26+
sequence_length=1e4,
27+
record_migrations=True,
28+
random_seed=42,
1429
)
15-
ts = msprime.sim_mutations(ts, rate=1e-2, random_seed=43)
30+
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=42)
31+
assert ts.num_populations > 0
32+
assert ts.num_sites > 0
33+
assert ts.num_migrations > 0
34+
assert ts.num_mutations > 0
35+
1636
tables = ts.tables
1737
tables.nodes.metadata_schema = tskit.MetadataSchema({"codec": "json"})
38+
39+
# Give each individual a location
40+
indiv_copy = tables.individuals.copy()
41+
tables.individuals.clear()
42+
for i, ind in enumerate(indiv_copy):
43+
tables.individuals.append(ind.replace(location=[i / 2, i + 1]))
44+
1845
ts = tables.tree_sequence()
1946

2047
tszip.compress(ts, tmpdir / "test.tszip")
@@ -25,11 +52,44 @@ def test_model(tmpdir):
2552
assert tsm.name == "test"
2653
assert tsm.file_uuid == ts.file_uuid
2754
assert len(tsm.summary_df) == 9
28-
assert len(tsm.edges_df) == ts.num_edges
2955
assert len(tsm.trees_df) == ts.num_trees
56+
57+
assert len(tsm.edges_df) == ts.num_edges
58+
for col in ["left", "right", "parent", "child"]:
59+
assert np.array_equal(tsm.edges_df[col].values, getattr(ts.tables.edges, col))
60+
3061
assert len(tsm.mutations_df) == ts.num_mutations
62+
for m1, m2 in zip(ts.mutations(), tsm.mutations_df.to_dict("records")):
63+
assert m1.derived_state == m2["derived_state"]
64+
assert m1.site == m2["site"]
65+
assert m1.node == m2["node"]
66+
assert m1.parent == m2["parent"]
67+
assert m1.time == m2["time"]
68+
3169
assert len(tsm.nodes_df) == ts.num_nodes
70+
for col in ["time", "flags", "population", "individual"]:
71+
assert np.array_equal(tsm.nodes_df[col].values, getattr(ts.tables.nodes, col))
72+
3273
assert len(tsm.sites_df) == ts.num_sites
74+
for m1, m2 in zip(ts.sites(), tsm.sites_df.to_dict("records")):
75+
assert m1.ancestral_state == m2["ancestral_state"]
76+
assert m1.position == m2["position"]
77+
78+
assert len(tsm.individuals_df) == ts.num_individuals
79+
for m1, m2 in zip(ts.individuals(), tsm.individuals_df.to_dict("records")):
80+
assert m1.flags == m2["flags"]
81+
assert np.array_equal(m1.location, m2["location"])
82+
assert np.array_equal(m1.parents, m2["parents"])
83+
84+
assert len(tsm.populations_df) == ts.num_populations
85+
for m1, m2 in zip(ts.populations(), tsm.populations_df.to_dict("records")):
86+
assert m1.metadata == m2["metadata"]
87+
88+
assert len(tsm.migrations_df) == ts.num_migrations
89+
for col in ["left", "right", "node", "source", "dest", "time"]:
90+
assert np.array_equal(
91+
tsm.migrations_df[col].values, getattr(ts.tables.migrations, col)
92+
)
3393

3494

3595
def test_model_errors(tmpdir):

tsbrowse/model.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def __init__(self, tsbrowse_path):
3333
self.ts = tszip.load(tsbrowse_path)
3434
self.name = tsbrowse_path.stem
3535
self.full_path = tsbrowse_path
36+
ts_tables = self.ts.tables
3637
for table_name in [
3738
"edges",
38-
"trees",
3939
"mutations",
4040
"nodes",
4141
"sites",
@@ -44,30 +44,35 @@ def __init__(self, tsbrowse_path):
4444
"migrations",
4545
"provenances",
4646
]:
47+
ts_table = getattr(ts_tables, table_name)
4748
# filter out ragged arrays with offset
4849
array_names = set(root[table_name].keys())
4950
ragged_array_names = {
5051
"_".join(name.split("_")[:-1])
5152
for name in array_names
5253
if "offset" in name
5354
}
54-
array_names -= set(ragged_array_names)
5555
array_names -= {"metadata_schema"}
5656
array_names -= {f"{name}_offset" for name in ragged_array_names}
57-
arrays = {name: root[table_name][name][:] for name in array_names}
58-
ragged_array_names -= {"metadata"}
59-
for name in ragged_array_names:
60-
array = root[table_name][name][:]
61-
offsets = root[table_name][f"{name}_offset"][:]
62-
arrays[name] = np.array(
63-
[
64-
array[s].tobytes().decode("utf-8")
65-
for s in (
66-
slice(start, end)
67-
for start, end in zip(offsets[:-1], offsets[1:])
68-
)
69-
]
70-
)
57+
arrays = {}
58+
for name in array_names:
59+
if hasattr(ts_table, name):
60+
if name in ragged_array_names:
61+
arrays[name] = [
62+
getattr(row, name) for row in getattr(self.ts, table_name)()
63+
]
64+
else:
65+
arrays[name] = getattr(ts_table, name)
66+
else:
67+
arrays[name] = root[table_name][name][:]
68+
df = pd.DataFrame(arrays)
69+
df["id"] = df.index
70+
setattr(self, f"{table_name}_df", df)
71+
72+
for table_name in ["trees"]:
73+
arrays = {
74+
name: root[table_name][name][:] for name in root[table_name].keys()
75+
}
7176
df = pd.DataFrame(arrays)
7277
df["id"] = df.index
7378
setattr(self, f"{table_name}_df", df)

0 commit comments

Comments
 (0)