Skip to content

Commit

Permalink
Set metadata schemas
Browse files Browse the repository at this point in the history
Defaults to a "struct" type unless a schema already exists (or if there is a null schema that can be interpreted as JSON). Fixes tskit-dev#302
  • Loading branch information
hyanwong committed Jul 23, 2023
1 parent 7282a30 commit 1bde2e7
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 86 deletions.
53 changes: 30 additions & 23 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Test cases for the python API for tsdate.
"""
import collections
import json
import logging
import unittest

Expand Down Expand Up @@ -1626,7 +1625,9 @@ def test_posterior_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(
ts, posterior, save_metadata=False
)
assert np.array_equal(
mn_post,
[
Expand All @@ -1638,31 +1639,35 @@ def test_posterior_mean_var(self):

def test_node_metadata_single_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
tables = ts.dump_tables()
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
ts = tables.tree_sequence()
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm")
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2]
assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2]
assert ts_node_metadata.node(2).metadata["mn"] == mn_post[2]
assert ts_node_metadata.node(2).metadata["vr"] == vr_post[2]

def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, eps, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
unconstrained_mn = [
json.loads(met.decode())["mn"]
for met in tskit.unpack_bytes(metadata, metadata_offset)
if len(met.decode()) > 0
]
assert np.array_equal(unconstrained_mn, mn_post[larger_ts.num_samples :])
assert np.all(
dated_ts.tables.nodes.time[larger_ts.num_samples :]
>= mn_post[larger_ts.num_samples :]
is_sample = np.zeros(larger_ts.num_nodes, dtype=bool)
is_sample[larger_ts.samples()] = True
is_not_sample = np.logical_not(is_sample)
# This calls posterior_mean_var
_, mn_post, _, _, _, _ = get_dates(
larger_ts,
method="inside_outside",
population_size=1,
mutation_rate=1,
save_metadata=False,
)
constrained_time = constrain_ages_topo(larger_ts, mn_post, eps=1e-6)
# Samples identical in all methods
assert np.array_equal(larger_ts.nodes_time[is_sample], mn_post[is_sample])
assert np.array_equal(constrained_time[is_sample], mn_post[is_sample])
# Non-samples should adhere to constraints
assert np.all(constrained_time[is_not_sample] >= mn_post[is_not_sample])


class TestConstrainAgesTopo:
Expand Down Expand Up @@ -1862,7 +1867,9 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(
ts, mutation_rate=None, population_size=1, save_metadata=False
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down Expand Up @@ -1966,15 +1973,15 @@ def test_sites_time_simulated(self):
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, _, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
larger_ts, mutation_rate=None, population_size=10000, save_metadata=False
)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.array_equal(
mn_post[larger_ts.tables.mutations.node],
mn_post[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
)
assert np.array_equal(
dated.tables.nodes.time[larger_ts.tables.mutations.node],
dated.nodes_time[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0),
)

Expand Down
287 changes: 287 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# MIT License
#
# Copyright (c) 2021-23 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Test cases for metadata setting functionality in tsdate.
"""
import json
import logging

import pytest
import tskit
import utility_functions

from tsdate.core import date
from tsdate.metadata import node_md_struct

struct_obj_only_example = tskit.MetadataSchema(
{
"codec": "struct",
"type": "object",
"properties": {
"node_id": {"type": "integer", "binaryFormat": "i"},
},
"additionalProperties": False,
}
)

struct_bad_mn = tskit.MetadataSchema(
{
"codec": "struct",
"type": "object",
"properties": {
"mn": {"type": "integer", "binaryFormat": "i"},
},
"additionalProperties": False,
}
)

struct_bad_vr = tskit.MetadataSchema(
{
"codec": "struct",
"type": "object",
"properties": {
"vr": {"type": "string", "binaryFormat": "10p"},
},
"additionalProperties": False,
}
)


class TestBytes:
"""
Tests for when existing node metadata is in raw bytes
"""

def test_no_existing(self):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
assert ts.node(root).metadata == b""
assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None)
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0

def test_append_existing(self):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None)
tables = ts.dump_tables()
tables.nodes.clear()
for nd in ts.nodes():
tables.nodes.append(nd.replace(metadata=b'{"node_id": %d}' % nd.id))
ts = tables.tree_sequence()
assert json.loads(ts.node(root).metadata.decode())["node_id"] == root
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.node(root).metadata["node_id"] == root
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0

def test_replace_existing(self):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None)
tables = ts.dump_tables()
tables.nodes.clear()
for nd in ts.nodes():
tables.nodes.append(nd.replace(metadata=b'{"mn": 1.0}'))
ts = tables.tree_sequence()
assert json.loads(ts.node(root).metadata.decode())["mn"] == pytest.approx(1.0)
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.node(root).metadata["mn"] != pytest.approx(1.0)
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0

def test_existing_bad(self):
ts = utility_functions.single_tree_ts_n2()
assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None)
tables = ts.dump_tables()
tables.nodes.clear()
for nd in ts.nodes():
tables.nodes.append(nd.replace(metadata=b"!!"))
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="Cannot modify"):
date(ts, mutation_rate=1, population_size=1)

def test_erase_existing_bad(self, caplog):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None)
tables = ts.dump_tables()
tables.nodes.clear()
for nd in ts.nodes():
tables.nodes.append(nd.replace(metadata=b"!!"))
ts = tables.tree_sequence()
# Should be able to replace using set_metadat=True
with caplog.at_level(logging.WARNING):
ts = date(ts, mutation_rate=1, population_size=1, set_metadata=True)
assert "Erasing existing node metadata" in caplog.text
assert ts.table_metadata_schemas.node.schema["codec"] == "struct"
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0


class TestStruct:
"""
Tests for when existing node metadata is as a struct
"""

def test_append_existing(self):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
tables = ts.dump_tables()
tables.nodes.metadata_schema = struct_obj_only_example
tables.nodes.packset_metadata(
[
tables.nodes.metadata_schema.validate_and_encode_row({"node_id": i})
for i in range(ts.num_nodes)
]
)
ts = tables.tree_sequence()
assert ts.node(root).metadata["node_id"] == root
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.node(root).metadata["node_id"] == root
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0

def test_replace_existing(self, caplog):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
tables = ts.dump_tables()
tables.nodes.metadata_schema = node_md_struct
tables.nodes.packset_metadata(
[
tables.nodes.metadata_schema.validate_and_encode_row(None)
for _ in range(ts.num_nodes)
]
)
ts = tables.tree_sequence()
assert ts.node(root).metadata is None
with caplog.at_level(logging.INFO):
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.table_metadata_schemas.node.schema["codec"] == "struct"
assert "Replacing 'mn'" in caplog.text
assert "Replacing 'vr'" in caplog.text
assert "Schema modified" in caplog.text
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0
sample = ts.samples()[0]
assert ts.node(sample).metadata is None

def test_existing_bad_mn(self, caplog):
ts = utility_functions.single_tree_ts_n2()
tables = ts.dump_tables()
tables.nodes.metadata_schema = struct_bad_mn
tables.nodes.packset_metadata(
[
tables.nodes.metadata_schema.validate_and_encode_row({"mn": 1})
for _ in range(ts.num_nodes)
]
)
ts = tables.tree_sequence()
with pytest.raises(
ValueError, match=r"Cannot change type of node.metadata\['mn'\]"
):
date(ts, mutation_rate=1, population_size=1)

def test_existing_bad_vr(self, caplog):
ts = utility_functions.single_tree_ts_n2()
tables = ts.dump_tables()
tables.nodes.metadata_schema = struct_bad_vr
tables.nodes.packset_metadata(
[
tables.nodes.metadata_schema.validate_and_encode_row({"vr": "foo"})
for _ in range(ts.num_nodes)
]
)
ts = tables.tree_sequence()
with pytest.raises(
ValueError, match=r"Cannot change type of node.metadata\['vr'\]"
):
date(ts, mutation_rate=1, population_size=1)


class TestJson:
"""
Tests for when existing node metadata is json encoded
"""

def test_replace_existing(self, caplog):
ts = utility_functions.single_tree_ts_n2()
root = ts.first().root
tables = ts.dump_tables()
schema = tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
tables.nodes.packset_metadata(
[
schema.validate_and_encode_row(
{f"node {i}": 1, "mn": "foo", "vr": "bar"}
)
for i in range(ts.num_nodes)
]
)
ts = tables.tree_sequence()
assert "node 0" in ts.node(0).metadata
assert ts.node(0).metadata["mn"] == "foo"
with caplog.at_level(logging.INFO):
ts = date(ts, mutation_rate=1, population_size=1)
assert ts.table_metadata_schemas.node.schema["codec"] == "json"
assert "Schema modified" in caplog.text
sample = ts.samples()[0]
assert f"node {sample}" in ts.node(sample).metadata
# Should have deleted mn and vr
assert "mn" not in ts.node(sample).metadata
assert "vr" not in ts.node(sample).metadata
assert f"node {root}" in ts.node(root).metadata
assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root])
assert ts.node(root).metadata["vr"] > 0


class TestNoSetMetadata:
"""
Tests for when metadata is not saved
"""

@pytest.mark.parametrize(
"method", ["inside_outside", "maximization", "variational_gamma"]
)
def test_empty(self, method):
ts = utility_functions.single_tree_ts_n2()
assert len(ts.tables.nodes.metadata) == 0
ts = date(
ts, mutation_rate=1, population_size=1, method=method, set_metadata=False
)
assert len(ts.tables.nodes.metadata) == 0

@pytest.mark.parametrize(
"method", ["inside_outside", "maximization", "variational_gamma"]
)
def test_random_md(self, method):
ts = utility_functions.single_tree_ts_n2()
assert len(ts.tables.nodes.metadata) == 0
tables = ts.dump_tables()
tables.nodes.packset_metadata([(b"random %i" % u) for u in range(ts.num_nodes)])
ts = tables.tree_sequence()
assert len(ts.tables.nodes.metadata) > 0
dts = date(
ts, mutation_rate=1, population_size=1, method=method, set_metadata=False
)
assert len(ts.tables.nodes.metadata) == len(dts.tables.nodes.metadata)
Loading

0 comments on commit 1bde2e7

Please sign in to comment.