From eff37ebe9de324d34eb6fc37c13faa04bac32b45 Mon Sep 17 00:00:00 2001 From: Brieuc Date: Thu, 29 Jun 2023 15:44:49 +0100 Subject: [PATCH] First pass at genetic_relatedness_weighted tests --- .pre-commit-config.yaml | 18 ++--- python/tests/test_tree_stats.py | 127 +++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 742c73d86f..e5aa957331 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.3.0 hooks: - id: check-merge-conflict - id: debug-statements @@ -12,37 +12,37 @@ repos: hooks: - id: copyright-year - repo: https://github.com/benjeffery/pre-commit-clang-format - rev: c21a74d089aaeb86c2c19df371c7e7bf40c07207 + rev: '1.0' hooks: - id: clang-format exclude: dev-tools|examples verbose: true - repo: https://github.com/asottile/reorder_python_imports - rev: v3.0.1 + rev: v3.9.0 hooks: - id: reorder-python-imports args: [--application-directories=python, --unclassifiable-application-module=_tskit] - repo: https://github.com/asottile/pyupgrade - rev: v2.31.1 + rev: v3.2.2 hooks: - id: pyupgrade args: [--py3-plus, --py37-plus] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.10.0 hooks: - id: black language_version: python3 - - repo: https://gitlab.com/pycqa/flake8 - rev: 3.9.2 + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 hooks: - id: flake8 args: [--config=python/.flake8] - additional_dependencies: ["flake8-bugbear==22.3.23", "flake8-builtins==1.5.3"] + additional_dependencies: ["flake8-bugbear==22.10.27", "flake8-builtins==2.0.1"] - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 hooks: - id: blacken-docs args: [--skip-errors] additional_dependencies: [black==22.3.0] - language_version: python3 + language_version: python3 \ No newline at end of file diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 3d49a79c14..d24a4011d3 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -1010,6 +1010,43 @@ def wrapped_summary_func(x): self.assertArrayAlmostEqual(sigma1, sigma4) +class kWayWeightStatsMixin(WeightStatsMixin): + """ + Implements the verify method and dispatches it to verify_weighted_stat + for a representative set of sample sets and windows. + """ + + def verify_definition(self, ts, W, windows, summary_func, ts_method, definition): + + # general_stat will need an extra column for p + gW = self.transform_weights(W) + + def wrapped_summary_func(x): + with suppress_division_by_zero_warning(): + return summary_func(x) + + # Determine output_dim of the function + M = len(wrapped_summary_func(gW[0])) + for sn in [True, False]: + sigma1 = ts.general_stat( + gW, wrapped_summary_func, M, windows, mode=self.mode, span_normalise=sn + ) + sigma2 = general_stat( + ts, gW, wrapped_summary_func, windows, mode=self.mode, span_normalise=sn + ) + sigma3 = ts_method(W, windows=windows, mode=self.mode, span_normalise=sn) + sigma4 = definition( + ts, W, windows=windows, mode=self.mode, span_normalise=sn + ) + + assert sigma1.shape == sigma2.shape + assert sigma1.shape == sigma3.shape + assert sigma1.shape == sigma4.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma1, sigma4) + + class SampleSetStatsMixin: """ Implements the verify method and dispatches it to verify_sample_sets @@ -2101,6 +2138,94 @@ def test_match_K_c0(self): self.assertArrayAlmostEqual(A, B) +############################################ +# Genetic relatedness weighted +############################################ +# still need to implement multiple index pairs and multiple windows + + +def genetic_relatedness_matrix(ts, sample_sets, mode, windows=None): + n = len(sample_sets) + indexes = [ + (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) + ] + K = np.zeros((n, n)) + K[np.triu_indices(n)] = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=False + ) + K = K + np.triu(K, 1).transpose() + return K + + +def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"): + sample_sets = [[u] for u in ts.samples()] + K = genetic_relatedness_matrix(ts, sample_sets, mode) + i1 = indexes[0] + i2 = indexes[1] + return W[:, i1] @ K @ W[:, i2] + + +def example_index_pairs(weights): + assert weights.shape[1] >= 2 + yield [(0, 1)] + yield [(1, 0), (0, 1)] + if weights.shape[1] > 2: + yield [(0, 1), (1, 2), (0, 2)] + + +class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): + mode = None + + def verify_definition( + self, ts, W, indexes, windows, summary_func, ts_method, definition + ): + # sigma1 = ts.general_stat(W, summary_func, windows, mode=self.mode) + # sigma2 = general_stat(ts, W, summary_func, windows, mode=self.mode) + + sigma3 = ts_method( + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + sigma4 = definition( + ts, + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + # assert sigma1.shape == sigma2.shape + # assert sigma1.shape == sigma3.shape + assert sigma3.shape == sigma4.shape + # self.assertArrayAlmostEqual(sigma1, sigma2) + # self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma3, sigma4) + + def verify(self, ts): + for W, windows in subset_combos( + self.example_weights(ts, min_size=2), example_windows(ts), p=0.1 + ): + for indexes in example_index_pairs(W): + self.verify_weighted_stat(ts, W, indexes, windows) + + def verify_weighted_stat(self, ts, W, indexes, windows): + n = W.shape[0] + + def f(x): + return (x**2) / (2 * (n - 1) * (n - 1)) + + self.verify_definition( + ts, + W, + indexes, + windows, + f, + ts.genetic_relatedness_weighted, + genetic_relatedness_weighted, + ) + + ############################################ # Fst ############################################