Skip to content

Commit

Permalink
First pass at genetic_relatedness_weighted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brieuclehmann committed Jun 29, 2023
1 parent d442686 commit eff37eb
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 10 deletions.
18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.3.0
hooks:
- id: check-merge-conflict
- id: debug-statements
Expand All @@ -12,37 +12,37 @@ repos:
hooks:
- id: copyright-year
- repo: https://github.com/benjeffery/pre-commit-clang-format
rev: c21a74d089aaeb86c2c19df371c7e7bf40c07207
rev: '1.0'
hooks:
- id: clang-format
exclude: dev-tools|examples
verbose: true
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.0.1
rev: v3.9.0
hooks:
- id: reorder-python-imports
args: [--application-directories=python,
--unclassifiable-application-module=_tskit]
- repo: https://github.com/asottile/pyupgrade
rev: v2.31.1
rev: v3.2.2
hooks:
- id: pyupgrade
args: [--py3-plus, --py37-plus]
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.10.0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
args: [--config=python/.flake8]
additional_dependencies: ["flake8-bugbear==22.3.23", "flake8-builtins==1.5.3"]
additional_dependencies: ["flake8-bugbear==22.10.27", "flake8-builtins==2.0.1"]
- repo: https://github.com/asottile/blacken-docs
rev: v1.12.1
hooks:
- id: blacken-docs
args: [--skip-errors]
additional_dependencies: [black==22.3.0]
language_version: python3
language_version: python3
127 changes: 126 additions & 1 deletion python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2018-2022 Tskit Developers
# Copyright (c) 2018-2023 Tskit Developers
# Copyright (C) 2016 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -1010,6 +1010,43 @@ def wrapped_summary_func(x):
self.assertArrayAlmostEqual(sigma1, sigma4)


class kWayWeightStatsMixin(WeightStatsMixin):
"""
Implements the verify method and dispatches it to verify_weighted_stat
for a representative set of sample sets and windows.
"""

def verify_definition(self, ts, W, windows, summary_func, ts_method, definition):

# general_stat will need an extra column for p
gW = self.transform_weights(W)

def wrapped_summary_func(x):
with suppress_division_by_zero_warning():
return summary_func(x)

# Determine output_dim of the function
M = len(wrapped_summary_func(gW[0]))
for sn in [True, False]:
sigma1 = ts.general_stat(
gW, wrapped_summary_func, M, windows, mode=self.mode, span_normalise=sn
)
sigma2 = general_stat(
ts, gW, wrapped_summary_func, windows, mode=self.mode, span_normalise=sn
)
sigma3 = ts_method(W, windows=windows, mode=self.mode, span_normalise=sn)
sigma4 = definition(
ts, W, windows=windows, mode=self.mode, span_normalise=sn
)

assert sigma1.shape == sigma2.shape
assert sigma1.shape == sigma3.shape
assert sigma1.shape == sigma4.shape
self.assertArrayAlmostEqual(sigma1, sigma2)
self.assertArrayAlmostEqual(sigma1, sigma3)
self.assertArrayAlmostEqual(sigma1, sigma4)


class SampleSetStatsMixin:
"""
Implements the verify method and dispatches it to verify_sample_sets
Expand Down Expand Up @@ -2101,6 +2138,94 @@ def test_match_K_c0(self):
self.assertArrayAlmostEqual(A, B)


############################################
# Genetic relatedness weighted
############################################
# still need to implement multiple index pairs and multiple windows


def genetic_relatedness_matrix(ts, sample_sets, mode, windows=None):
n = len(sample_sets)
indexes = [
(n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2)
]
K = np.zeros((n, n))
K[np.triu_indices(n)] = ts.genetic_relatedness(
sample_sets, indexes, mode=mode, proportion=False, span_normalise=False
)
K = K + np.triu(K, 1).transpose()
return K


def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"):
sample_sets = [[u] for u in ts.samples()]
K = genetic_relatedness_matrix(ts, sample_sets, mode)
i1 = indexes[0]
i2 = indexes[1]
return W[:, i1] @ K @ W[:, i2]


def example_index_pairs(weights):
assert weights.shape[1] >= 2
yield [(0, 1)]
yield [(1, 0), (0, 1)]
if weights.shape[1] > 2:
yield [(0, 1), (1, 2), (0, 2)]


class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin):
mode = None

def verify_definition(
self, ts, W, indexes, windows, summary_func, ts_method, definition
):
# sigma1 = ts.general_stat(W, summary_func, windows, mode=self.mode)
# sigma2 = general_stat(ts, W, summary_func, windows, mode=self.mode)

sigma3 = ts_method(
W,
indexes=indexes,
windows=windows,
mode=self.mode,
)
sigma4 = definition(
ts,
W,
indexes=indexes,
windows=windows,
mode=self.mode,
)
# assert sigma1.shape == sigma2.shape
# assert sigma1.shape == sigma3.shape
assert sigma3.shape == sigma4.shape
# self.assertArrayAlmostEqual(sigma1, sigma2)
# self.assertArrayAlmostEqual(sigma1, sigma3)
self.assertArrayAlmostEqual(sigma3, sigma4)

def verify(self, ts):
for W, windows in subset_combos(
self.example_weights(ts, min_size=2), example_windows(ts), p=0.1
):
for indexes in example_index_pairs(W):
self.verify_weighted_stat(ts, W, indexes, windows)

def verify_weighted_stat(self, ts, W, indexes, windows):
n = W.shape[0]

def f(x):
return (x**2) / (2 * (n - 1) * (n - 1))

self.verify_definition(
ts,
W,
indexes,
windows,
f,
ts.genetic_relatedness_weighted,
genetic_relatedness_weighted,
)


############################################
# Fst
############################################
Expand Down

0 comments on commit eff37eb

Please sign in to comment.