diff --git a/.circleci/config.yml b/.circleci/config.yml index 0b2e2056ca..d4b68b468f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,4 +1,6 @@ version: 2.1 +orbs: + codecov: codecov/codecov@3.2.4 commands: setup: @@ -12,7 +14,6 @@ commands: sudo pip install meson pip install numpy==1.18.5 pip install --user -r python/requirements/CI-complete/requirements.txt - pip install twine --user # Remove tskit installed by msprime pip uninstall tskit -y echo 'export PATH=/home/circleci/.local/bin:$PATH' >> $BASH_ENV @@ -48,12 +49,14 @@ commands: ninja -C build-gcc test - run: - name: Run gcov & upload coverage. + name: Run gcov command: | cd build-gcc find ../c/tskit/*.c -type f -printf "%f\n" | xargs -i gcov -pb libtskit.a.p/tskit_{}.gcno ../c/tskit/{} - cd .. - bash <(curl -s https://codecov.io/bash) -X gcov -X coveragepy -F c-tests + + - codecov/upload: + flags: c-tests + token: CODECOV_TOKEN - run: name: Valgrind for C tests. @@ -116,14 +119,17 @@ commands: python -m pytest -n2 - run: - name: Upload LWT coverage + name: Generate coverage command: | # Make sure the C coverage reports aren't lying around rm -fR build-gcc ls -R cd python/lwt_interface gcov -pb -o ./build/temp.linux*/*.gcno example_c_module.c - bash <(curl -s https://codecov.io/bash) -X gcov -F lwt-tests + + - codecov/upload: + flags: lwt-tests + token: CODECOV_TOKEN - run: name: Run Python tests @@ -132,30 +138,17 @@ commands: python -m pytest --cov=tskit --cov-report=xml --cov-branch -n2 tests/test_lowlevel.py tests/test_tables.py tests/test_file_format.py - run: - name: Upload Python coverage + name: Generate Python coverage command: | # Make sure the C coverage reports aren't lying around rm -fR build-gcc + rm -f python/lwt_interface/*.gcov cd python gcov -pb -o ./build/temp.linux*/*.gcno _tskitmodule.c - bash <(curl -s https://codecov.io/bash) -X gcov -F python-c-tests - - run: - name: Build Python package - command: | - cd python - rm -fR build - python setup.py sdist - python setup.py check - python -m twine check dist/*.tar.gz - python -m venv venv - source venv/bin/activate - pip install --upgrade setuptools pip wheel - python setup.py build_ext - python setup.py egg_info - python setup.py bdist_wheel - pip install dist/*.tar.gz - tskit --help + - codecov/upload: + flags: python-c-tests + token: CODECOV_TOKEN jobs: build: @@ -176,6 +169,28 @@ jobs: paths: - "/home/circleci/.local" - compile_and_test + - run: + name: Install dependencies for wheel test + command: | + ARGO_NET_GIT_FETCH_WITH_CLI=1 pip install twine --user + # Remove tskit installed by msprime + pip uninstall tskit -y + - run: + name: Build Python package + command: | + cd python + rm -fR build + python setup.py sdist + python setup.py check + python -m twine check dist/*.tar.gz + python -m venv venv + source venv/bin/activate + pip install --upgrade setuptools pip wheel + python setup.py build_ext + python setup.py egg_info + python setup.py bdist_wheel + pip install dist/*.tar.gz + tskit --help build-32: docker: @@ -191,6 +206,8 @@ jobs: key: tskit-32-{{ .Branch }}-v7 paths: - "/home/circleci/.local" + # We need to install curl for the codecov upload. + - run: sudo apt-get install -y curl - compile_and_test workflows: diff --git a/.github/workflows/docker/shared.env b/.github/workflows/docker/shared.env index 6f613fb81c..26a8ea4318 100644 --- a/.github/workflows/docker/shared.env +++ b/.github/workflows/docker/shared.env @@ -1,4 +1,5 @@ PYTHON_VERSIONS=( + cp311-cp311 cp310-cp310 cp39-cp39 cp38-cp38 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2da3ccc3b0..d6c90ff434 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,42 +19,47 @@ env: jobs: build-deploy-docs: name: Docs - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.6.0 + uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: - python-version: 3.8 - - - uses: actions/cache@v2 - id: cache - with: - path: venv - key: docs-venv-v2-${{ hashFiles(env.REQUIREMENTS) }} - - - name: Build virtualenv - if: steps.cache.outputs.cache-hit != 'true' - run: python -m venv venv - - - name: Install deps - run: venv/bin/activate && pip install -r ${{env.REQUIREMENTS}} + python-version: "3.10" - name: Install apt deps if: env.APTGET run: sudo apt-get install -y ${{env.APTGET}} + - uses: actions/cache@v3 + id: venv-cache + with: + path: venv + key: docs-venv-v4-${{ hashFiles(env.REQUIREMENTS) }} + + - name: Create venv and install deps (one by one to avoid conflict errors) + if: steps.venv-cache.outputs.cache-hit != 'true' + run: | + python -m venv venv + . venv/bin/activate + pip install --upgrade pip wheel + pip install -r ${{env.REQUIREMENTS}} + - name: Build C module if: env.MAKE_TARGET - run: venv/bin/activate && make $MAKE_TARGET + run: | + . venv/bin/activate + make $MAKE_TARGET - name: Build Docs - run: venv/bin/activate && make -C docs + run: | + . venv/bin/activate + make -C docs - name: Trigger docs site rebuild if: github.ref == 'refs/heads/main' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8c617bba7e..58eb68a9d3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,10 +8,10 @@ on: jobs: pre-commit: name: Lint - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.10.0 + uses: styfle/cancel-workflow-action@0.11.0 with: access_token: ${{ github.token }} - uses: actions/checkout@v3 @@ -21,9 +21,8 @@ jobs: - name: install clang-format if: steps.clang_format.outputs.cache-hit != 'true' run: | - sudo apt-get remove -y clang-6.0 libclang-common-6.0-dev libclang1-6.0 libllvm6.0 - sudo apt-get autoremove - sudo apt-get install clang-format clang-format-6.0 + sudo pip install clang-format==6.0.1 + sudo ln -s /usr/local/bin/clang-format /usr/local/bin/clang-format-6.0 - uses: pre-commit/action@v3.0.0 benchmark: @@ -35,11 +34,11 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' cache: 'pip' - cache-dependency-path: python/requirements/development.txt + cache-dependency-path: python/requirements/benchmark.txt - name: Install deps - run: pip install -r python/requirements/development.txt + run: pip install -r python/requirements/benchmark.txt - name: Build module run: | cd python @@ -61,7 +60,7 @@ jobs: strategy: fail-fast: false matrix: - python: [ 3.7, 3.9, "3.10" ] + python: [ 3.7, 3.9, "3.11" ] os: [ macos-latest, ubuntu-latest, windows-latest ] defaults: run: @@ -84,7 +83,7 @@ jobs: /usr/share/miniconda/envs/anaconda-client-env ~/osx-conda ~/.profile - key: ${{ runner.os }}-${{ matrix.python}}-conda-v11-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} + key: ${{ runner.os }}-${{ matrix.python}}-conda-v12-${{ hashFiles('python/requirements/CI-tests-conda/requirements.txt') }}-${{ hashFiles('python/requirements/CI-tests-pip/requirements.txt') }} - name: Install Conda uses: conda-incubator/setup-miniconda@v2 @@ -143,6 +142,11 @@ jobs: conda activate anaconda-client-env python setup.py build_ext --inplace + - name: Remove py311 incompatible tests (lack of numba support for 3.11, needed for lshmm) + if: matrix.python == '3.11' + run: | + rm python/tests/test_genotype_matching_* + - name: Run tests working-directory: python run: | @@ -152,11 +156,11 @@ jobs: python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 tests - name: Upload coverage to Codecov - if: matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: + token: ${{ secrets.CODECOV_TOKEN }} working-directory: python - fail_ci_if_error: true + fail_ci_if_error: false flags: python-tests name: codecov-umbrella verbose: true diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index bdd48f89cd..caa1c0837c 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -15,12 +15,12 @@ jobs: runs-on: macos-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] steps: - name: Checkout uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install deps @@ -55,7 +55,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] wordsize: [64] steps: - name: Checkout @@ -108,7 +108,7 @@ jobs: uses: actions/checkout@v2 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 @@ -127,7 +127,7 @@ jobs: - name: Build wheels in docker shell: bash run: | - docker run --rm -v `pwd`:/project -w /project quay.io/pypa/manylinux2010_x86_64 bash .github/workflows/docker/buildwheel.sh + docker run --rm -v `pwd`:/project -w /project quay.io/pypa/manylinux2014_x86_64 bash .github/workflows/docker/buildwheel.sh - name: Upload Wheels uses: actions/upload-artifact@v2 @@ -140,14 +140,14 @@ jobs: runs-on: macos-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] steps: - name: Download wheels uses: actions/download-artifact@v2 with: name: osx-wheel-${{ matrix.python }} - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test @@ -162,7 +162,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] wordsize: [64] steps: - name: Download wheels @@ -170,7 +170,7 @@ jobs: with: name: win-wheel-${{ matrix.python }}-${{ matrix.wordsize }} - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test @@ -186,7 +186,7 @@ jobs: needs: ['manylinux'] strategy: matrix: - python: [3.7, 3.8, 3.9, "3.10"] + python: [3.7, 3.8, 3.9, "3.10", 3.11] include: - python: 3.7 wheel: cp37 @@ -196,13 +196,15 @@ jobs: wheel: cp39 - python: "3.10" wheel: cp310 + - python: 3.11 + wheel: cp311 steps: - name: Download wheels uses: actions/download-artifact@v2 with: name: linux-wheels - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install wheel and test diff --git a/.gitignore b/.gitignore index 32c8ed68d4..fcf4695e11 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build-gcc python/benchmark/*.trees python/benchmark/*.json python/benchmark/*.html +.venv +.env diff --git a/.mergify.yml b/.mergify.yml index 8a2b976d8b..8b0d980744 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -7,10 +7,13 @@ queue_rules: - status-success=Lint - status-success=Python (3.7, macos-latest) - status-success=Python (3.9, macos-latest) + - status-success=Python (3.11, macos-latest) - status-success=Python (3.7, ubuntu-latest) - status-success=Python (3.9, ubuntu-latest) + - status-success=Python (3.11, ubuntu-latest) - status-success=Python (3.7, windows-latest) - status-success=Python (3.9, windows-latest) + - status-success=Python (3.11, windows-latest) - "status-success=ci/circleci: build" pull_request_rules: - name: Automatic rebase, CI and merge @@ -24,10 +27,13 @@ pull_request_rules: - status-success=Lint - status-success=Python (3.7, macos-latest) - status-success=Python (3.9, macos-latest) + - status-success=Python (3.11, macos-latest) - status-success=Python (3.7, ubuntu-latest) - status-success=Python (3.9, ubuntu-latest) + - status-success=Python (3.11, ubuntu-latest) - status-success=Python (3.7, windows-latest) - status-success=Python (3.9, windows-latest) + - status-success=Python (3.11, windows-latest) - "status-success=ci/circleci: build" #- status-success=codecov/patch #- status-success=codecov/project/c-tests @@ -37,7 +43,6 @@ pull_request_rules: queue: name: default method: rebase - rebase_fallback: none update_method: rebase - name: Remove label after merge @@ -49,29 +54,3 @@ pull_request_rules: remove: - AUTOMERGE-REQUESTED - - name: Automatic dep update - conditions: - - author~=^dependabot(|-preview)\[bot\]$ - - "-merged" - - base=main - - label=dependancy-upgrade - - status-success=Docs - - status-success=Lint - - status-success=Python (3.7, macos-latest) - - status-success=Python (3.9, macos-latest) - - status-success=Python (3.7, ubuntu-latest) - - status-success=Python (3.9, ubuntu-latest) - - status-success=Python (3.7, windows-latest) - - status-success=Python (3.9, windows-latest) - - "status-success=ci/circleci: build" - - "status-success=ci/circleci: build-32" - - status-success=codecov/patch - - status-success=codecov/project/c-tests - - status-success=codecov/project/python-c-tests - - status-success=codecov/project/python-tests - actions: - queue: - name: default - method: rebase - rebase_fallback: none - update_method: rebase diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 461d3d788a..11606a42f2 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -1,7 +1,12 @@ -------------------- -[1.1.2] - 2022-XX-XX +[1.1.2] - 2023-05-17 -------------------- +**Performance improvements** + +- tsk_tree_seek is now much faster at seeking to arbitrary points along + the sequence from the null tree (:user:`molpopgen`, :pr:`2661`). + **Features** - The struct ``tsk_treeseq_t`` now has the variables ``min_time`` and ``max_time``, @@ -10,6 +15,22 @@ ``tsk_treeseq_get_min_time`` and ``tsk_treeseq_get_max_time``, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) +- Add the `TSK_SIMPLIFY_NO_FILTER_NODES` option to simplify to allow unreferenced + nodes be kept in the output (:user:`jeromekelleher`, :user:`hyanwong`, + :issue:`2606`, :pr:`2619`). + +- Add the `TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option to simplify which ensures + no node sample flags are changed to allow calling code to manage sample status. + (:user:`jeromekelleher`, :issue:`2662`, :pr:`2663`). + +- Guarantee that unfiltered tables are not written to unnecessarily + during simplify (:user:`jeromekelleher` :pr:`2619`). + +- Add `x_table_keep_rows` methods to provide efficient in-place table subsetting + (:user:`jeromekelleher`, :pr:`2700`). + +- Add `tsk_tree_seek_index` function + -------------------- [1.1.1] - 2022-07-29 -------------------- diff --git a/c/VERSION.txt b/c/VERSION.txt index 8cfbc905b3..8428158dc5 100644 --- a/c/VERSION.txt +++ b/c/VERSION.txt @@ -1 +1 @@ -1.1.1 \ No newline at end of file +1.1.2 \ No newline at end of file diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 35991288d4..39f2a063e0 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -262,6 +262,48 @@ verify_mean_descendants(tsk_treeseq_t *ts) free(C); } +/* Check the divergence matrix by running against the stats API equivalent + * code. NOTE: this will not always be equal in site mode, because of a slightly + * different definition wrt to multiple mutations at a site. + */ +static void +verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t mode) +{ + int ret; + const tsk_size_t n = tsk_treeseq_get_num_samples(ts); + const tsk_id_t *samples = tsk_treeseq_get_samples(ts); + tsk_size_t sample_set_sizes[n]; + tsk_id_t index_tuples[2 * n * n]; + double D1[n * n], D2[n * n]; + tsk_size_t i, j, k; + + for (j = 0; j < n; j++) { + sample_set_sizes[j] = 1; + for (k = 0; k < n; k++) { + index_tuples[2 * (j * n + k)] = (tsk_id_t) j; + index_tuples[2 * (j * n + k) + 1] = (tsk_id_t) k; + } + } + ret = tsk_treeseq_divergence( + ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, mode, D1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, mode, D2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < n; j++) { + for (k = 0; k < n; k++) { + i = j * n + k; + /* printf("%d\t%d\t%f\t%f\n", (int) j, (int) k, D1[i], D2[i]); */ + if (j == k) { + CU_ASSERT_EQUAL(D2[i], 0); + } else { + CU_ASSERT_DOUBLE_EQUAL(D1[i], D2[i], 1E-6); + } + } + } +} + typedef struct { int call_count; int error_on; @@ -303,6 +345,16 @@ verify_window_errors(tsk_treeseq_t *ts, tsk_flags_t mode) ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[1] = -1; + ret = tsk_treeseq_general_stat( + ts, 1, W, 1, general_stat_error, NULL, 1, windows, options, sigma); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + windows[0] = 10; ret = tsk_treeseq_general_stat( ts, 1, W, 1, general_stat_error, NULL, 2, windows, options, sigma); @@ -396,11 +448,10 @@ verify_node_general_stat_errors(tsk_treeseq_t *ts) static void verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -409,7 +460,10 @@ verify_one_way_weighted_func_errors(tsk_treeseq_t *ts, one_way_weighted_method * } ret = method(ts, 0, weights, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -418,12 +472,11 @@ static void verify_one_way_weighted_covariate_func_errors( tsk_treeseq_t *ts, one_way_covariates_method *method) { - // we don't have any specific errors for this function - // but we might add some in the future int ret; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); double *weights = tsk_malloc(num_samples * sizeof(double)); double *covariates = NULL; + double bad_windows[] = { 0, -1 }; double result; tsk_size_t j; @@ -432,7 +485,10 @@ verify_one_way_weighted_covariate_func_errors( } ret = method(ts, 0, weights, 0, covariates, 0, NULL, 0, &result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 1, weights, 0, covariates, 1, bad_windows, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); free(weights); } @@ -516,6 +572,28 @@ verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *m CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); } +static void +verify_two_way_weighted_stat_func_errors( + tsk_treeseq_t *ts, two_way_weighted_method *method) +{ + int ret; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double bad_windows[] = { -1, -1 }; + double weights[10]; + double result[10]; + + memset(weights, 0, sizeof(weights)); + + ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); + + ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); +} + static void verify_three_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *method) { @@ -973,6 +1051,128 @@ test_single_tree_general_stat_errors(void) tsk_treeseq_free(&ts); } +static void +test_single_tree_divergence_matrix(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 6, 6, 2, 0, 6, 6, 6, 6, 0, 4, 6, 6, 4, 0 }; + double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_internal_samples(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D[16] = { 0, 2, 4, 3, 2, 0, 4, 3, 4, 4, 0, 1, 3, 3, 1, 0 }; + + const char *nodes = "1 0 -1 -1\n" /* 2.00┊ 6 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏━┻━┓ ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 5* ┊ */ + "0 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "1 1 -1 -1\n" /* 0 * * * 1 */ + "0 2 -1 -1\n"; + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; + /* One mutations per branch so we get the same as the branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n" + "0.5 A\n" + "0.6 A\n"; + const char *mutations = "0 0 T -1\n" + "1 1 T -1\n" + "2 2 T -1\n" + "3 3 T -1\n" + "4 4 T -1\n" + "5 5 T -1\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_divergence_matrix_multi_root(void) +{ + tsk_treeseq_t ts; + int ret; + double result[16]; + double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 }; + double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 }; + + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" /* 2.00┊ 5 ┊ */ + "1 0 -1 -1\n" /* 1.00┊ 4 ┊ */ + "1 0 -1 -1\n" /* ┊ ┏┻┓ ┏┻┓ ┊ */ + "0 1 -1 -1\n" /* 0.00┊ 0 1 2 3 ┊ */ + "0 2 -1 -1\n"; /* 0 * * * * 1 */ + const char *edges = "0 1 4 0,1\n" + "0 1 5 2,3\n"; + /* Two mutations per branch unit so we get twice branch length value */ + const char *sites = "0.1 A\n" + "0.2 A\n" + "0.3 A\n" + "0.4 A\n"; + const char *mutations = "0 0 B -1\n" + "0 0 C 0\n" + "1 1 B -1\n" + "1 1 C 2\n" + "2 2 B -1\n" + "2 2 C 4\n" + "2 2 D 5\n" + "2 2 E 6\n" + "3 3 B -1\n" + "3 3 C 8\n" + "3 3 D 9\n" + "3 3 E 10\n"; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_branch); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D_site); + + tsk_treeseq_free(&ts); +} + static void test_paper_ex_ld(void) { @@ -1351,6 +1551,46 @@ test_paper_ex_genetic_relatedness_errors(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_weighted(void) +{ + tsk_treeseq_t ts; + double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double result[100]; + tsk_size_t num_weights; + int ret; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + for (num_weights = 1; num_weights < 3; num_weights++) { + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_NODE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + tsk_treeseq_free(&ts); +} + +static void +test_paper_ex_genetic_relatedness_weighted_errors(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + verify_two_way_weighted_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness_weighted); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_Y2_errors(void) { @@ -1592,6 +1832,20 @@ test_paper_ex_afs(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + static void test_nonbinary_ex_ld(void) { @@ -1726,6 +1980,158 @@ test_ld_silent_mutations(void) free(base_ts); } +static void +test_simplest_divergence_matrix(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[4] = { 0, 2, 2, 0 }; + double D_site[4] = { 0, 0, 0, 0 }; + double result[4]; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D_site, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + sample_ids[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + sample_ids[0] = 3; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_windows(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1 }; + double D_branch[8] = { 0, 1, 1, 0, 0, 1, 1, 0 }; + double D_site[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + double result[8]; + double windows[] = { 0, 0.5, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_site, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_ids, 2, windows, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(8, D_branch, result); + + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); + + windows[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.45; + windows[2] = 1.5; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + windows[0] = 0.55; + windows[2] = 1.0; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_divergence_matrix_internal_sample(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts; + tsk_id_t sample_ids[] = { 0, 1, 2 }; + double result[9]; + double D_branch[9] = { 0, 2, 1, 2, 0, 1, 1, 1, 0 }; + double D_site[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_branch, result); + + ret = tsk_treeseq_divergence_matrix( + &ts, 3, sample_ids, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(9, D_site, result); + + tsk_treeseq_free(&ts); +} + +static void +test_multiroot_divergence_matrix(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, multiroot_ex_nodes, multiroot_ex_edges, NULL, + multiroot_ex_sites, multiroot_ex_mutations, NULL, NULL, 0); + + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); + verify_divergence_matrix(&ts, TSK_STAT_SITE); + + tsk_treeseq_free(&ts); +} + int main(int argc, char **argv) { @@ -1745,6 +2151,11 @@ main(int argc, char **argv) test_single_tree_genealogical_nearest_neighbours }, { "test_single_tree_general_stat", test_single_tree_general_stat }, { "test_single_tree_general_stat_errors", test_single_tree_general_stat_errors }, + { "test_single_tree_divergence_matrix", test_single_tree_divergence_matrix }, + { "test_single_tree_divergence_matrix_internal_samples", + test_single_tree_divergence_matrix_internal_samples }, + { "test_single_tree_divergence_matrix_multi_root", + test_single_tree_divergence_matrix_multi_root }, { "test_paper_ex_ld", test_paper_ex_ld }, { "test_paper_ex_mean_descendants", test_paper_ex_mean_descendants }, @@ -1773,6 +2184,10 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness_errors", test_paper_ex_genetic_relatedness_errors }, { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, + { "test_paper_ex_genetic_relatedness_weighted", + test_paper_ex_genetic_relatedness_weighted }, + { "test_paper_ex_genetic_relatedness_weighted_errors", + test_paper_ex_genetic_relatedness_weighted_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, @@ -1785,6 +2200,7 @@ main(int argc, char **argv) { "test_paper_ex_f4", test_paper_ex_f4 }, { "test_paper_ex_afs_errors", test_paper_ex_afs_errors }, { "test_paper_ex_afs", test_paper_ex_afs }, + { "test_paper_ex_divergence_matrix", test_paper_ex_divergence_matrix }, { "test_nonbinary_ex_ld", test_nonbinary_ex_ld }, { "test_nonbinary_ex_mean_descendants", test_nonbinary_ex_mean_descendants }, @@ -1798,6 +2214,13 @@ main(int argc, char **argv) { "test_ld_multi_mutations", test_ld_multi_mutations }, { "test_ld_silent_mutations", test_ld_silent_mutations }, + { "test_simplest_divergence_matrix", test_simplest_divergence_matrix }, + { "test_simplest_divergence_matrix_windows", + test_simplest_divergence_matrix_windows }, + { "test_simplest_divergence_matrix_internal_sample", + test_simplest_divergence_matrix_internal_sample }, + { "test_multiroot_divergence_matrix", test_multiroot_divergence_matrix }, + { NULL, NULL }, }; return test_main(tests, argc, argv); diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 0520750099..c6c6ac0053 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ */ #include "testlib.h" +#include "tskit/core.h" #include #include @@ -345,10 +346,32 @@ test_table_collection_simplify_errors(void) tsk_id_t samples[] = { 0, 1 }; tsk_id_t ret_id; const char *individuals = "1 0.25 -2\n"; + ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tables.sequence_length = 1; + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + /* Bad samples */ + samples[0] = -1; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = 10; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + samples[0] = 0; + + /* Duplicate samples */ + samples[0] = 0; + samples[1] = 0; + ret = tsk_table_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + samples[0] = 0; + ret_id = tsk_site_table_add_row(&tables.sites, 0, "A", 1, NULL, 0); CU_ASSERT_FATAL(ret_id >= 0); ret_id = tsk_site_table_add_row(&tables.sites, 0, "A", 1, NULL, 0); @@ -1468,6 +1491,106 @@ test_node_table_update_row(void) tsk_node_table_free(&table); } +static void +test_node_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_node_table_t source, t1, t2; + tsk_node_t row; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + const char *metadata = "ABC"; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_node_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_node_table_add_row(&source, 0, 1.0, 2, 3, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&source, 1, 2.0, 3, 4, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&source, 2, 3.0, 4, 5, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_node_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &source, 0)); + + ret = tsk_node_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_node_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_node_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_node_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.flags, 1); + CU_ASSERT_EQUAL_FATAL(row.time, 2.0); + CU_ASSERT_EQUAL_FATAL(row.population, 3); + CU_ASSERT_EQUAL_FATAL(row.individual, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_node_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_node_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_node_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_node_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_node_table_equals(&source, &t2, 0)); + + tsk_node_table_free(&t1); + tsk_node_table_free(&t2); + } + + tsk_node_table_free(&source); +} + static void test_edge_table_with_options(tsk_flags_t options) { @@ -2011,6 +2134,203 @@ test_edge_table_update_row_no_metadata(void) tsk_edge_table_free(&table); } +static void +test_edge_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_edge_table_t source, t1, t2; + tsk_edge_t row; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + const char *metadata = "ABC"; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_edge_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_edge_table_add_row(&source, 0, 1.0, 2, 3, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 1, 2.0, 3, 4, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 2, 3.0, 4, 5, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_edge_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2.0); + CU_ASSERT_EQUAL_FATAL(row.parent, 3); + CU_ASSERT_EQUAL_FATAL(row.child, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_edge_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_edge_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_edge_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_edge_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&source, &t2, 0)); + + tsk_edge_table_free(&t1); + tsk_edge_table_free(&t2); + } + + tsk_edge_table_free(&source); +} + +static void +test_edge_table_keep_rows_no_metadata(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_edge_table_t source, t1, t2; + tsk_edge_t row; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_edge_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_edge_table_add_row(&source, 0, 1.0, 2, 3, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 1, 2.0, 3, 4, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&source, 2, 3.0, 4, 5, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_edge_table_copy(&source, &t1, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + + ret = tsk_edge_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_edge_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_edge_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2.0); + CU_ASSERT_EQUAL_FATAL(row.parent, 3); + CU_ASSERT_EQUAL_FATAL(row.child, 4); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 0); + + tsk_edge_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_edge_table_copy(&source, &t2, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_copy(&source, &t1, TSK_TABLE_NO_METADATA); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_edge_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_edge_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_edge_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_edge_table_equals(&source, &t2, 0)); + + tsk_edge_table_free(&t1); + tsk_edge_table_free(&t2); + } + + tsk_edge_table_free(&source); +} + static void test_edge_table_takeset_with_options(tsk_flags_t table_options) { @@ -2970,6 +3290,107 @@ test_site_table_update_row(void) tsk_site_table_free(&table); } +static void +test_site_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_site_table_t source, t1, t2; + tsk_site_t row; + const char *ancestral_state = "XYZ"; + const char *metadata = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_site_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_site_table_add_row(&source, 0, ancestral_state, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&source, 1, ancestral_state, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&source, 2, ancestral_state, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_site_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &source, 0)); + + ret = tsk_site_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_site_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_site_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_site_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.position, 1); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state_length, 2); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.ancestral_state[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_site_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_site_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_site_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_site_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_table_equals(&source, &t2, 0)); + + tsk_site_table_free(&t1); + tsk_site_table_free(&t2); + } + + tsk_site_table_free(&source); +} + static void test_mutation_table(void) { @@ -3655,6 +4076,199 @@ test_mutation_table_update_row(void) tsk_mutation_table_free(&table); } +static void +test_mutation_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_mutation_table_t source, t1, t2; + tsk_mutation_t row; + const char *derived_state = "XYZ"; + const char *metadata = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_mutation_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_mutation_table_add_row( + &source, 0, 1, -1, 3.0, derived_state, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &source, 1, 2, -1, 4.0, derived_state, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &source, 2, 3, 0, 5.0, derived_state, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_mutation_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &source, 0)); + + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_mutation_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_mutation_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_mutation_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.site, 1); + CU_ASSERT_EQUAL_FATAL(row.node, 2); + CU_ASSERT_EQUAL_FATAL(row.parent, -1); + CU_ASSERT_EQUAL_FATAL(row.time, 4); + CU_ASSERT_EQUAL_FATAL(row.derived_state_length, 2); + CU_ASSERT_EQUAL_FATAL(row.derived_state[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.derived_state[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_mutation_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_mutation_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_mutation_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_mutation_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_mutation_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_mutation_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t2, 0)); + + tsk_mutation_table_free(&t1); + tsk_mutation_table_free(&t2); + } + + tsk_mutation_table_free(&source); +} + +static void +test_mutation_table_keep_rows_parent_references(void) +{ + int ret; + tsk_id_t ret_id; + tsk_mutation_table_t source, t; + tsk_bool_t keep[4] = { 1, 1, 1, 1 }; + tsk_id_t id_map[4]; + + ret = tsk_mutation_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_mutation_table_add_row(&source, 0, 1, -1, 3.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 1, 2, -1, 4.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 2, 3, 1, 5.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row(&source, 3, 4, 1, 6.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* OOB errors */ + t.parent[0] = -2; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + + t.parent[0] = 4; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + /* But ignored if row is not kept */ + keep[0] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_mutation_table_free(&t); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Try to remove referenced row 1 */ + keep[0] = true; + keep[1] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* remove unreferenced row 0 */ + keep[0] = false; + keep[1] = true; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 3); + CU_ASSERT_EQUAL_FATAL(t.parent[0], TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.parent[1], 0); + CU_ASSERT_EQUAL_FATAL(t.parent[2], 0); + tsk_mutation_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parent[3] = -2; + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = true; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parent[3] = 0; + ret = tsk_mutation_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = false; + ret = tsk_mutation_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_mutation_table_equals(&source, &t, 0)); + tsk_mutation_table_free(&t); + + tsk_mutation_table_free(&source); +} + static void test_migration_table(void) { @@ -4256,6 +4870,108 @@ test_migration_table_update_row(void) tsk_migration_table_free(&table); } +static void +test_migration_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_migration_table_t source, t1, t2; + tsk_migration_t row; + const char *metadata = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_migration_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_migration_table_add_row(&source, 0, 1.0, 2, 3, 4, 5, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_migration_table_add_row(&source, 1, 2.0, 3, 4, 5, 6, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_migration_table_add_row(&source, 2, 3.0, 4, 5, 6, 7, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_migration_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &source, 0)); + + ret = tsk_migration_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_migration_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_migration_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_migration_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.left, 1); + CU_ASSERT_EQUAL_FATAL(row.right, 2); + CU_ASSERT_EQUAL_FATAL(row.node, 3); + CU_ASSERT_EQUAL_FATAL(row.source, 4); + CU_ASSERT_EQUAL_FATAL(row.dest, 5); + CU_ASSERT_EQUAL_FATAL(row.time, 6); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_migration_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_migration_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_migration_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_migration_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_migration_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_migration_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_migration_table_equals(&source, &t2, 0)); + + tsk_migration_table_free(&t1); + tsk_migration_table_free(&t2); + } + + tsk_migration_table_free(&source); +} + static void test_individual_table(void) { @@ -4981,6 +5697,201 @@ test_individual_table_update_row(void) tsk_individual_table_free(&table); } +static void +test_individual_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_individual_t row; + double location[] = { 0, 1, 2 }; + tsk_id_t parents[] = { -1, 1, -1 }; + const char *metadata = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t indexes[] = { 0, 1, 2 }; + tsk_id_t id_map[3]; + tsk_individual_table_t source, t1, t2; + tsk_size_t j; + + ret = tsk_individual_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id + = tsk_individual_table_add_row(&source, 0, location, 1, parents, 1, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id + = tsk_individual_table_add_row(&source, 1, location, 2, parents, 2, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id + = tsk_individual_table_add_row(&source, 2, location, 3, parents, 3, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_individual_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &source, 0)); + + ret = tsk_individual_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_individual_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_individual_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_individual_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.flags, 1); + CU_ASSERT_EQUAL_FATAL(row.parents_length, 2); + CU_ASSERT_EQUAL_FATAL(row.parents[0], -1); + CU_ASSERT_EQUAL_FATAL(row.parents[1], 0); + CU_ASSERT_EQUAL_FATAL(row.location_length, 2); + CU_ASSERT_EQUAL_FATAL(row.location[0], 0); + CU_ASSERT_EQUAL_FATAL(row.location[1], 1); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_individual_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_individual_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_individual_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_individual_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_individual_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_individual_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t2, 0)); + + tsk_individual_table_free(&t1); + tsk_individual_table_free(&t2); + } + + tsk_individual_table_free(&source); +} + +static void +test_individual_table_keep_rows_parent_references(void) +{ + int ret; + tsk_id_t ret_id; + tsk_individual_table_t source, t; + tsk_bool_t keep[] = { 1, 1, 1, 1 }; + tsk_id_t parents[] = { -1, 1, 2 }; + tsk_id_t id_map[4]; + + ret = tsk_individual_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 3, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_individual_table_add_row(&source, 0, NULL, 0, parents, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* OOB errors */ + t.parents[0] = -2; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + + t.parents[0] = 4; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 4); + /* But ignored if row is not kept */ + keep[0] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_individual_table_free(&t); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Try to remove referenced row 2 */ + keep[0] = true; + keep[2] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* remove unreferenced row 0 */ + keep[0] = false; + keep[2] = true; + ret = tsk_individual_table_keep_rows(&t, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.num_rows, 3); + CU_ASSERT_EQUAL_FATAL(t.parents[0], TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.parents[1], 0); + CU_ASSERT_EQUAL_FATAL(t.parents[2], 1); + tsk_individual_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parents[1] = -2; + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = true; + ret = tsk_individual_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + /* Check that we don't change the table in error cases. */ + source.parents[1] = 0; + ret = tsk_individual_table_copy(&source, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = false; + ret = tsk_individual_table_keep_rows(&t, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); + CU_ASSERT_TRUE(tsk_individual_table_equals(&source, &t, 0)); + tsk_individual_table_free(&t); + + tsk_individual_table_free(&source); +} + static void test_population_table(void) { @@ -5358,6 +6269,102 @@ test_population_table_update_row(void) tsk_population_table_free(&table); } +static void +test_population_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_population_table_t source, t1, t2; + tsk_population_t row; + const char *metadata = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t id_map[3]; + tsk_id_t indexes[] = { 0, 1, 2 }; + + ret = tsk_population_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_population_table_add_row(&source, metadata, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_population_table_add_row(&source, metadata, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_population_table_add_row(&source, metadata, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_population_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &source, 0)); + + ret = tsk_population_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_population_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_population_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_population_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.metadata_length, 2); + CU_ASSERT_EQUAL_FATAL(row.metadata[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.metadata[1], 'B'); + + tsk_population_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_population_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_population_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_population_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_population_table_equals(&source, &t2, 0)); + + tsk_population_table_free(&t1); + tsk_population_table_free(&t2); + } + + tsk_population_table_free(&source); +} + static void test_provenance_table(void) { @@ -5797,6 +6804,106 @@ test_provenance_table_update_row(void) tsk_provenance_table_free(&table); } +static void +test_provenance_table_keep_rows(void) +{ + int ret; + tsk_id_t ret_id; + tsk_size_t j; + tsk_provenance_table_t source, t1, t2; + tsk_provenance_t row; + const char *timestamp = "XYZ"; + const char *record = "ABC"; + tsk_bool_t keep[3] = { 1, 1, 1 }; + tsk_id_t indexes[] = { 0, 1, 2 }; + tsk_id_t id_map[3]; + + ret = tsk_provenance_table_init(&source, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret_id = tsk_provenance_table_add_row(&source, timestamp, 1, record, 1); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_provenance_table_add_row(&source, timestamp, 2, record, 2); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_provenance_table_add_row(&source, timestamp, 3, record, 3); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_provenance_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &source, 0)); + + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &source, 0)); + CU_ASSERT_EQUAL_FATAL(id_map[0], 0); + CU_ASSERT_EQUAL_FATAL(id_map[1], 1); + CU_ASSERT_EQUAL_FATAL(id_map[2], 2); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], -1); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_provenance_table_copy(&source, &t1, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[0] = 0; + keep[1] = 1; + keep[2] = 0; + ret = tsk_provenance_table_keep_rows(&t1, keep, 0, id_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.num_rows, 1); + CU_ASSERT_EQUAL_FATAL(id_map[0], -1); + CU_ASSERT_EQUAL_FATAL(id_map[1], 0); + CU_ASSERT_EQUAL_FATAL(id_map[2], -1); + + ret = tsk_provenance_table_get_row(&t1, 0, &row); + CU_ASSERT_EQUAL_FATAL(row.timestamp_length, 2); + CU_ASSERT_EQUAL_FATAL(row.timestamp[0], 'X'); + CU_ASSERT_EQUAL_FATAL(row.timestamp[1], 'Y'); + CU_ASSERT_EQUAL_FATAL(row.record_length, 2); + CU_ASSERT_EQUAL_FATAL(row.record[0], 'A'); + CU_ASSERT_EQUAL_FATAL(row.record[1], 'B'); + + tsk_provenance_table_free(&t1); + + keep[0] = 0; + keep[1] = 0; + keep[2] = 0; + /* Keeping first n rows equivalent to truncate */ + for (j = 0; j < source.num_rows; j++) { + ret = tsk_provenance_table_copy(&source, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_provenance_table_copy(&source, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_provenance_table_truncate(&t1, j + 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + keep[j] = 1; + ret = tsk_provenance_table_keep_rows(&t2, keep, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&t1, &t2, 0)); + + /* Adding the remaining rows back on to the table gives the original + * table */ + ret = tsk_provenance_table_extend( + &t2, &source, source.num_rows - j - 1, indexes + j + 1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_provenance_table_equals(&source, &t2, 0)); + + tsk_provenance_table_free(&t1); + tsk_provenance_table_free(&t2); + } + + tsk_provenance_table_free(&source); +} + static void test_table_size_increments(void) { @@ -10433,17 +11540,50 @@ test_table_collection_delete_older(void) tsk_treeseq_free(&ts); } +static void +test_table_collection_edge_diffs_errors(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_diff_iter_t iter; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL(ret, 0); + tables.sequence_length = 1; + ret_id + = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0, -1, -1, NULL, 0); + CU_ASSERT_EQUAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, -1, -1, NULL, 0); + CU_ASSERT_EQUAL(ret_id, 1); + ret = (int) tsk_edge_table_add_row(&tables.edges, 0., 1., 1, 0, NULL, 0); + CU_ASSERT_EQUAL(ret, 0); + + ret = tsk_diff_iter_init(&iter, &tables, -1, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_TABLES_NOT_INDEXED); + + tables.nodes.time[0] = 1; + ret = tsk_diff_iter_init(&iter, &tables, -1, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + + tsk_table_collection_free(&tables); +} + int main(int argc, char **argv) { CU_TestInfo tests[] = { { "test_node_table", test_node_table }, { "test_node_table_update_row", test_node_table_update_row }, + { "test_node_table_keep_rows", test_node_table_keep_rows }, { "test_node_table_takeset", test_node_table_takeset }, { "test_edge_table", test_edge_table }, { "test_edge_table_update_row", test_edge_table_update_row }, { "test_edge_table_update_row_no_metadata", test_edge_table_update_row_no_metadata }, + { "test_edge_table_keep_rows", test_edge_table_keep_rows }, + { "test_edge_table_keep_rows_no_metadata", + test_edge_table_keep_rows_no_metadata }, { "test_edge_table_takeset", test_edge_table_takeset }, { "test_edge_table_copy_semantics", test_edge_table_copy_semantics }, { "test_edge_table_squash", test_edge_table_squash }, @@ -10455,21 +11595,31 @@ main(int argc, char **argv) { "test_edge_table_squash_metadata", test_edge_table_squash_metadata }, { "test_site_table", test_site_table }, { "test_site_table_update_row", test_site_table_update_row }, + { "test_site_table_keep_rows", test_site_table_keep_rows }, { "test_site_table_takeset", test_site_table_takeset }, { "test_mutation_table", test_mutation_table }, { "test_mutation_table_update_row", test_mutation_table_update_row }, { "test_mutation_table_takeset", test_mutation_table_takeset }, + { "test_mutation_table_keep_rows", test_mutation_table_keep_rows }, + { "test_mutation_table_keep_rows_parent_references", + test_mutation_table_keep_rows_parent_references }, { "test_migration_table", test_migration_table }, { "test_migration_table_update_row", test_migration_table_update_row }, + { "test_migration_table_keep_rows", test_migration_table_keep_rows }, { "test_migration_table_takeset", test_migration_table_takeset }, { "test_individual_table", test_individual_table }, { "test_individual_table_takeset", test_individual_table_takeset }, { "test_individual_table_update_row", test_individual_table_update_row }, + { "test_individual_table_keep_rows", test_individual_table_keep_rows }, + { "test_individual_table_keep_rows_parent_references", + test_individual_table_keep_rows_parent_references }, { "test_population_table", test_population_table }, { "test_population_table_update_row", test_population_table_update_row }, + { "test_population_table_keep_rows", test_population_table_keep_rows }, { "test_population_table_takeset", test_population_table_takeset }, { "test_provenance_table", test_provenance_table }, { "test_provenance_table_update_row", test_provenance_table_update_row }, + { "test_provenance_table_keep_rows", test_provenance_table_keep_rows }, { "test_provenance_table_takeset", test_provenance_table_takeset }, { "test_table_size_increments", test_table_size_increments }, { "test_table_expansion", test_table_expansion }, @@ -10556,6 +11706,8 @@ main(int argc, char **argv) { "test_table_collection_takeset_indexes", test_table_collection_takeset_indexes }, { "test_table_collection_delete_older", test_table_collection_delete_older }, + { "test_table_collection_edge_diffs_errors", + test_table_collection_edge_diffs_errors }, { NULL, NULL }, }; diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 18b904171b..5acf465db6 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -175,6 +175,97 @@ verify_individual_nodes(tsk_treeseq_t *ts) } } +static void +verify_tree_pos(const tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *tree_parents) +{ + int ret; + const tsk_size_t N = tsk_treeseq_get_num_nodes(ts); + const tsk_id_t *edges_parent = ts->tables->edges.parent; + const tsk_id_t *edges_child = ts->tables->edges.child; + tsk_tree_position_t tree_pos; + tsk_id_t *known_parent; + tsk_id_t *parent = tsk_malloc(N * sizeof(*parent)); + tsk_id_t u, index, j, e; + bool valid; + + CU_ASSERT_FATAL(parent != NULL); + + ret = tsk_tree_position_init(&tree_pos, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (u = 0; u < (tsk_id_t) N; u++) { + parent[u] = TSK_NULL; + } + + for (index = 0; index < (tsk_id_t) num_trees; index++) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j < tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j < tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + for (index = (tsk_id_t) num_trees - 1; index >= 0; index--) { + known_parent = tree_parents + N * (tsk_size_t) index; + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_TRUE(valid); + CU_ASSERT_EQUAL(index, tree_pos.index); + + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + + for (j = tree_pos.in.start; j > tree_pos.in.stop; j--) { + CU_ASSERT_FATAL(j >= 0); + e = tree_pos.in.order[j]; + parent[edges_child[e]] = edges_parent[e]; + } + + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], known_parent[u]); + } + } + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FALSE(valid); + for (j = tree_pos.out.start; j > tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + parent[edges_child[e]] = TSK_NULL; + } + for (u = 0; u < (tsk_id_t) N; u++) { + CU_ASSERT_EQUAL(parent[u], TSK_NULL); + } + + tsk_tree_position_free(&tree_pos); + tsk_safe_free(parent); +} + static void verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) { @@ -233,6 +324,8 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), breakpoints[j]); tsk_tree_free(&tree); + + verify_tree_pos(ts, num_trees, parents); } static tsk_tree_t * @@ -407,7 +500,7 @@ verify_tree_diffs(tsk_treeseq_t *ts, tsk_flags_t options) child[j] = TSK_NULL; sib[j] = TSK_NULL; } - ret = tsk_diff_iter_init(&iter, ts, options); + ret = tsk_diff_iter_init_from_ts(&iter, ts, options); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_init(&tree, ts, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -3257,6 +3350,102 @@ test_simplest_individual_filter(void) tsk_table_collection_free(&tables); } +static void +test_simplest_no_node_filter(void) +{ + const char *nodes = "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0"; /* unreferenced node */ + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = { 0, 1 }; + tsk_id_t node_map[] = { -1, -1, -1, -1 }; + tsk_id_t j; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_simplify( + &ts, NULL, 0, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + /* Reversing sample order makes no difference */ + sample_ids[0] = 1; + sample_ids[1] = 0; + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify( + &ts, sample_ids, 1, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(&simplified), 0); + for (j = 0; j < 4; j++) { + CU_ASSERT_EQUAL(node_map[j], j); + } + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 1, + TSK_SIMPLIFY_NO_FILTER_NODES | TSK_SIMPLIFY_KEEP_INPUT_ROOTS + | TSK_SIMPLIFY_KEEP_UNARY, + &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(&simplified), 1); + tsk_treeseq_free(&simplified); + + sample_ids[0] = 0; + sample_ids[1] = 0; + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + tsk_treeseq_free(&simplified); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_no_update_flags(void) +{ + const char *nodes = "0 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = "0 1 2 0,1\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = { 0, 1 }; + int ret; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + /* We have a mixture of sample and non-samples in the input tables */ + ret = tsk_treeseq_simplify( + &ts, sample_ids, 2, TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, + TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS | TSK_SIMPLIFY_NO_FILTER_NODES, &simplified, + NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0)); + tsk_treeseq_free(&simplified); + + tsk_treeseq_free(&ts); +} + static void test_simplest_map_mutations(void) { @@ -5137,6 +5326,65 @@ test_single_tree_tracked_samples(void) tsk_tree_free(&tree); } +static void +test_single_tree_tree_pos(void) +{ + tsk_treeseq_t ts; + tsk_tree_position_t tree_pos; + bool valid; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + + ret = tsk_tree_position_init(&tree_pos, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_insertion_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_next(&tree_pos); + CU_ASSERT_FATAL(!valid); + + tsk_tree_position_print_state(&tree_pos, _devnull); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 6); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_removal_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.left, 0); + CU_ASSERT_EQUAL_FATAL(tree_pos.interval.right, 1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.in.order, ts.tables->indexes.edge_removal_order); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + valid = tsk_tree_position_prev(&tree_pos); + CU_ASSERT_FATAL(!valid); + + CU_ASSERT_EQUAL_FATAL(tree_pos.index, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.start, 5); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.stop, -1); + CU_ASSERT_EQUAL_FATAL(tree_pos.out.order, ts.tables->indexes.edge_insertion_order); + + tsk_tree_position_free(&tree_pos); + tsk_treeseq_free(&ts); +} + /*======================================================= * Multi tree tests. *======================================================*/ @@ -5299,7 +5547,6 @@ test_simplify_keep_input_roots_multi_tree(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_treeseq_dump(&ts, "tmp.trees", 0); ret = tsk_treeseq_simplify( &ts, samples, 2, TSK_SIMPLIFY_KEEP_INPUT_ROOTS, &simplified, NULL); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -6044,10 +6291,16 @@ test_seek_multi_tree(void) ret = tsk_tree_seek(&t, breakpoints[j], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); + ret = tsk_tree_seek_index(&t, j, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, j); for (k = 0; k < num_trees; k++) { ret = tsk_tree_seek(&t, breakpoints[k], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, k); + ret = tsk_tree_seek_index(&t, k, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, k); } } @@ -6109,6 +6362,10 @@ test_seek_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); ret = tsk_tree_seek(&t, 11, 0); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, (tsk_id_t) ts.num_trees, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, -1, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); tsk_tree_free(&t); tsk_treeseq_free(&ts); @@ -7695,7 +7952,7 @@ test_time_uncalibrated(void) tsk_size_t sample_set_sizes[] = { 2, 2 }; tsk_id_t samples[] = { 0, 1, 2, 3 }; tsk_size_t num_samples; - double result[10]; + double result[100]; double *W; double *sigma; @@ -7751,6 +8008,12 @@ test_time_uncalibrated(void) TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, sigma); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, + TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_safe_free(W); tsk_safe_free(sigma); tsk_treeseq_free(&ts); @@ -7949,6 +8212,131 @@ test_split_edges_errors(void) tsk_treeseq_free(&ts); } +static void +test_extend_edges_simple(void) +{ + int ret; + tsk_treeseq_t ts, ets; + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges_ex = "0 10 2 0\n" + "0 10 2 1\n"; + + tsk_treeseq_from_text(&ts, 10, nodes_ex, edges_ex, NULL, NULL, NULL, NULL, NULL, 0); + + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + + tsk_treeseq_free(&ts); + tsk_treeseq_free(&ets); +} + +static void +assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) +{ + tsk_table_collection_t t1, t2; + int ret; + + ret = tsk_table_collection_copy(ts1->tables, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_table_collection_copy(ts2->tables, &t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_edge_table_clear(&t1.edges); + tsk_edge_table_clear(&t2.edges); + + CU_ASSERT_TRUE(tsk_table_collection_equals(&t1, &t2, 0)); + + tsk_table_collection_free(&t1); + tsk_table_collection_free(&t2); +} + +static void +test_extend_edges(void) +{ + int ret, max_iter; + tsk_table_collection_t tables; + tsk_treeseq_t ts, ets; + /* 7 and 8 should be extended to the whole sequence + + 6 6 6 6 + +-+-+ +-+-+ +-+-+ +-+-+ + | | 7 | | 8 | | + | | ++-+ | | +-++ | | + 4 5 4 | | 4 | 5 4 5 + +++ +++ +++ | | | | +++ +++ +++ + 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + */ + + const char *nodes_ex = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 1.0 -1 -1\n" + "0 3.0 -1 -1\n" + "0 2.0 -1 -1\n" + "0 2.0 -1 -1\n"; + // l, r, p, c + const char *edges_ex = "0 10 4 0\n" + "0 5 4 1\n" + "7 10 4 1\n" + "0 2 5 2\n" + "5 10 5 2\n" + "0 2 5 3\n" + "5 10 5 3\n" + "0 2 6 4\n" + "5 10 6 4\n" + "2 5 6 3\n" + "0 2 6 5\n" + "7 10 6 5\n" + "2 5 6 7\n" + "5 7 6 8\n" + "2 5 7 2\n" + "2 5 7 4\n" + "5 7 8 1\n" + "5 7 8 5\n"; + + /* Doing this rather than tsk_treeseq_from_text because the edges are unsorted */ + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 10; + parse_nodes(nodes_ex, &tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); + parse_edges(edges_ex, &tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); + ret = tsk_table_collection_sort(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_extend_edges(&ts, 0, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE_FATAL(tsk_table_collection_equals(ts.tables, ets.tables, 0)); + /* tsk_treeseq_print_state(&ets, stdout); */ + tsk_treeseq_free(&ets); + + for (max_iter = 1; max_iter < 10; max_iter++) { + ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_equal_except_edges(&ts, &ets); + CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 13); + tsk_treeseq_free(&ets); + } + + ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(ets.tables->nodes.num_rows, 9); + CU_ASSERT_EQUAL_FATAL(ets.tables->edges.num_rows, 13); + tsk_treeseq_free(&ets); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + static void test_init_take_ownership_no_edge_metadata(void) { @@ -8026,6 +8414,8 @@ main(int argc, char **argv) { "test_simplest_simplify_defragment", test_simplest_simplify_defragment }, { "test_simplest_population_filter", test_simplest_population_filter }, { "test_simplest_individual_filter", test_simplest_individual_filter }, + { "test_simplest_no_node_filter", test_simplest_no_node_filter }, + { "test_simplest_no_update_flags", test_simplest_no_update_flags }, { "test_simplest_map_mutations", test_simplest_map_mutations }, { "test_simplest_nonbinary_map_mutations", test_simplest_nonbinary_map_mutations }, @@ -8072,6 +8462,7 @@ main(int argc, char **argv) { "test_single_tree_map_mutations_internal_samples", test_single_tree_map_mutations_internal_samples }, { "test_single_tree_tracked_samples", test_single_tree_tracked_samples }, + { "test_single_tree_tree_pos", test_single_tree_tree_pos }, /* Multi tree tests */ { "test_simple_multi_tree", test_simple_multi_tree }, @@ -8165,6 +8556,8 @@ main(int argc, char **argv) { "test_split_edges_no_populations", test_split_edges_no_populations }, { "test_split_edges_populations", test_split_edges_populations }, { "test_split_edges_errors", test_split_edges_errors }, + { "test_extend_edges_simple", test_extend_edges_simple }, + { "test_extend_edges", test_extend_edges }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 823068d136..043ae5ceab 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -966,6 +966,16 @@ tskit_suite_init(void) return CUE_SUCCESS; } +void +assert_arrays_almost_equal(tsk_size_t len, double *a, double *b) +{ + tsk_size_t j; + + for (j = 0; j < len; j++) { + CU_ASSERT_DOUBLE_EQUAL(a[j], b[j], 1e-9); + } +} + static int tskit_suite_cleanup(void) { diff --git a/c/tests/testlib.h b/c/tests/testlib.h index d042d60b55..69efb14781 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2021 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,8 @@ void parse_individuals(const char *text, tsk_individual_table_t *individual_tabl void unsort_edges(tsk_edge_table_t *edges, size_t start); +void assert_arrays_almost_equal(tsk_size_t len, double *a, double *b); + extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; diff --git a/c/tskit/core.c b/c/tskit/core.c index bc50a21a5f..5a8ed6d9ac 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -222,6 +222,10 @@ tsk_strerror_internal(int err) case TSK_ERR_SEEK_OUT_OF_BOUNDS: ret = "Tree seek position out of bounds. (TSK_ERR_SEEK_OUT_OF_BOUNDS)"; break; + case TSK_ERR_KEEP_ROWS_MAP_TO_DELETED: + ret = "One of the kept rows in the table refers to a deleted row. " + "(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED)"; + break; /* Edge errors */ case TSK_ERR_NULL_PARENT: @@ -462,6 +466,19 @@ tsk_strerror_internal(int err) ret = "Statistics using branch lengths cannot be calculated when time_units " "is 'uncalibrated'. (TSK_ERR_TIME_UNCALIBRATED)"; break; + case TSK_ERR_STAT_POLARISED_UNSUPPORTED: + ret = "The TSK_STAT_POLARISED option is not supported by this statistic. " + "(TSK_ERR_STAT_POLARISED_UNSUPPORTED)"; + break; + case TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED: + ret = "The TSK_STAT_SPAN_NORMALISE option is not supported by this " + "statistic. " + "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; + break; + case TSK_ERR_INSUFFICIENT_WEIGHTS: + ret = "Insufficient weights provided (at least 1 required). " + "(TSK_ERR_INSUFFICIENT_WEIGHTS)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: diff --git a/c/tskit/core.h b/c/tskit/core.h index 0e7d528b0c..45a33dd8b7 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -123,6 +123,15 @@ specify options to API functions. typedef uint32_t tsk_flags_t; #define TSK_FLAGS_STORAGE_TYPE KAS_UINT32 +/** +@brief Boolean type. + +@rst +Fixed-size (1 byte) boolean values. +@endrst +*/ +typedef uint8_t tsk_bool_t; + // clang-format off /** @defgroup API_VERSION_GROUP API version macros. @@ -143,7 +152,7 @@ to the API or ABI are introduced, i.e., the addition of a new function. The library patch version. Incremented when any changes not relevant to the to the API or ABI are introduced, i.e., internal refactors of bugfixes. */ -#define TSK_VERSION_PATCH 1 +#define TSK_VERSION_PATCH 2 /** @} */ /* @@ -356,6 +365,12 @@ A time value was non-finite (NaN counts as finite) A genomic position was non-finite */ #define TSK_ERR_GENOME_COORDS_NONFINITE -211 +/** +One of the rows in the retained table refers to a row that has been +deleted. +*/ +#define TSK_ERR_KEEP_ROWS_MAP_TO_DELETED -212 + /** @} */ /** @@ -660,6 +675,20 @@ Statistics based on branch lengths were attempted when the ``time_units`` were ``uncalibrated``. */ #define TSK_ERR_TIME_UNCALIBRATED -910 +/** +The TSK_STAT_POLARISED option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_POLARISED_UNSUPPORTED -911 +/** +The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 +/** +Insufficient weights were provided. +*/ +#define TSK_ERR_INSUFFICIENT_WEIGHTS -913 /** @} */ /** diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index 41c1bd23a0..d6fdfd7f46 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -209,7 +209,7 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self) { tsk_tree_free(&self->tree); - tsk_diff_iter_free(&self->diffs); + tsk_tree_position_free(&self->tree_pos); tsk_safe_free(self->recombination_rate); tsk_safe_free(self->mutation_rate); tsk_safe_free(self->recombination_rate); @@ -248,9 +248,8 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) tsk_memset(self->transition_parent, 0xff, self->max_transitions * sizeof(*self->transition_parent)); - /* This is safe because we've already zero'd out the memory. */ - tsk_diff_iter_free(&self->diffs); - ret = tsk_diff_iter_init(&self->diffs, self->tree_sequence, false); + tsk_tree_position_free(&self->tree_pos); + ret = tsk_tree_position_init(&self->tree_pos, self->tree_sequence, 0); if (ret != 0) { goto out; } @@ -306,21 +305,20 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) int ret = 0; tsk_id_t *restrict parent = self->parent; tsk_id_t *restrict T_index = self->transition_index; + const tsk_id_t *restrict edges_child = self->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = self->tree_sequence->tables->edges.parent; tsk_value_transition_t *restrict T = self->transitions; - tsk_edge_list_node_t *record; - tsk_edge_list_t records_out, records_in; - tsk_edge_t edge; - double left, right; - tsk_id_t u; + tsk_id_t u, c, p, j, e; tsk_value_transition_t *vt; - ret = tsk_diff_iter_next(&self->diffs, &left, &right, &records_out, &records_in); - if (ret < 0) { - goto out; - } + tsk_tree_position_next(&self->tree_pos); + tsk_bug_assert(self->tree_pos.index != -1); + tsk_bug_assert(self->tree_pos.index == self->tree.index); - for (record = records_out.head; record != NULL; record = record->next) { - u = record->edge.child; + for (j = self->tree_pos.out.start; j < self->tree_pos.out.stop; j++) { + e = self->tree_pos.out.order[j]; + c = edges_child[e]; + u = c; if (T_index[u] == TSK_NULL) { /* Ensure the subtree we're detaching has a transition at the root */ while (T_index[u] == TSK_NULL) { @@ -328,25 +326,27 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) tsk_bug_assert(u != TSK_NULL); } tsk_bug_assert(self->num_transitions < self->max_transitions); - T_index[record->edge.child] = (tsk_id_t) self->num_transitions; - T[self->num_transitions].tree_node = record->edge.child; + T_index[c] = (tsk_id_t) self->num_transitions; + T[self->num_transitions].tree_node = c; T[self->num_transitions].value = T[T_index[u]].value; self->num_transitions++; } - parent[record->edge.child] = TSK_NULL; + parent[c] = TSK_NULL; } - for (record = records_in.head; record != NULL; record = record->next) { - edge = record->edge; - parent[edge.child] = edge.parent; - u = edge.parent; - if (parent[edge.parent] == TSK_NULL) { + for (j = self->tree_pos.in.start; j < self->tree_pos.in.stop; j++) { + e = self->tree_pos.in.order[j]; + c = edges_child[e]; + p = edges_parent[e]; + parent[c] = p; + u = p; + if (parent[p] == TSK_NULL) { /* Grafting onto a new root. */ - if (T_index[record->edge.parent] == TSK_NULL) { - T_index[edge.parent] = (tsk_id_t) self->num_transitions; + if (T_index[p] == TSK_NULL) { + T_index[p] = (tsk_id_t) self->num_transitions; tsk_bug_assert(self->num_transitions < self->max_transitions); - T[self->num_transitions].tree_node = edge.parent; - T[self->num_transitions].value = T[T_index[edge.child]].value; + T[self->num_transitions].tree_node = p; + T[self->num_transitions].value = T[T_index[c]].value; self->num_transitions++; } } else { @@ -356,18 +356,17 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) } tsk_bug_assert(u != TSK_NULL); } - tsk_bug_assert(T_index[u] != -1 && T_index[edge.child] != -1); - if (T[T_index[u]].value == T[T_index[edge.child]].value) { - vt = &T[T_index[edge.child]]; + tsk_bug_assert(T_index[u] != -1 && T_index[c] != -1); + if (T[T_index[u]].value == T[T_index[c]].value) { + vt = &T[T_index[c]]; /* Mark the value transition as unusued */ vt->value = -1; vt->tree_node = TSK_NULL; - T_index[edge.child] = TSK_NULL; + T_index[c] = TSK_NULL; } } ret = tsk_ls_hmm_remove_dead_roots(self); -out: return ret; } diff --git a/c/tskit/haplotype_matching.h b/c/tskit/haplotype_matching.h index 46631fb086..e4d82bef03 100644 --- a/c/tskit/haplotype_matching.h +++ b/c/tskit/haplotype_matching.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -98,7 +98,10 @@ typedef struct _tsk_ls_hmm_t { tsk_size_t num_nodes; /* state */ tsk_tree_t tree; - tsk_diff_iter_t diffs; + /* NOTE: this tree_position will be redundant once we integrate the top-level + * tree class with this. + */ + tsk_tree_position_t tree_pos; tsk_id_t *parent; /* The probability value transitions on the tree */ tsk_value_transition_t *transitions; diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 6205e36236..3e4c880302 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -732,6 +732,188 @@ write_metadata_schema_header( return fprintf(out, fmt, (int) metadata_schema_length, metadata_schema); } +/* Utilities for in-place subsetting columns */ + +static tsk_size_t +count_true(tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j; + tsk_size_t count = 0; + + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + count++; + } + } + return count; +} + +static void +keep_mask_to_id_map( + tsk_size_t num_rows, const tsk_bool_t *restrict keep, tsk_id_t *restrict id_map) +{ + tsk_size_t j; + tsk_id_t next_id = 0; + + for (j = 0; j < num_rows; j++) { + id_map[j] = TSK_NULL; + if (keep[j]) { + id_map[j] = next_id; + next_id++; + } + } +} + +static tsk_size_t +subset_remap_id_column(tsk_id_t *restrict column, tsk_size_t num_rows, + const tsk_bool_t *restrict keep, const tsk_id_t *restrict id_map) +{ + tsk_size_t j, k; + tsk_id_t value; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + value = column[j]; + if (value != TSK_NULL) { + value = id_map[value]; + } + column[k] = value; + k++; + } + } + return k; +} + +/* Trigger warning: C++ programmers should look away... This may be one of the + * few cases where some macro funkiness is warranted, as these are exact + * duplicates of the same function with just the type of the column + * parameter changed. */ + +static tsk_size_t +subset_id_column( + tsk_id_t *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_flags_column( + tsk_flags_t *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_double_column( + double *restrict column, tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j, k; + + k = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + column[k] = column[j]; + k++; + } + } + return k; +} + +static tsk_size_t +subset_ragged_char_column(char *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j, k, i, offset; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + /* Note: Unclear whether it's worth calling memcpy instead here? + * Need to be careful since the regions are overlapping */ + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + data[offset] = data[i]; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + +static tsk_size_t +subset_ragged_double_column(double *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const tsk_bool_t *restrict keep) +{ + tsk_size_t j, k, i, offset; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + /* Note: Unclear whether it's worth calling memcpy instead here? + * Need to be careful since the regions are overlapping */ + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + data[offset] = data[i]; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + +static tsk_size_t +subset_remap_ragged_id_column(tsk_id_t *restrict data, tsk_size_t *restrict offset_col, + tsk_size_t num_rows, const tsk_bool_t *restrict keep, + const tsk_id_t *restrict id_map) +{ + tsk_size_t j, k, i, offset; + tsk_id_t di; + + k = 0; + offset = 0; + for (j = 0; j < num_rows; j++) { + if (keep[j]) { + offset_col[k] = offset; + for (i = offset_col[j]; i < offset_col[j + 1]; i++) { + di = data[i]; + if (di != TSK_NULL) { + di = id_map[di]; + } + data[offset] = di; + offset++; + } + k++; + } + } + offset_col[k] = offset; + return offset; +} + /************************* * reference sequence *************************/ @@ -1622,6 +1804,71 @@ tsk_individual_table_equals(const tsk_individual_table_t *self, return ret; } +int +tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) +{ + int ret = 0; + const tsk_size_t current_num_rows = self->num_rows; + tsk_size_t j, k, remaining_rows; + tsk_id_t pk; + tsk_id_t *id_map = ret_id_map; + tsk_id_t *restrict parents = self->parents; + tsk_size_t *restrict parents_offset = self->parents_offset; + + if (ret_id_map == NULL) { + id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); + if (id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + + keep_mask_to_id_map(current_num_rows, keep, id_map); + + /* See notes in tsk_mutation_table_keep_rows for possibilities + * on making this more flexible */ + for (j = 0; j < current_num_rows; j++) { + if (keep[j]) { + for (k = parents_offset[j]; k < parents_offset[j + 1]; k++) { + pk = parents[k]; + if (pk != TSK_NULL) { + if (pk < 0 || pk >= (tsk_id_t) current_num_rows) { + ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ; + goto out; + } + if (id_map[pk] == TSK_NULL) { + ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + goto out; + } + } + } + } + } + + remaining_rows = subset_flags_column(self->flags, current_num_rows, keep); + self->parents_length = subset_remap_ragged_id_column( + self->parents, self->parents_offset, current_num_rows, keep, id_map); + self->location_length = subset_ragged_double_column( + self->location, self->location_offset, current_num_rows, keep); + if (self->metadata_length > 0) { + /* Implementation note: we special case metadata here because + * it'll make the common-case of no metadata a bit faster, and + * to also potentially support more general use of the + * TSK_TABLE_NO_METADATA option. This is done for all the tables + * but only commented on here. */ + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, current_num_rows, keep); + } + self->num_rows = remaining_rows; +out: + if (ret_id_map == NULL) { + tsk_safe_free(id_map); + } + return ret; +} + static int tsk_individual_table_dump( const tsk_individual_table_t *self, kastore_t *store, tsk_flags_t options) @@ -2271,6 +2518,29 @@ tsk_node_table_get_row(const tsk_node_table_t *self, tsk_id_t index, tsk_node_t return ret; } +int +tsk_node_table_keep_rows(tsk_node_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_flags_column(self->flags, self->num_rows, keep); + subset_double_column(self->time, self->num_rows, keep); + subset_id_column(self->population, self->num_rows, keep); + subset_id_column(self->individual, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_node_table_dump(const tsk_node_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -2940,6 +3210,29 @@ tsk_edge_table_equals( return ret; } +int +tsk_edge_table_keep_rows(tsk_edge_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + remaining_rows = subset_double_column(self->left, self->num_rows, keep); + subset_double_column(self->right, self->num_rows, keep); + subset_id_column(self->parent, self->num_rows, keep); + subset_id_column(self->child, self->num_rows, keep); + if (self->metadata_length > 0) { + tsk_bug_assert(!(self->options & TSK_TABLE_NO_METADATA)); + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_edge_table_dump(const tsk_edge_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -3675,6 +3968,28 @@ tsk_site_table_dump_text(const tsk_site_table_t *self, FILE *out) return ret; } +int +tsk_site_table_keep_rows(tsk_site_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_double_column(self->position, self->num_rows, keep); + self->ancestral_state_length = subset_ragged_char_column( + self->ancestral_state, self->ancestral_state_offset, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_site_table_dump(const tsk_site_table_t *self, kastore_t *store, tsk_flags_t options) { @@ -4418,6 +4733,65 @@ tsk_mutation_table_dump_text(const tsk_mutation_table_t *self, FILE *out) return ret; } +int +tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *ret_id_map) +{ + int ret = 0; + const tsk_size_t current_num_rows = self->num_rows; + tsk_size_t j, remaining_rows; + tsk_id_t pj; + tsk_id_t *id_map = ret_id_map; + tsk_id_t *restrict parent = self->parent; + + if (ret_id_map == NULL) { + id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); + if (id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + + keep_mask_to_id_map(current_num_rows, keep, id_map); + + /* Note: we could add some options to avoid these checks if we wanted. + * MAP_DELETED_TO_NULL is an obvious one, and I guess it might be + * helpful to also provide NO_REMAP to prevent reference remapping + * entirely. */ + for (j = 0; j < current_num_rows; j++) { + if (keep[j]) { + pj = parent[j]; + if (pj != TSK_NULL) { + if (pj < 0 || pj >= (tsk_id_t) current_num_rows) { + ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + goto out; + } + if (id_map[pj] == TSK_NULL) { + ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + goto out; + } + } + } + } + + remaining_rows = subset_id_column(self->site, current_num_rows, keep); + subset_id_column(self->node, current_num_rows, keep); + subset_remap_id_column(parent, current_num_rows, keep, id_map); + subset_double_column(self->time, current_num_rows, keep); + self->derived_state_length = subset_ragged_char_column( + self->derived_state, self->derived_state_offset, current_num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, current_num_rows, keep); + } + self->num_rows = remaining_rows; +out: + if (ret_id_map == NULL) { + tsk_safe_free(id_map); + } + return ret; +} + static int tsk_mutation_table_dump( const tsk_mutation_table_t *self, kastore_t *store, tsk_flags_t options) @@ -5063,6 +5437,31 @@ tsk_migration_table_equals(const tsk_migration_table_t *self, return ret; } +int +tsk_migration_table_keep_rows(tsk_migration_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + tsk_size_t remaining_rows; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + remaining_rows = subset_double_column(self->left, self->num_rows, keep); + subset_double_column(self->right, self->num_rows, keep); + subset_id_column(self->node, self->num_rows, keep); + subset_id_column(self->source, self->num_rows, keep); + subset_id_column(self->dest, self->num_rows, keep); + subset_double_column(self->time, self->num_rows, keep); + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = remaining_rows; + return ret; +} + static int tsk_migration_table_dump( const tsk_migration_table_t *self, kastore_t *store, tsk_flags_t options) @@ -5632,6 +6031,24 @@ tsk_population_table_equals(const tsk_population_table_t *self, return ret; } +int +tsk_population_table_keep_rows(tsk_population_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + + if (self->metadata_length > 0) { + self->metadata_length = subset_ragged_char_column( + self->metadata, self->metadata_offset, self->num_rows, keep); + } + self->num_rows = count_true(self->num_rows, keep); + return ret; +} + static int tsk_population_table_dump( const tsk_population_table_t *self, kastore_t *store, tsk_flags_t options) @@ -6244,6 +6661,24 @@ tsk_provenance_table_equals(const tsk_provenance_table_t *self, return ret; } +int +tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const tsk_bool_t *keep, + tsk_flags_t TSK_UNUSED(options), tsk_id_t *id_map) +{ + int ret = 0; + + if (id_map != NULL) { + keep_mask_to_id_map(self->num_rows, keep, id_map); + } + self->timestamp_length = subset_ragged_char_column( + self->timestamp, self->timestamp_offset, self->num_rows, keep); + self->record_length = subset_ragged_char_column( + self->record, self->record_offset, self->num_rows, keep); + self->num_rows = count_true(self->num_rows, keep); + + return ret; +} + static int tsk_provenance_table_dump( const tsk_provenance_table_t *self, kastore_t *store, tsk_flags_t options) @@ -7159,7 +7594,6 @@ typedef struct { } segment_overlapper_t; typedef struct { - tsk_id_t *samples; tsk_size_t num_samples; tsk_flags_t options; tsk_table_collection_t *tables; @@ -7168,6 +7602,7 @@ typedef struct { /* State for topology */ tsk_segment_t **ancestor_map_head; tsk_segment_t **ancestor_map_tail; + /* Mapping of input node IDs to output node IDs. */ tsk_id_t *node_id_map; bool *is_sample; /* Segments for a particular parent that are processed together */ @@ -7185,8 +7620,6 @@ typedef struct { tsk_size_t num_buffered_children; /* For each mutation, map its output node. */ tsk_id_t *mutation_node_map; - /* Map of input mutation IDs to output mutation IDs. */ - tsk_id_t *mutation_id_map; /* Map of input nodes to the list of input mutation IDs */ mutation_id_list_t **node_mutation_list_map_head; mutation_id_list_t **node_mutation_list_map_tail; @@ -8697,6 +9130,8 @@ simplifier_print_state(simplifier_t *self, FILE *out) fprintf(out, "options:\n"); fprintf(out, "\tfilter_unreferenced_sites : %d\n", !!(self->options & TSK_SIMPLIFY_FILTER_SITES)); + fprintf(out, "\tno_filter_nodes : %d\n", + !!(self->options & TSK_SIMPLIFY_NO_FILTER_NODES)); fprintf(out, "\treduce_to_site_topology : %d\n", !!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)); fprintf(out, "\tkeep_unary : %d\n", @@ -8805,19 +9240,21 @@ simplifier_alloc_interval_list(simplifier_t *self, double left, double right) /* Add a new node to the output node table corresponding to the specified input id. * Returns the new ID. */ static tsk_id_t TSK_WARN_UNUSED -simplifier_record_node(simplifier_t *self, tsk_id_t input_id, bool is_sample) +simplifier_record_node(simplifier_t *self, tsk_id_t input_id) { tsk_node_t node; - tsk_flags_t flags; + bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS); tsk_node_table_get_row_unsafe(&self->input_tables.nodes, (tsk_id_t) input_id, &node); - /* Zero out the sample bit */ - flags = node.flags & (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; - if (is_sample) { - flags |= TSK_NODE_IS_SAMPLE; + if (update_flags) { + /* Zero out the sample bit */ + node.flags &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; + if (self->is_sample[input_id]) { + node.flags |= TSK_NODE_IS_SAMPLE; + } } self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes.num_rows; - return tsk_node_table_add_row(&self->tables->nodes, flags, node.time, + return tsk_node_table_add_row(&self->tables->nodes, node.flags, node.time, node.population, node.individual, node.metadata, node.metadata_length); } @@ -8876,7 +9313,7 @@ simplifier_init_position_lookup(simplifier_t *self) goto out; } self->position_lookup[0] = 0; - self->position_lookup[num_sites + 1] = self->tables->sequence_length; + self->position_lookup[num_sites + 1] = self->input_tables.sequence_length; tsk_memcpy(self->position_lookup + 1, self->input_tables.sites.position, num_sites * sizeof(double)); out: @@ -8920,7 +9357,7 @@ simplifier_record_edge(simplifier_t *self, double left, double right, tsk_id_t c interval_list_t *tail, *x; bool skip; - if (!!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)) { + if (self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY) { skip = simplifier_map_reduced_coordinates(self, &left, &right); /* NOTE: we exit early here when reduce_coordindates has told us to * skip this edge, as it is not visible in the reduced tree sequence */ @@ -8966,8 +9403,6 @@ simplifier_init_sites(simplifier_t *self) mutation_id_list_t *list_node; tsk_size_t j; - self->mutation_id_map - = tsk_calloc(self->input_tables.mutations.num_rows, sizeof(tsk_id_t)); self->mutation_node_map = tsk_calloc(self->input_tables.mutations.num_rows, sizeof(tsk_id_t)); self->node_mutation_list_mem @@ -8976,15 +9411,12 @@ simplifier_init_sites(simplifier_t *self) = tsk_calloc(self->input_tables.nodes.num_rows, sizeof(mutation_id_list_t *)); self->node_mutation_list_map_tail = tsk_calloc(self->input_tables.nodes.num_rows, sizeof(mutation_id_list_t *)); - if (self->mutation_id_map == NULL || self->mutation_node_map == NULL - || self->node_mutation_list_mem == NULL + if (self->mutation_node_map == NULL || self->node_mutation_list_mem == NULL || self->node_mutation_list_map_head == NULL || self->node_mutation_list_map_tail == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - tsk_memset(self->mutation_id_map, 0xff, - self->input_tables.mutations.num_rows * sizeof(tsk_id_t)); tsk_memset(self->mutation_node_map, 0xff, self->input_tables.mutations.num_rows * sizeof(tsk_id_t)); @@ -9058,32 +9490,96 @@ simplifier_add_ancestry( return ret; } +/* Sets up the internal working copies of the various tables, as needed + * depending on the specified options. */ +static int +simplifier_init_tables(simplifier_t *self) +{ + int ret; + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool filter_populations = self->options & TSK_SIMPLIFY_FILTER_POPULATIONS; + bool filter_individuals = self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS; + bool filter_sites = self->options & TSK_SIMPLIFY_FILTER_SITES; + tsk_bookmark_t rows_to_retain; + + /* NOTE: this is a bit inefficient here as we're taking copies of + * the tables even in the no-filter case where the original tables + * won't be touched (beyond references to external tables that may + * need updating). Future versions may do something a bit more + * complicated like temporarily stealing the pointers to the + * underlying column memory in these tables, and then being careful + * not to free the table at the end. + */ + ret = tsk_table_collection_copy(self->tables, &self->input_tables, 0); + if (ret != 0) { + goto out; + } + memset(&rows_to_retain, 0, sizeof(rows_to_retain)); + rows_to_retain.provenances = self->tables->provenances.num_rows; + if (!filter_nodes) { + rows_to_retain.nodes = self->tables->nodes.num_rows; + } + if (!filter_populations) { + rows_to_retain.populations = self->tables->populations.num_rows; + } + if (!filter_individuals) { + rows_to_retain.individuals = self->tables->individuals.num_rows; + } + if (!filter_sites) { + rows_to_retain.sites = self->tables->sites.num_rows; + } + + ret = tsk_table_collection_truncate(self->tables, &rows_to_retain); + if (ret != 0) { + goto out; + } +out: + return ret; +} + static int -simplifier_init_samples(simplifier_t *self, const tsk_id_t *samples) +simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples) { int ret = 0; tsk_id_t node_id; tsk_size_t j; - - /* Go through the samples to check for errors. */ - for (j = 0; j < self->num_samples; j++) { - if (samples[j] < 0 - || samples[j] > (tsk_id_t) self->input_tables.nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; - goto out; + const tsk_size_t num_nodes = self->input_tables.nodes.num_rows; + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS); + tsk_flags_t *node_flags = self->tables->nodes.flags; + tsk_id_t *node_id_map = self->node_id_map; + + if (filter_nodes) { + tsk_bug_assert(self->tables->nodes.num_rows == 0); + /* The node table has been cleared. Add nodes for the samples. */ + for (j = 0; j < self->num_samples; j++) { + node_id = simplifier_record_node(self, samples[j]); + if (node_id < 0) { + ret = (int) node_id; + goto out; + } } - if (self->is_sample[samples[j]]) { - ret = TSK_ERR_DUPLICATE_SAMPLE; - goto out; + } else { + tsk_bug_assert(self->tables->nodes.num_rows == num_nodes); + if (update_flags) { + for (j = 0; j < num_nodes; j++) { + /* Reset the sample flags */ + node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE; + if (self->is_sample[j]) { + node_flags[j] |= TSK_NODE_IS_SAMPLE; + } + } } - self->is_sample[samples[j]] = true; - node_id = simplifier_record_node(self, samples[j], true); - if (node_id < 0) { - ret = (int) node_id; - goto out; + + for (j = 0; j < num_nodes; j++) { + node_id_map[j] = (tsk_id_t) j; } - ret = simplifier_add_ancestry( - self, samples[j], 0, self->tables->sequence_length, node_id); + } + /* Add the initial ancestry */ + for (j = 0; j < self->num_samples; j++) { + node_id = samples[j]; + ret = simplifier_add_ancestry(self, node_id, 0, + self->input_tables.sequence_length, self->node_id_map[node_id]); if (ret != 0) { goto out; } @@ -9097,6 +9593,7 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp tsk_table_collection_t *tables, tsk_flags_t options) { int ret = 0; + tsk_size_t j; tsk_id_t ret_id; tsk_size_t num_nodes; @@ -9118,19 +9615,6 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp goto out; } - ret = tsk_table_collection_copy(self->tables, &self->input_tables, 0); - if (ret != 0) { - goto out; - } - - /* Take a copy of the input samples */ - self->samples = tsk_malloc(num_samples * sizeof(tsk_id_t)); - if (self->samples == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_memcpy(self->samples, samples, num_samples * sizeof(tsk_id_t)); - /* Allocate the heaps used for small objects-> Assuming 8K is a good chunk size */ ret = tsk_blkalloc_init(&self->segment_heap, 8192); @@ -9164,26 +9648,40 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp ret = TSK_ERR_NO_MEMORY; goto out; } - ret = tsk_table_collection_clear(self->tables, 0); + + /* Go through the samples to check for errors before we clear the tables. */ + for (j = 0; j < self->num_samples; j++) { + if (samples[j] < 0 || samples[j] >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (self->is_sample[samples[j]]) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + self->is_sample[samples[j]] = true; + } + tsk_memset(self->node_id_map, 0xff, num_nodes * sizeof(tsk_id_t)); + + ret = simplifier_init_tables(self); if (ret != 0) { goto out; } - tsk_memset( - self->node_id_map, 0xff, self->input_tables.nodes.num_rows * sizeof(tsk_id_t)); ret = simplifier_init_sites(self); if (ret != 0) { goto out; } - ret = simplifier_init_samples(self, samples); + ret = simplifier_init_nodes(self, samples); if (ret != 0) { goto out; } - if (!!(self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY)) { + if (self->options & TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY) { ret = simplifier_init_position_lookup(self); if (ret != 0) { goto out; } } + self->edge_sort_offset = TSK_NULL; out: return ret; @@ -9196,7 +9694,6 @@ simplifier_free(simplifier_t *self) tsk_blkalloc_free(&self->segment_heap); tsk_blkalloc_free(&self->interval_list_heap); segment_overlapper_free(&self->segment_overlapper); - tsk_safe_free(self->samples); tsk_safe_free(self->ancestor_map_head); tsk_safe_free(self->ancestor_map_tail); tsk_safe_free(self->child_edge_map_head); @@ -9204,7 +9701,6 @@ simplifier_free(simplifier_t *self) tsk_safe_free(self->node_id_map); tsk_safe_free(self->segment_queue); tsk_safe_free(self->is_sample); - tsk_safe_free(self->mutation_id_map); tsk_safe_free(self->mutation_node_map); tsk_safe_free(self->node_mutation_list_mem); tsk_safe_free(self->node_mutation_list_map_head); @@ -9252,12 +9748,10 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) double left, right, prev_right; tsk_id_t ancestry_node; tsk_id_t output_id = self->node_id_map[input_id]; + bool is_sample = self->is_sample[input_id]; + bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES); + bool keep_unary = self->options & TSK_SIMPLIFY_KEEP_UNARY; - bool is_sample = output_id != TSK_NULL; - bool keep_unary = false; - if (self->options & TSK_SIMPLIFY_KEEP_UNARY) { - keep_unary = true; - } if ((self->options & TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS) && (self->input_tables.nodes.individual[input_id] != TSK_NULL)) { keep_unary = true; @@ -9292,7 +9786,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) ancestry_node = output_id; } else if (keep_unary) { if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); } ret = simplifier_record_edge(self, left, right, ancestry_node); if (ret != 0) { @@ -9301,7 +9795,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) } } else { if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); if (output_id < 0) { ret = (int) output_id; goto out; @@ -9348,7 +9842,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) if (ret != 0) { goto out; } - if (num_flushed_edges == 0 && !is_sample) { + if (filter_nodes && (num_flushed_edges == 0) && !is_sample) { ret = simplifier_rewind_node(self, input_id, output_id); } } @@ -9454,133 +9948,60 @@ simplifier_process_parent_edges( } static int TSK_WARN_UNUSED -simplifier_output_sites(simplifier_t *self) +simplifier_finalise_site_references( + simplifier_t *self, const bool *site_referenced, tsk_id_t *site_id_map) { int ret = 0; tsk_id_t ret_id; - tsk_id_t input_site; - tsk_id_t input_mutation, mapped_parent, site_start, site_end; - tsk_id_t num_input_sites = (tsk_id_t) self->input_tables.sites.num_rows; - tsk_id_t num_input_mutations = (tsk_id_t) self->input_tables.mutations.num_rows; - tsk_id_t num_output_mutations, num_output_site_mutations; - tsk_id_t mapped_node; - bool keep_site; - bool filter_sites = !!(self->options & TSK_SIMPLIFY_FILTER_SITES); + tsk_size_t j; tsk_site_t site; - tsk_mutation_t mutation; - - input_mutation = 0; - num_output_mutations = 0; - for (input_site = 0; input_site < num_input_sites; input_site++) { - tsk_site_table_get_row_unsafe( - &self->input_tables.sites, (tsk_id_t) input_site, &site); - site_start = input_mutation; - num_output_site_mutations = 0; - while (input_mutation < num_input_mutations - && self->input_tables.mutations.site[input_mutation] == site.id) { - mapped_node = self->mutation_node_map[input_mutation]; - if (mapped_node != TSK_NULL) { - self->mutation_id_map[input_mutation] = num_output_mutations; - num_output_mutations++; - num_output_site_mutations++; - } - input_mutation++; - } - site_end = input_mutation; - - keep_site = true; - if (filter_sites && num_output_site_mutations == 0) { - keep_site = false; - } - if (keep_site) { - for (input_mutation = site_start; input_mutation < site_end; - input_mutation++) { - if (self->mutation_id_map[input_mutation] != TSK_NULL) { - tsk_bug_assert( - self->tables->mutations.num_rows - == (tsk_size_t) self->mutation_id_map[input_mutation]); - mapped_node = self->mutation_node_map[input_mutation]; - tsk_bug_assert(mapped_node != TSK_NULL); - mapped_parent = self->input_tables.mutations.parent[input_mutation]; - if (mapped_parent != TSK_NULL) { - mapped_parent = self->mutation_id_map[mapped_parent]; - } - tsk_mutation_table_get_row_unsafe(&self->input_tables.mutations, - (tsk_id_t) input_mutation, &mutation); - ret_id = tsk_mutation_table_add_row(&self->tables->mutations, - (tsk_id_t) self->tables->sites.num_rows, mapped_node, - mapped_parent, mutation.time, mutation.derived_state, - mutation.derived_state_length, mutation.metadata, - mutation.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } + const tsk_size_t num_sites = self->input_tables.sites.num_rows; + + if (self->options & TSK_SIMPLIFY_FILTER_SITES) { + for (j = 0; j < num_sites; j++) { + tsk_site_table_get_row_unsafe( + &self->input_tables.sites, (tsk_id_t) j, &site); + site_id_map[j] = TSK_NULL; + if (site_referenced[j]) { + ret_id = tsk_site_table_add_row(&self->tables->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; } - } - ret_id = tsk_site_table_add_row(&self->tables->sites, site.position, - site.ancestral_state, site.ancestral_state_length, site.metadata, - site.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; + site_id_map[j] = ret_id; } } - tsk_bug_assert( - num_output_mutations == (tsk_id_t) self->tables->mutations.num_rows); - input_mutation = site_end; + } else { + tsk_bug_assert(self->tables->sites.num_rows == num_sites); + for (j = 0; j < num_sites; j++) { + site_id_map[j] = (tsk_id_t) j; + } } - tsk_bug_assert(input_mutation == num_input_mutations); - ret = 0; out: return ret; } static int TSK_WARN_UNUSED -simplifier_finalise_references(simplifier_t *self) +simplifier_finalise_population_references(simplifier_t *self) { int ret = 0; - tsk_id_t ret_id; tsk_size_t j; - bool keep; - tsk_size_t num_nodes = self->tables->nodes.num_rows; - + tsk_id_t pop_id, ret_id; tsk_population_t pop; - tsk_id_t pop_id; - tsk_size_t num_populations = self->input_tables.populations.num_rows; tsk_id_t *node_population = self->tables->nodes.population; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_size_t num_populations = self->input_tables.populations.num_rows; bool *population_referenced = tsk_calloc(num_populations, sizeof(*population_referenced)); tsk_id_t *population_id_map = tsk_malloc(num_populations * sizeof(*population_id_map)); - bool filter_populations = !!(self->options & TSK_SIMPLIFY_FILTER_POPULATIONS); - tsk_individual_t ind; - tsk_id_t ind_id; - tsk_size_t num_individuals = self->input_tables.individuals.num_rows; - tsk_id_t *node_individual = self->tables->nodes.individual; - bool *individual_referenced - = tsk_calloc(num_individuals, sizeof(*individual_referenced)); - tsk_id_t *individual_id_map - = tsk_malloc(num_individuals * sizeof(*individual_id_map)); - bool filter_individuals = !!(self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS); - - if (population_referenced == NULL || population_id_map == NULL - || individual_referenced == NULL || individual_id_map == NULL) { - goto out; - } + tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_POPULATIONS); - /* TODO Migrations fit reasonably neatly into the pattern that we have here. We - * can consider references to populations from migration objects in the same way - * as from nodes, so that we only remove a population if its referenced by - * neither. Mapping the population IDs in migrations is then easy. In principle - * nodes are similar, but the semantics are slightly different because we've - * already allocated all the nodes by their references from edges. We then - * need to decide whether we remove migrations that reference unmapped nodes - * or whether to add these nodes back in (probably the former is the correct - * approach).*/ - if (self->input_tables.migrations.num_rows != 0) { - ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + if (population_referenced == NULL || population_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } @@ -9589,20 +10010,13 @@ simplifier_finalise_references(simplifier_t *self) if (pop_id != TSK_NULL) { population_referenced[pop_id] = true; } - ind_id = node_individual[j]; - if (ind_id != TSK_NULL) { - individual_referenced[ind_id] = true; - } } + for (j = 0; j < num_populations; j++) { tsk_population_table_get_row_unsafe( &self->input_tables.populations, (tsk_id_t) j, &pop); - keep = true; - if (filter_populations && !population_referenced[j]) { - keep = false; - } population_id_map[j] = TSK_NULL; - if (keep) { + if (population_referenced[j]) { ret_id = tsk_population_table_add_row( &self->tables->populations, pop.metadata, pop.metadata_length); if (ret_id < 0) { @@ -9613,15 +10027,56 @@ simplifier_finalise_references(simplifier_t *self) } } + /* Remap the IDs in the node table */ + for (j = 0; j < num_nodes; j++) { + pop_id = node_population[j]; + if (pop_id != TSK_NULL) { + node_population[j] = population_id_map[pop_id]; + } + } +out: + tsk_safe_free(population_id_map); + tsk_safe_free(population_referenced); + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_finalise_individual_references(simplifier_t *self) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t pop_id, ret_id; + tsk_individual_t ind; + tsk_id_t *node_individual = self->tables->nodes.individual; + tsk_id_t *parents; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_size_t num_individuals = self->input_tables.individuals.num_rows; + bool *individual_referenced + = tsk_calloc(num_individuals, sizeof(*individual_referenced)); + tsk_id_t *individual_id_map + = tsk_malloc(num_individuals * sizeof(*individual_id_map)); + + tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS); + + if (individual_referenced == NULL || individual_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < num_nodes; j++) { + pop_id = node_individual[j]; + if (pop_id != TSK_NULL) { + individual_referenced[pop_id] = true; + } + } + for (j = 0; j < num_individuals; j++) { tsk_individual_table_get_row_unsafe( &self->input_tables.individuals, (tsk_id_t) j, &ind); - keep = true; - if (filter_individuals && !individual_referenced[j]) { - keep = false; - } individual_id_map[j] = TSK_NULL; - if (keep) { + if (individual_referenced[j]) { + /* Can't remap the parents inline here because we have no + * guarantees about sortedness */ ret_id = tsk_individual_table_add_row(&self->tables->individuals, ind.flags, ind.location, ind.location_length, ind.parents, ind.parents_length, ind.metadata, ind.metadata_length); @@ -9633,32 +10088,128 @@ simplifier_finalise_references(simplifier_t *self) } } - /* Remap parent IDs */ - for (j = 0; j < self->tables->individuals.parents_length; j++) { - self->tables->individuals.parents[j] - = self->tables->individuals.parents[j] == TSK_NULL - ? TSK_NULL - : individual_id_map[self->tables->individuals.parents[j]]; - } - - /* Remap node IDs referencing the above */ + /* Remap the IDs in the node table */ for (j = 0; j < num_nodes; j++) { - pop_id = node_population[j]; + pop_id = node_individual[j]; if (pop_id != TSK_NULL) { - node_population[j] = population_id_map[pop_id]; + node_individual[j] = individual_id_map[pop_id]; } - ind_id = node_individual[j]; - if (ind_id != TSK_NULL) { - node_individual[j] = individual_id_map[ind_id]; + } + + /* Remap parent IDs. * + * NOTE! must take the pointer reference here as it can change from + * the start of the function */ + parents = self->tables->individuals.parents; + for (j = 0; j < self->tables->individuals.parents_length; j++) { + if (parents[j] != TSK_NULL) { + parents[j] = individual_id_map[parents[j]]; } } - ret = 0; out: - tsk_safe_free(population_referenced); - tsk_safe_free(individual_referenced); - tsk_safe_free(population_id_map); tsk_safe_free(individual_id_map); + tsk_safe_free(individual_referenced); + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_output_sites(simplifier_t *self) +{ + int ret = 0; + tsk_id_t ret_id; + tsk_size_t j; + tsk_mutation_t mutation; + const tsk_size_t num_sites = self->input_tables.sites.num_rows; + const tsk_size_t num_mutations = self->input_tables.mutations.num_rows; + bool *site_referenced = tsk_calloc(num_sites, sizeof(*site_referenced)); + tsk_id_t *site_id_map = tsk_malloc(num_sites * sizeof(*site_id_map)); + tsk_id_t *mutation_id_map = tsk_malloc(num_mutations * sizeof(*mutation_id_map)); + const tsk_id_t *mutation_node_map = self->mutation_node_map; + const tsk_id_t *mutation_site = self->input_tables.mutations.site; + + if (site_referenced == NULL || site_id_map == NULL || mutation_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < num_mutations; j++) { + if (mutation_node_map[j] != TSK_NULL) { + site_referenced[mutation_site[j]] = true; + } + } + ret = simplifier_finalise_site_references(self, site_referenced, site_id_map); + if (ret != 0) { + goto out; + } + + for (j = 0; j < num_mutations; j++) { + mutation_id_map[j] = TSK_NULL; + if (mutation_node_map[j] != TSK_NULL) { + tsk_mutation_table_get_row_unsafe( + &self->input_tables.mutations, (tsk_id_t) j, &mutation); + mutation.node = mutation_node_map[j]; + mutation.site = site_id_map[mutation.site]; + if (mutation.parent != TSK_NULL) { + mutation.parent = mutation_id_map[mutation.parent]; + } + ret_id = tsk_mutation_table_add_row(&self->tables->mutations, mutation.site, + mutation.node, mutation.parent, mutation.time, mutation.derived_state, + mutation.derived_state_length, mutation.metadata, + mutation.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + mutation_id_map[j] = ret_id; + } + } +out: + tsk_safe_free(site_referenced); + tsk_safe_free(site_id_map); + tsk_safe_free(mutation_id_map); + return ret; +} + +/* Flush the remaining non-edge and node data in the model to the + * output tables. */ +static int TSK_WARN_UNUSED +simplifier_flush_output(simplifier_t *self) +{ + int ret = 0; + + /* TODO Migrations fit reasonably neatly into the pattern that we have here. We + * can consider references to populations from migration objects in the same way + * as from nodes, so that we only remove a population if its referenced by + * neither. Mapping the population IDs in migrations is then easy. In principle + * nodes are similar, but the semantics are slightly different because we've + * already allocated all the nodes by their references from edges. We then + * need to decide whether we remove migrations that reference unmapped nodes + * or whether to add these nodes back in (probably the former is the correct + * approach).*/ + if (self->input_tables.migrations.num_rows != 0) { + ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + ret = simplifier_output_sites(self); + if (ret != 0) { + goto out; + } + + if (self->options & TSK_SIMPLIFY_FILTER_POPULATIONS) { + ret = simplifier_finalise_population_references(self); + if (ret != 0) { + goto out; + } + } + if (self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS) { + ret = simplifier_finalise_individual_references(self); + if (ret != 0) { + goto out; + } + } + +out: return ret; } @@ -9707,7 +10258,7 @@ simplifier_insert_input_roots(simplifier_t *self) if (x != NULL) { output_id = self->node_id_map[input_id]; if (output_id == TSK_NULL) { - output_id = simplifier_record_node(self, input_id, false); + output_id = simplifier_record_node(self, input_id); if (output_id < 0) { ret = (int) output_id; goto out; @@ -9772,11 +10323,7 @@ simplifier_run(simplifier_t *self, tsk_id_t *node_map) goto out; } } - ret = simplifier_output_sites(self); - if (ret != 0) { - goto out; - } - ret = simplifier_finalise_references(self); + ret = simplifier_flush_output(self); if (ret != 0) { goto out; } @@ -13122,3 +13669,165 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output out: return ret; } + +/* ======================================================== * + * Tree diff iterator. + * ======================================================== */ + +int TSK_WARN_UNUSED +tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, + tsk_id_t num_trees, tsk_flags_t options) +{ + int ret = 0; + + tsk_bug_assert(tables != NULL); + tsk_memset(self, 0, sizeof(tsk_diff_iter_t)); + self->num_nodes = tables->nodes.num_rows; + self->num_edges = tables->edges.num_rows; + self->tables = tables; + self->insertion_index = 0; + self->removal_index = 0; + self->tree_left = 0; + self->tree_index = -1; + if (num_trees < 0) { + num_trees = tsk_table_collection_check_integrity(self->tables, TSK_CHECK_TREES); + if (num_trees < 0) { + ret = (int) num_trees; + goto out; + } + } + self->last_index = num_trees; + + if (options & TSK_INCLUDE_TERMINAL) { + self->last_index = self->last_index + 1; + } + self->edge_list_nodes = tsk_malloc(self->num_edges * sizeof(*self->edge_list_nodes)); + if (self->edge_list_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +int +tsk_diff_iter_free(tsk_diff_iter_t *self) +{ + tsk_safe_free(self->edge_list_nodes); + return 0; +} + +void +tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out) +{ + fprintf(out, "tree_diff_iterator state\n"); + fprintf(out, "num_edges = %lld\n", (long long) self->num_edges); + fprintf(out, "insertion_index = %lld\n", (long long) self->insertion_index); + fprintf(out, "removal_index = %lld\n", (long long) self->removal_index); + fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "tree_index = %lld\n", (long long) self->tree_index); +} + +int TSK_WARN_UNUSED +tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, + tsk_edge_list_t *edges_out_ret, tsk_edge_list_t *edges_in_ret) +{ + int ret = 0; + tsk_id_t k; + const double sequence_length = self->tables->sequence_length; + double left = self->tree_left; + double right = sequence_length; + tsk_size_t next_edge_list_node = 0; + tsk_edge_list_node_t *out_head = NULL; + tsk_edge_list_node_t *out_tail = NULL; + tsk_edge_list_node_t *in_head = NULL; + tsk_edge_list_node_t *in_tail = NULL; + tsk_edge_list_node_t *w = NULL; + tsk_edge_list_t edges_out; + tsk_edge_list_t edges_in; + const tsk_edge_table_t *edges = &self->tables->edges; + const tsk_id_t *insertion_order = self->tables->indexes.edge_insertion_order; + const tsk_id_t *removal_order = self->tables->indexes.edge_removal_order; + + tsk_memset(&edges_out, 0, sizeof(edges_out)); + tsk_memset(&edges_in, 0, sizeof(edges_in)); + + if (self->tree_index + 1 < self->last_index) { + /* First we remove the stale records */ + while (self->removal_index < (tsk_id_t) self->num_edges + && left == edges->right[removal_order[self->removal_index]]) { + k = removal_order[self->removal_index]; + tsk_bug_assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->edge.metadata = edges->metadata + edges->metadata_offset[k]; + w->edge.metadata_length + = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; + w->next = NULL; + w->prev = NULL; + if (out_head == NULL) { + out_head = w; + out_tail = w; + } else { + out_tail->next = w; + w->prev = out_tail; + out_tail = w; + } + self->removal_index++; + } + edges_out.head = out_head; + edges_out.tail = out_tail; + + /* Now insert the new records */ + while (self->insertion_index < (tsk_id_t) self->num_edges + && left == edges->left[insertion_order[self->insertion_index]]) { + k = insertion_order[self->insertion_index]; + tsk_bug_assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->edge.metadata = edges->metadata + edges->metadata_offset[k]; + w->edge.metadata_length + = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; + w->next = NULL; + w->prev = NULL; + if (in_head == NULL) { + in_head = w; + in_tail = w; + } else { + in_tail->next = w; + w->prev = in_tail; + in_tail = w; + } + self->insertion_index++; + } + edges_in.head = in_head; + edges_in.tail = in_tail; + + right = sequence_length; + if (self->insertion_index < (tsk_id_t) self->num_edges) { + right = TSK_MIN(right, edges->left[insertion_order[self->insertion_index]]); + } + if (self->removal_index < (tsk_id_t) self->num_edges) { + right = TSK_MIN(right, edges->right[removal_order[self->removal_index]]); + } + self->tree_index++; + ret = TSK_TREE_OK; + } + *edges_out_ret = edges_out; + *edges_in_ret = edges_in; + *ret_left = left; + *ret_right = right; + /* Set the left coordinate for the next tree */ + self->tree_left = right; + return ret; +} diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 299d2f8200..872b9b8fa1 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -678,6 +678,30 @@ typedef struct { bool store_pairs; } tsk_identity_segments_t; +/* Diff iterator. */ +typedef struct _tsk_edge_list_node_t { + tsk_edge_t edge; + struct _tsk_edge_list_node_t *next; + struct _tsk_edge_list_node_t *prev; +} tsk_edge_list_node_t; + +typedef struct { + tsk_edge_list_node_t *head; + tsk_edge_list_node_t *tail; +} tsk_edge_list_t; + +typedef struct { + tsk_size_t num_nodes; + tsk_size_t num_edges; + double tree_left; + const tsk_table_collection_t *tables; + tsk_id_t insertion_index; + tsk_id_t removal_index; + tsk_id_t tree_index; + tsk_id_t last_index; + tsk_edge_list_node_t *edge_list_nodes; +} tsk_diff_iter_t; + /****************************************************************************/ /* Common function options */ /****************************************************************************/ @@ -694,6 +718,17 @@ reference them. */ #define TSK_SIMPLIFY_FILTER_POPULATIONS (1 << 1) /** Remove individuals from the output if there are no nodes that reference them.*/ #define TSK_SIMPLIFY_FILTER_INDIVIDUALS (1 << 2) +/** Do not remove nodes from the output if there are no edges that reference +them and do not reorder nodes so that the samples are nodes 0 to num_samples - 1. +Note that this flag is negated compared to other filtering options because +the default behaviour is to filter unreferenced nodes and reorder to put samples +first. +*/ +#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7) +/** +Do not update the sample status of nodes as a result of simplification. +*/ +#define TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS (1 << 8) /** Reduce the topological information in the tables to the minimum necessary to represent the trees that contain sites. If there are zero sites this will @@ -889,6 +924,16 @@ top-level information of the table collections being compared. #define TSK_CLEAR_PROVENANCE (1 << 2) /** @} */ +/* For the edge diff iterator */ +#define TSK_INCLUDE_TERMINAL (1 << 0) + +/** @brief Value returned by seeking methods when they have successfully + seeked to a non-null tree. + + @ingroup TREE_API_SEEKING_GROUP +*/ +#define TSK_TREE_OK 1 + /****************************************************************************/ /* Function signatures */ /****************************************************************************/ @@ -1040,6 +1085,55 @@ int tsk_individual_table_extend(tsk_individual_table_t *self, const tsk_individual_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +The values in the ``parents`` column are updated according to this map, so that +reference integrity within the table is maintained. As a consequence of this, +the values in the ``parents`` column for kept rows are bounds-checked and an +error raised if they are not valid. Rows that are deleted are not checked for +parent ID integrity. + +If an attempt is made to delete rows that are referred to by the ``parents`` +column of rows that are retained, an error is raised. + +These error conditions are checked before any alterations to the table are +made. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_individual_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -1388,6 +1482,43 @@ and is not checked for compatibility with any existing schema on this table. int tsk_node_table_extend(tsk_node_table_t *self, const tsk_node_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_node_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_node_table_keep_rows(tsk_node_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -1698,6 +1829,43 @@ as-is and is not checked for compatibility with any existing schema on this tabl int tsk_edge_table_extend(tsk_edge_table_t *self, const tsk_edge_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_edge_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_edge_table_keep_rows(tsk_edge_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2032,6 +2200,43 @@ int tsk_migration_table_extend(tsk_migration_table_t *self, const tsk_migration_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_migration_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_migration_table_keep_rows(tsk_migration_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2340,6 +2545,43 @@ and is not checked for compatibility with any existing schema on this table. int tsk_site_table_extend(tsk_site_table_t *self, const tsk_site_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_site_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_site_table_keep_rows(tsk_site_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -2676,6 +2918,55 @@ int tsk_mutation_table_extend(tsk_mutation_table_t *self, const tsk_mutation_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +The values in the ``parent`` column are updated according to this map, so that +reference integrity within the table is maintained. As a consequence of this, +the values in the ``parent`` column for kept rows are bounds-checked and an +error raised if they are not valid. Rows that are deleted are not checked for +parent ID integrity. + +If an attempt is made to delete rows that are referred to by the ``parent`` +column of rows that are retained, an error is raised. + +These error conditions are checked before any alterations to the table are +made. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_mutation_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -3003,6 +3294,43 @@ int tsk_population_table_extend(tsk_population_table_t *self, const tsk_population_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_population_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_population_table_keep_rows(tsk_population_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -3297,6 +3625,43 @@ int tsk_provenance_table_extend(tsk_provenance_table_t *self, const tsk_provenance_table_t *other, tsk_size_t num_rows, const tsk_id_t *row_indexes, tsk_flags_t options); +/** +@brief Subset this table by keeping rows according to a boolean mask. + +@rst +Deletes rows from this table and optionally return the mapping from IDs in +the current table to the updated table. Rows are kept or deleted according to +the specified boolean array ``keep`` such that for each row ``j`` if +``keep[j]`` is false (zero) the row is deleted, and otherwise the row is +retained. Thus, ``keep`` must be an array of at least ``num_rows`` +:c:type:`bool` values. + +If the ``id_map`` argument is non-null, this array will be updated to represent +the mapping between IDs before and after row deletion. For row ``j``, +``id_map[j]`` will contain the new ID for row ``j`` if it is retained, or +:c:macro:`TSK_NULL` if the row has been removed. Thus, ``id_map`` must be an +array of at least ``num_rows`` :c:type:`tsk_id_t` values. + +.. warning:: + C++ users need to be careful to specify the correct type when + passing in values for the ``keep`` array, + using ``std::vector`` and not ``std::vector``, + as the latter may not be correct size. + +@endrst + +@param self A pointer to a tsk_provenance_table_t object. +@param keep Array of boolean flags describing whether a particular + row should be kept or not. Must be at least ``num_rows`` long. +@param options Bitwise option flags. Currently unused; should be + set to zero to ensure compatibility with later versions of tskit. +@param id_map An array in which to store the mapping between new + and old IDs. If NULL, this will be ignored. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_provenance_table_keep_rows(tsk_provenance_table_t *self, const tsk_bool_t *keep, + tsk_flags_t options, tsk_id_t *id_map); + /** @brief Returns true if the data in the specified table is identical to the data in this table. @@ -3917,8 +4282,44 @@ A mapping from the node IDs in the table before simplification to their equivale values after simplification can be obtained via the ``node_map`` argument. If this is non NULL, ``node_map[u]`` will contain the new ID for node ``u`` after simplification, or :c:macro:`TSK_NULL` if the node has been removed. Thus, ``node_map`` must be an array -of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values. The table collection will -always be unindexed after simplify successfully completes. +of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values. + +If the `TSK_SIMPLIFY_NO_FILTER_NODES` option is specified, the node table will be +unaltered except for changing the sample status of nodes (but see the +`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option below) and to update references +to other tables that may have changed as a result of filtering (see below). +The ``node_map`` (if specified) will always be the identity mapping, such that +``node_map[u] == u`` for all nodes. Note also that the order of the list of +samples is not important in this case. + +When a table is not filtered (i.e., if the `TSK_SIMPLIFY_NO_FILTER_NODES` +option is provided or the `TSK_SIMPLIFY_FILTER_SITES`, +`TSK_SIMPLIFY_FILTER_POPULATIONS` or `TSK_SIMPLIFY_FILTER_INDIVIDUALS` +options are *not* provided) the corresponding table is modified as +little as possible, and all pointers are guaranteed to remain valid +after simplification. The only changes made to an unfiltered table are +to update any references to tables that may have changed (for example, +remapping population IDs in the node table if +`TSK_SIMPLIFY_FILTER_POPULATIONS` was specified) or altering the +sample status flag of nodes. + +.. note:: It is possible for populations and individuals to be filtered + even if `TSK_SIMPLIFY_NO_FILTER_NODES` is specified because there + may be entirely unreferenced entities in the input tables, which + are not affected by whether we filter nodes or not. + +By default, the node sample flags are updated by unsetting the +:c:macro:`TSK_NODE_IS_SAMPLE` flag for all nodes and subsequently setting it +for the nodes provided as input to this function. The +`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option will prevent this from occuring, +making it the responsibility of calling code to keep track of the ultimate +sample status of nodes. Using this option in conjunction with +`TSK_SIMPLIFY_NO_FILTER_NODES` (and without the +`TSK_SIMPLIFY_FILTER_POPULATIONS` and `TSK_SIMPLIFY_FILTER_INDIVIDUALS` +options) guarantees that the node table will not be written to during the +lifetime of this function. + +The table collection will always be unindexed after simplify successfully completes. .. note:: Migrations are currently not supported by simplify, and an error will be raised if we attempt call simplify on a table collection with greater @@ -3932,6 +4333,8 @@ Options can be specified by providing one or more of the following bitwise - :c:macro:`TSK_SIMPLIFY_FILTER_SITES` - :c:macro:`TSK_SIMPLIFY_FILTER_POPULATIONS` - :c:macro:`TSK_SIMPLIFY_FILTER_INDIVIDUALS` +- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES` +- :c:macro:`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` - :c:macro:`TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY` - :c:macro:`TSK_SIMPLIFY_KEEP_UNARY` - :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS` @@ -4392,6 +4795,19 @@ int tsk_identity_segments_get(const tsk_identity_segments_t *self, tsk_id_t a, void tsk_identity_segments_print_state(tsk_identity_segments_t *self, FILE *out); int tsk_identity_segments_free(tsk_identity_segments_t *self); +/* Edge differences */ + +/* Internal API - currently used in a few places, but a better API is envisaged + * at some point. + * IMPORTANT: tskit-rust uses this API, so don't break without discussing! + */ +int tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, + tsk_id_t num_trees, tsk_flags_t options); +int tsk_diff_iter_free(tsk_diff_iter_t *self); +int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, + tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); +void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); + #ifdef __cplusplus } #endif diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4fcb2ee376..04936afb71 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -1191,9 +1191,11 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, * General stats framework ***********************************/ +#define TSK_REQUIRE_FULL_SPAN 1 + static int -tsk_treeseq_check_windows( - const tsk_treeseq_t *self, tsk_size_t num_windows, const double *windows) +tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, + const double *windows, tsk_flags_t options) { int ret = TSK_ERR_BAD_WINDOWS; tsk_size_t j; @@ -1202,12 +1204,23 @@ tsk_treeseq_check_windows( ret = TSK_ERR_BAD_NUM_WINDOWS; goto out; } - /* TODO these restrictions can be lifted later if we want a specific interval. */ - if (windows[0] != 0) { - goto out; - } - if (windows[num_windows] != self->tables->sequence_length) { - goto out; + if (options & TSK_REQUIRE_FULL_SPAN) { + /* TODO the general stat code currently requires that we include the + * entire tree sequence span. This should be relaxed, so hopefully + * this branch (and the option) can be removed at some point */ + if (windows[0] != 0) { + goto out; + } + if (windows[num_windows] != self->tables->sequence_length) { + goto out; + } + } else { + if (windows[0] < 0) { + goto out; + } + if (windows[num_windows] > self->tables->sequence_length) { + goto out; + } } for (j = 0; j < num_windows; j++) { if (windows[j] >= windows[j + 1]) { @@ -1264,7 +1277,7 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, { int ret = 0; tsk_id_t u, v; - tsk_size_t j, k, tree_index, window_index; + tsk_size_t j, k, window_index; tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; @@ -1315,7 +1328,6 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tj = 0; tk = 0; t_left = 0; - tree_index = 0; window_index = 0; while (tj < num_edges || t_left < sequence_length) { while (tk < num_edges && edge_right[O[tk]] == t_left) { @@ -1400,7 +1412,6 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, } /* Move to the next tree */ t_left = t_right; - tree_index++; } tsk_bug_assert(window_index == num_windows); out: @@ -1962,7 +1973,8 @@ tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -2059,6 +2071,11 @@ typedef struct { const tsk_id_t *set_indexes; } sample_count_stat_params_t; +typedef struct { + double *total_weights; + const tsk_id_t *index_tuples; +} indexed_weight_stat_params_t; + static int tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -2470,7 +2487,7 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); - double default_windows[] = { 0, self->tables->sequence_length }; + const double default_windows[] = { 0, self->tables->sequence_length }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; const tsk_size_t K = num_sample_sets + 1; tsk_size_t j, k, l, afs_size; @@ -2498,7 +2515,8 @@ tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -2621,6 +2639,10 @@ tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights, ret = TSK_ERR_NO_MEMORY; goto out; } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } // center weights for (j = 0; j < num_samples; j++) { @@ -2692,7 +2714,7 @@ tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights, } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -2805,7 +2827,7 @@ tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights } if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; goto out; } @@ -3012,6 +3034,79 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample return ret; } +static int +genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double meanx, ni, nj; + + meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + ni = args.total_weights[i]; + nj = args.total_weights[j]; + result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + } + return 0; +} + +int +tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options) +{ + int ret = 0; + tsk_size_t num_samples = self->num_samples; + size_t j, k; + indexed_weight_stat_params_t args; + const double *row; + double *new_row; + double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights)); + double *new_weights + = tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); + + if (total_weights == NULL || new_weights == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + if (num_weights == 0) { + ret = TSK_ERR_INSUFFICIENT_WEIGHTS; + goto out; + } + + // Add a column of ones to W + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k]; + total_weights[k] += row[k]; + } + new_row[num_weights] = 1.0; + } + total_weights[num_weights] = (double) num_samples; + + args.total_weights = total_weights; + args.index_tuples = index_tuples; + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, + genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options, + result); + if (ret != 0) { + goto out; + } + +out: + tsk_safe_free(total_weights); + tsk_safe_free(new_weights); + return ret; +} + static int Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) @@ -3333,7 +3428,7 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, } ret = tsk_treeseq_init( output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - /* Once tsk_tree_init has returned ownership of tables is transferred */ + /* Once tsk_treeseq_init has returned ownership of tables is transferred */ tables = NULL; out: if (tables != NULL) { @@ -3458,10 +3553,182 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag return ret; } +/* ======================================================== * + * tree_position + * ======================================================== */ + +static void +tsk_tree_position_set_null(tsk_tree_position_t *self) +{ + self->index = -1; + self->interval.left = 0; + self->interval.right = 0; +} + +int +tsk_tree_position_init(tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, + tsk_flags_t TSK_UNUSED(options)) +{ + memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + tsk_tree_position_set_null(self); + return 0; +} + +int +tsk_tree_position_free(tsk_tree_position_t *TSK_UNUSED(self)) +{ + return 0; +} + +int +tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out) +{ + fprintf(out, "Tree position state\n"); + fprintf(out, "index = %d\n", (int) self->index); + fprintf( + out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop); + fprintf( + out, "in = start=%d\tstop=%d\n", (int) self->in.start, (int) self->in.stop); + return 0; +} + +bool +tsk_tree_position_next(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double left; + + if (self->index == -1) { + self->interval.right = 0; + self->in.stop = 0; + self->out.stop = 0; + self->direction = TSK_DIR_FORWARD; + } + + if (self->direction == TSK_DIR_FORWARD) { + left_current_index = self->in.stop; + right_current_index = self->out.stop; + } else { + left_current_index = self->out.stop + 1; + right_current_index = self->in.stop + 1; + } + + left = self->interval.right; + + j = right_current_index; + self->out.start = j; + while (j < M && right_coords[right_order[j]] == left) { + j++; + } + self->out.stop = j; + self->out.order = right_order; + + j = left_current_index; + self->in.start = j; + while (j < M && left_coords[left_order[j]] == left) { + j++; + } + self->in.stop = j; + self->in.order = left_order; + + self->direction = TSK_DIR_FORWARD; + self->index++; + if (self->index == num_trees) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = left; + self->interval.right = breakpoints[self->index + 1]; + } + return self->index != -1; +} + +bool +tsk_tree_position_prev(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double right; + + if (self->index == -1) { + self->index = num_trees; + self->interval.left = sequence_length; + self->in.stop = M - 1; + self->out.stop = M - 1; + self->direction = TSK_DIR_REVERSE; + } + + if (self->direction == TSK_DIR_REVERSE) { + left_current_index = self->out.stop; + right_current_index = self->in.stop; + } else { + left_current_index = self->in.stop - 1; + right_current_index = self->out.stop - 1; + } + + right = self->interval.left; + + j = left_current_index; + self->out.start = j; + while (j >= 0 && left_coords[left_order[j]] == right) { + j--; + } + self->out.stop = j; + self->out.order = left_order; + + j = right_current_index; + self->in.start = j; + while (j >= 0 && right_coords[right_order[j]] == right) { + j--; + } + self->in.stop = j; + self->in.order = right_order; + + self->index--; + self->direction = TSK_DIR_REVERSE; + if (self->index == -1) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = breakpoints[self->index]; + self->interval.right = right; + } + return self->index != -1; +} + /* ======================================================== * * Tree * ======================================================== */ +/* Return the root for the specified node. + * NOTE: no bounds checking is done here. + */ +static tsk_id_t +tsk_tree_get_node_root(const tsk_tree_t *self, tsk_id_t u) +{ + const tsk_id_t *restrict parent = self->parent; + + while (parent[u] != TSK_NULL) { + u = parent[u]; + } + return u; +} + int TSK_WARN_UNUSED tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { @@ -4551,19 +4818,138 @@ tsk_tree_position_in_interval(const tsk_tree_t *self, double x) return self->interval.left <= x && x < self->interval.right; } -int TSK_WARN_UNUSED -tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +/* NOTE: + * + * Notes from Kevin Thornton: + * + * This method inserts the edges for an arbitrary tree + * in linear time and requires no additional memory. + * + * During design, the following alternatives were tested + * (in a combination of rust + C): + * 1. Indexing edge insertion/removal locations by tree. + * The indexing can be done in O(n) time, giving O(1) + * access to the first edge in a tree. We can then add + * edges to the tree in O(e) time, where e is the number + * of edges. This apparoach requires O(n) additional memory + * and is only marginally faster than the implementation below. + * 2. Building an interval tree mapping edge id -> span. + * This approach adds a lot of complexity and wasn't any faster + * than the indexing described above. + */ +static int +tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) { int ret = 0; + tsk_size_t edge; + tsk_id_t p, c, e, j, k, tree_index; const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - const double t_l = self->interval.left; - const double t_r = self->interval.right; - double distance_left, distance_right; + const tsk_treeseq_t *treeseq = self->tree_sequence; + const tsk_table_collection_t *tables = treeseq->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const tsk_size_t num_edges = tables->edges.num_rows; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + const double *restrict breakpoints = treeseq->breakpoints; + const tsk_id_t *restrict insertion = tables->indexes.edge_insertion_order; + const tsk_id_t *restrict removal = tables->indexes.edge_removal_order; + + // NOTE: it may be better to get the + // index first and then ask if we are + // searching in the first or last 1/2 + // of trees. + j = -1; + if (x <= L / 2.0) { + for (edge = 0; edge < num_edges; edge++) { + e = insertion[edge]; + if (edge_left[e] > x) { + j = (tsk_id_t) edge; + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } else { + for (edge = 0; edge < num_edges; edge++) { + e = removal[num_edges - edge - 1]; + if (edge_right[e] < x) { + j = (tsk_id_t)(num_edges - edge - 1); + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } - if (x < 0 || x >= L) { + if (j == -1) { + j = 0; + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + } + k = 0; + while (k < (tsk_id_t) num_edges && edge_right[removal[k]] <= x) { + k++; + } + + /* NOTE: tsk_search_sorted finds the first the first + * insertion locatiom >= the query point, which + * finds a RIGHT value for queries not at the left edge. + */ + tree_index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[tree_index] > x) { + tree_index--; + } + self->index = tree_index; + self->interval.left = breakpoints[tree_index]; + self->interval.right = breakpoints[tree_index + 1]; + self->left_index = j; + self->right_index = k; + self->direction = TSK_DIR_FORWARD; + self->num_nodes = tables->nodes.num_rows; + if (tables->sites.num_rows > 0) { + self->sites = treeseq->tree_sites[self->index]; + self->sites_length = treeseq->tree_sites_length[self->index]; + } + + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) +{ + int ret = 0; + double x; + + if (tree < 0 || tree >= (tsk_id_t) self->tree_sequence->num_trees) { ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; goto out; } + x = self->tree_sequence->breakpoints[tree]; + ret = tsk_tree_seek(self, x, options); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +{ + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + const double t_l = self->interval.left; + const double t_r = self->interval.right; + int ret = 0; + double distance_left, distance_right; if (x < t_l) { /* |-----|-----|========|---------| */ @@ -4596,6 +4982,27 @@ tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) return ret; } +int TSK_WARN_UNUSED +tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) +{ + int ret = 0; + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + + if (x < 0 || x >= L) { + ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; + goto out; + } + + if (self->index == -1) { + ret = tsk_tree_seek_from_null(self, x, options); + } else { + ret = tsk_tree_seek_linear(self, x, options); + } + +out: + return ret; +} + int TSK_WARN_UNUSED tsk_tree_clear(tsk_tree_t *self) { @@ -5344,159 +5751,16 @@ tsk_tree_map_mutations(tsk_tree_t *self, int32_t *genotypes, return ret; } -/* ======================================================== * - * Tree diff iterator. - * ======================================================== */ - +/* Compatibility shim for initialising the diff iterator from a tree sequence. We are + * using this function in a small number of places internally, so simplest to keep it + * until a more satisfactory "diff" API comes along. + */ int TSK_WARN_UNUSED -tsk_diff_iter_init( +tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) { - int ret = 0; - - tsk_bug_assert(tree_sequence != NULL); - tsk_memset(self, 0, sizeof(tsk_diff_iter_t)); - self->num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); - self->num_edges = tsk_treeseq_get_num_edges(tree_sequence); - self->tree_sequence = tree_sequence; - self->insertion_index = 0; - self->removal_index = 0; - self->tree_left = 0; - self->tree_index = -1; - self->last_index = (tsk_id_t) tsk_treeseq_get_num_trees(tree_sequence); - if (options & TSK_INCLUDE_TERMINAL) { - self->last_index = self->last_index + 1; - } - self->edge_list_nodes = tsk_malloc(self->num_edges * sizeof(*self->edge_list_nodes)); - if (self->edge_list_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } -out: - return ret; -} - -int -tsk_diff_iter_free(tsk_diff_iter_t *self) -{ - tsk_safe_free(self->edge_list_nodes); - return 0; -} - -void -tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out) -{ - fprintf(out, "tree_diff_iterator state\n"); - fprintf(out, "num_edges = %lld\n", (long long) self->num_edges); - fprintf(out, "insertion_index = %lld\n", (long long) self->insertion_index); - fprintf(out, "removal_index = %lld\n", (long long) self->removal_index); - fprintf(out, "tree_left = %f\n", self->tree_left); - fprintf(out, "tree_index = %lld\n", (long long) self->tree_index); -} - -int TSK_WARN_UNUSED -tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, - tsk_edge_list_t *edges_out_ret, tsk_edge_list_t *edges_in_ret) -{ - int ret = 0; - tsk_id_t k; - const double sequence_length = self->tree_sequence->tables->sequence_length; - double left = self->tree_left; - double right = sequence_length; - tsk_size_t next_edge_list_node = 0; - const tsk_treeseq_t *s = self->tree_sequence; - tsk_edge_list_node_t *out_head = NULL; - tsk_edge_list_node_t *out_tail = NULL; - tsk_edge_list_node_t *in_head = NULL; - tsk_edge_list_node_t *in_tail = NULL; - tsk_edge_list_node_t *w = NULL; - tsk_edge_list_t edges_out; - tsk_edge_list_t edges_in; - const tsk_edge_table_t *edges = &s->tables->edges; - const tsk_id_t *insertion_order = s->tables->indexes.edge_insertion_order; - const tsk_id_t *removal_order = s->tables->indexes.edge_removal_order; - - tsk_memset(&edges_out, 0, sizeof(edges_out)); - tsk_memset(&edges_in, 0, sizeof(edges_in)); - - if (self->tree_index + 1 < self->last_index) { - /* First we remove the stale records */ - while (self->removal_index < (tsk_id_t) self->num_edges - && left == edges->right[removal_order[self->removal_index]]) { - k = removal_order[self->removal_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (out_head == NULL) { - out_head = w; - out_tail = w; - } else { - out_tail->next = w; - w->prev = out_tail; - out_tail = w; - } - self->removal_index++; - } - edges_out.head = out_head; - edges_out.tail = out_tail; - - /* Now insert the new records */ - while (self->insertion_index < (tsk_id_t) self->num_edges - && left == edges->left[insertion_order[self->insertion_index]]) { - k = insertion_order[self->insertion_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (in_head == NULL) { - in_head = w; - in_tail = w; - } else { - in_tail->next = w; - w->prev = in_tail; - in_tail = w; - } - self->insertion_index++; - } - edges_in.head = in_head; - edges_in.tail = in_tail; - - right = sequence_length; - if (self->insertion_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->left[insertion_order[self->insertion_index]]); - } - if (self->removal_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->right[removal_order[self->removal_index]]); - } - self->tree_index++; - ret = TSK_TREE_OK; - } - *edges_out_ret = edges_out; - *edges_in_ret = edges_in; - *ret_left = left; - *ret_right = right; - /* Set the left coordinate for the next tree */ - self->tree_left = right; - return ret; + return tsk_diff_iter_init( + self, tree_sequence->tables, (tsk_id_t) tree_sequence->num_trees, options); } /* ======================================================== * @@ -5840,25 +6104,29 @@ update_kc_subtree_state( } static int -update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_out, - tsk_edge_list_t *edges_in, tsk_size_t *depths) +update_kc_incremental( + tsk_tree_t *tree, kc_vectors *kc, tsk_tree_position_t *tree_pos, tsk_size_t *depths) { int ret = 0; - tsk_edge_list_node_t *record; - tsk_edge_t *e; - tsk_id_t u; + tsk_id_t u, v, e, j; double root_time, time; - const double *times = self->tree_sequence->tables->nodes.time; + const double *restrict times = tree->tree_sequence->tables->nodes.time; + const tsk_id_t *restrict edges_child = tree->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = tree->tree_sequence->tables->edges.parent; + + tsk_bug_assert(tree_pos->index == tree->index); + tsk_bug_assert(tree_pos->interval.left == tree->interval.left); + tsk_bug_assert(tree_pos->interval.right == tree->interval.right); /* Update state of detached subtrees */ - for (record = edges_out->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->out.stop - 1; j >= tree_pos->out.start; j--) { + e = tree_pos->out.order[j]; + u = edges_child[e]; depths[u] = 0; - if (self->parent[u] == TSK_NULL) { - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + if (tree->parent[u] == TSK_NULL) { + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } @@ -5866,25 +6134,25 @@ update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_o } /* Propagate state change down into reattached subtrees. */ - for (record = edges_in->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; + for (j = tree_pos->in.stop - 1; j >= tree_pos->in.start; j--) { + e = tree_pos->in.order[j]; + u = edges_child[e]; + v = edges_parent[e]; - tsk_bug_assert(depths[e->child] == 0); - depths[u] = depths[e->parent] + 1; + tsk_bug_assert(depths[u] == 0); + depths[u] = depths[v] + 1; - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); if (ret != 0) { goto out; } - if (tsk_tree_is_sample(self, u)) { - time = tsk_tree_get_branch_length_unsafe(self, u); - update_kc_vectors_single_sample(self->tree_sequence, kc, u, time); + if (tsk_tree_is_sample(tree, u)) { + time = tsk_tree_get_branch_length_unsafe(tree, u); + update_kc_vectors_single_sample(tree->tree_sequence, kc, u, time); } } - out: return ret; } @@ -5900,19 +6168,18 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, const tsk_treeseq_t *treeseqs[2] = { self, other }; tsk_tree_t trees[2]; kc_vectors kcs[2]; - tsk_diff_iter_t diff_iters[2]; - tsk_edge_list_t edges_out[2]; - tsk_edge_list_t edges_in[2]; + /* TODO the tree_pos here is redundant because we should be using this interally + * in the trees to do the advancing. Once we have converted the tree over to using + * tree_pos internally, we can get rid of these tree_pos variables and use + * the values stored in the trees themselves */ + tsk_tree_position_t tree_pos[2]; tsk_size_t *depths[2]; - double t0_left, t0_right, t1_left, t1_right; int ret = 0; for (i = 0; i < 2; i++) { tsk_memset(&trees[i], 0, sizeof(trees[i])); - tsk_memset(&diff_iters[i], 0, sizeof(diff_iters[i])); + tsk_memset(&tree_pos[i], 0, sizeof(tree_pos[i])); tsk_memset(&kcs[i], 0, sizeof(kcs[i])); - tsk_memset(&edges_out[i], 0, sizeof(edges_out[i])); - tsk_memset(&edges_in[i], 0, sizeof(edges_in[i])); depths[i] = NULL; } @@ -5927,7 +6194,7 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_init(&diff_iters[i], treeseqs[i], false); + ret = tsk_tree_position_init(&tree_pos[i], treeseqs[i], 0); if (ret != 0) { goto out; } @@ -5954,11 +6221,10 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index == 0); + + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } @@ -5967,37 +6233,37 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[1], &t1_left, &t1_right, &edges_out[1], &edges_in[1]); - tsk_bug_assert(ret == TSK_TREE_OK); + tsk_tree_position_next(&tree_pos[1]); + tsk_bug_assert(tree_pos[1].index != -1); - ret = update_kc_incremental( - &trees[1], &kcs[1], &edges_out[1], &edges_in[1], depths[1]); + ret = update_kc_incremental(&trees[1], &kcs[1], &tree_pos[1], depths[1]); if (ret != 0) { goto out; } - while (t0_right < t1_right) { - span = t0_right - left; + tsk_bug_assert(trees[0].interval.left == tree_pos[0].interval.left); + tsk_bug_assert(trees[0].interval.right == tree_pos[0].interval.right); + tsk_bug_assert(trees[1].interval.left == tree_pos[1].interval.left); + tsk_bug_assert(trees[1].interval.right == tree_pos[1].interval.right); + while (trees[0].interval.right < trees[1].interval.right) { + span = trees[0].interval.right - left; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - left = t0_right; + left = trees[0].interval.right; ret = tsk_tree_next(&trees[0]); tsk_bug_assert(ret == TSK_TREE_OK); ret = check_kc_distance_tree_inputs(&trees[0]); if (ret != 0) { goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + tsk_tree_position_next(&tree_pos[0]); + tsk_bug_assert(tree_pos[0].index != -1); + ret = update_kc_incremental(&trees[0], &kcs[0], &tree_pos[0], depths[0]); if (ret != 0) { goto out; } } - span = t1_right - left; - left = t1_right; + span = trees[1].interval.right - left; + left = trees[1].interval.right; total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; } if (ret != 0) { @@ -6008,9 +6274,809 @@ tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, out: for (i = 0; i < 2; i++) { tsk_tree_free(&trees[i]); - tsk_diff_iter_free(&diff_iters[i]); + tsk_tree_position_free(&tree_pos[i]); kc_vectors_free(&kcs[i]); tsk_safe_free(depths[i]); } return ret; } + +/* + * Divergence matrix + */ + +typedef struct { + /* Note it's a waste storing the triply linked tree here, but the code + * is written on the assumption of 1-based trees and the algorithm is + * frighteningly subtle, so it doesn't seem worth messing with it + * unless we really need to save some memory */ + tsk_id_t *parent; + tsk_id_t *child; + tsk_id_t *sib; + tsk_id_t *lambda; + tsk_id_t *pi; + tsk_id_t *tau; + tsk_id_t *beta; + tsk_id_t *alpha; +} sv_tables_t; + +static int +sv_tables_init(sv_tables_t *self, tsk_size_t n) +{ + int ret = 0; + + self->parent = tsk_malloc(n * sizeof(*self->parent)); + self->child = tsk_malloc(n * sizeof(*self->child)); + self->sib = tsk_malloc(n * sizeof(*self->sib)); + self->pi = tsk_malloc(n * sizeof(*self->pi)); + self->lambda = tsk_malloc(n * sizeof(*self->lambda)); + self->tau = tsk_malloc(n * sizeof(*self->tau)); + self->beta = tsk_malloc(n * sizeof(*self->beta)); + self->alpha = tsk_malloc(n * sizeof(*self->alpha)); + if (self->parent == NULL || self->child == NULL || self->sib == NULL + || self->lambda == NULL || self->tau == NULL || self->beta == NULL + || self->alpha == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +sv_tables_free(sv_tables_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + tsk_safe_free(self->sib); + tsk_safe_free(self->lambda); + tsk_safe_free(self->pi); + tsk_safe_free(self->tau); + tsk_safe_free(self->beta); + tsk_safe_free(self->alpha); + return 0; +} +static void +sv_tables_reset(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + tsk_memset(self->parent, 0, n * sizeof(*self->parent)); + tsk_memset(self->child, 0, n * sizeof(*self->child)); + tsk_memset(self->sib, 0, n * sizeof(*self->sib)); + tsk_memset(self->pi, 0, n * sizeof(*self->pi)); + tsk_memset(self->lambda, 0, n * sizeof(*self->lambda)); + tsk_memset(self->tau, 0, n * sizeof(*self->tau)); + tsk_memset(self->beta, 0, n * sizeof(*self->beta)); + tsk_memset(self->alpha, 0, n * sizeof(*self->alpha)); +} + +static void +sv_tables_convert_tree(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + const tsk_id_t *restrict tsk_parent = tree->parent; + tsk_id_t *restrict child = self->child; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict sib = self->sib; + tsk_size_t j; + tsk_id_t u, v; + + for (j = 0; j < n - 1; j++) { + u = (tsk_id_t) j + 1; + v = tsk_parent[j] + 1; + sib[u] = child[v]; + child[v] = u; + parent[u] = v; + } +} + +#define LAMBDA 0 + +static void +sv_tables_build_index(sv_tables_t *self) +{ + const tsk_id_t *restrict child = self->child; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict sib = self->sib; + tsk_id_t *restrict lambda = self->lambda; + tsk_id_t *restrict pi = self->pi; + tsk_id_t *restrict tau = self->tau; + tsk_id_t *restrict beta = self->beta; + tsk_id_t *restrict alpha = self->alpha; + tsk_id_t a, n, p, h; + + p = child[LAMBDA]; + n = 0; + lambda[0] = -1; + while (p != LAMBDA) { + while (true) { + n++; + pi[p] = n; + tau[n] = LAMBDA; + lambda[n] = 1 + lambda[n >> 1]; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + beta[p] = n; + while (true) { + tau[beta[p]] = parent[p]; + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p != LAMBDA) { + h = lambda[n & -pi[p]]; + beta[p] = ((n >> h) | 1) << h; + } else { + break; + } + } + } + } + + /* Begin the second traversal */ + lambda[0] = lambda[n]; + pi[LAMBDA] = 0; + beta[LAMBDA] = 0; + alpha[LAMBDA] = 0; + p = child[LAMBDA]; + while (p != LAMBDA) { + while (true) { + a = alpha[parent[p]] | (beta[p] & -beta[p]); + alpha[p] = a; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + while (true) { + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p == LAMBDA) { + break; + } + } + } + } +} + +static void +sv_tables_build(sv_tables_t *self, tsk_tree_t *tree) +{ + sv_tables_reset(self, tree); + sv_tables_convert_tree(self, tree); + sv_tables_build_index(self); +} + +static tsk_id_t +sv_tables_mrca_one_based(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + const tsk_id_t *restrict lambda = self->lambda; + const tsk_id_t *restrict pi = self->pi; + const tsk_id_t *restrict tau = self->tau; + const tsk_id_t *restrict beta = self->beta; + const tsk_id_t *restrict alpha = self->alpha; + tsk_id_t h, k, xhat, yhat, ell, j, z; + + if (beta[x] <= beta[y]) { + h = lambda[beta[y] & -beta[x]]; + } else { + h = lambda[beta[x] & -beta[y]]; + } + k = alpha[x] & alpha[y] & -(1 << h); + h = lambda[k & -k]; + j = ((beta[x] >> h) | 1) << h; + if (j == beta[x]) { + xhat = x; + } else { + ell = lambda[alpha[x] & ((1 << h) - 1)]; + xhat = tau[((beta[x] >> ell) | 1) << ell]; + } + if (j == beta[y]) { + yhat = y; + } else { + ell = lambda[alpha[y] & ((1 << h) - 1)]; + yhat = tau[((beta[y] >> ell) | 1) << ell]; + } + if (pi[xhat] <= pi[yhat]) { + z = xhat; + } else { + z = yhat; + } + return z; +} + +static tsk_id_t +sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) +{ + /* Convert to 1-based indexes and back */ + return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; +} + +static int +tsk_treeseq_check_node_bounds( + const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (u < 0 || u >= N) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } +out: + return ret; +} + +static int +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t options, double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const double *restrict nodes_time = self->tables->nodes.time; + const tsk_size_t n = num_samples; + tsk_size_t i, j, k; + tsk_id_t u, v, w, u_root, v_root; + double tu, tv, d, span, left, right, span_left, span_right; + double *restrict D; + sv_tables_t sv; + + memset(&sv, 0, sizeof(sv)); + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + ret = sv_tables_init(&sv, self->tables->nodes.num_rows + 1); + if (ret != 0) { + goto out; + } + + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = TSK_ERR_TIME_UNCALIBRATED; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + span = span_right - span_left; + sv_tables_build(&sv, &tree); + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[j * n + k] += d; + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + sv_tables_free(&sv); + return ret; +} + +static tsk_size_t +count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent, + const double *restrict time, const tsk_size_t *restrict mutations_per_node) +{ + double tu, tv; + tsk_size_t count = 0; + + tu = time[u]; + tv = time[v]; + while (u != v) { + if (tu < tv) { + count += mutations_per_node[u]; + u = parent[u]; + if (u == TSK_NULL) { + break; + } + tu = time[u]; + } else { + count += mutations_per_node[v]; + v = parent[v]; + if (v == TSK_NULL) { + break; + } + tv = time[v]; + } + } + if (u != v) { + while (u != TSK_NULL) { + count += mutations_per_node[u]; + u = parent[u]; + } + while (v != TSK_NULL) { + count += mutations_per_node[v]; + v = parent[v]; + } + } + return count; +} + +static int +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + double *restrict result) +{ + int ret = 0; + tsk_tree_t tree; + const tsk_size_t n = num_samples; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; + tsk_size_t i, j, k, tree_site, tree_mut; + tsk_site_t site; + tsk_mutation_t mut; + tsk_id_t u, v; + double left, right, span_left, span_right; + double *restrict D; + tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node)); + + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { + goto out; + } + if (mutations_per_node == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * n * n; + ret = tsk_tree_seek(&tree, left, 0); + if (ret != 0) { + goto out; + } + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + + /* NOTE: we could avoid this full memset across all nodes by doing + * the same loops again and decrementing at the end of the main + * tree-loop. It's probably not worth it though, because of the + * overwhelming O(n^2) below */ + tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node)); + for (tree_site = 0; tree_site < tree.sites_length; tree_site++) { + site = tree.sites[tree_site]; + if (span_left <= site.position && site.position < span_right) { + for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) { + mut = site.mutations[tree_mut]; + mutations_per_node[mut.node]++; + } + } + } + + for (j = 0; j < n; j++) { + u = samples[j]; + for (k = j + 1; k < n; k++) { + v = samples[k]; + D[j * n + k] += (double) count_mutations_on_path( + u, v, tree.parent, nodes_time, mutations_per_node); + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + tsk_tree_free(&tree); + tsk_safe_free(mutations_per_node); + return ret; +} + +static void +fill_lower_triangle( + double *restrict result, const tsk_size_t n, const tsk_size_t num_windows) +{ + tsk_size_t i, j, k; + double *restrict D; + + /* TODO there's probably a better striding pattern that could be used here */ + for (i = 0; i < num_windows; i++) { + D = result + i * n * n; + for (j = 0; j < n; j++) { + for (k = j + 1; k < n; k++) { + D[k * n + j] = D[j * n + k]; + } + } + } +} + +int +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) +{ + int ret = 0; + const tsk_id_t *samples = self->samples; + tsk_size_t n = self->num_samples; + const double default_windows[] = { 0, self->tables->sequence_length }; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + + if (stat_node) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = TSK_ERR_MULTIPLE_STAT_MODES; + goto out; + } + + if (options & TSK_STAT_POLARISED) { + ret = TSK_ERR_STAT_POLARISED_UNSUPPORTED; + goto out; + } + if (options & TSK_STAT_SPAN_NORMALISE) { + ret = TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED; + goto out; + } + + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); + if (ret != 0) { + goto out; + } + } + + if (samples_in != NULL) { + samples = samples_in; + n = num_samples; + ret = tsk_treeseq_check_node_bounds(self, n, samples); + if (ret != 0) { + goto out; + } + } + + tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); + + if (stat_branch) { + ret = tsk_treeseq_divergence_matrix_branch( + self, n, samples, num_windows, windows, options, result); + } else { + tsk_bug_assert(stat_site); + ret = tsk_treeseq_divergence_matrix_site( + self, n, samples, num_windows, windows, options, result); + } + if (ret != 0) { + goto out; + } + fill_lower_triangle(result, n, num_windows); + +out: + return ret; +} + +/* ======================================================== * + * Extend edges + * ======================================================== */ + +typedef struct _edge_list_t { + tsk_id_t edge; + // the `extended` flags records whether we have decided to extend + // this entry to the current tree? + bool extended; + struct _edge_list_t *next; +} edge_list_t; + +static int +extend_edges_append_entry( + edge_list_t **head, edge_list_t **tail, tsk_blkalloc_t *heap, tsk_id_t edge) +{ + int ret = 0; + edge_list_t *x = NULL; + + x = tsk_blkalloc_get(heap, sizeof(*x)); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + x->edge = edge; + x->extended = false; + x->next = NULL; + + if (*tail == NULL) { + *head = x; + } else { + (*tail)->next = x; + } + *tail = x; +out: + return ret; +} + +static void +remove_unextended(edge_list_t **head, edge_list_t **tail) +{ + edge_list_t *px, *x; + + px = *head; + while (px != NULL && !px->extended) { + px = px->next; + } + *head = px; + if (px != NULL) { + px->extended = false; + x = px->next; + while (x != NULL) { + if (x->extended) { + x->extended = false; + px->next = x; + px = x; + } + x = x->next; + } + } + *tail = px; +} + +static int +tsk_treeseq_extend_edges_iter( + const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges) +{ + // Note: this modifies the edge table, but it does this by (a) removing + // some edges, and (b) extending left/right endpoints of others, + // while keeping order the same, and so this maintains sortedness + // (so, there is no need to sort afterwards). + int ret = 0; + tsk_id_t tj; + tsk_id_t e, e1, e2, e_in; + tsk_blkalloc_t edge_list_heap; + double *near_side, *far_side; + edge_list_t *edges_in_head, *edges_in_tail; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *ex1, *ex2, *ex_in; + double there, left, right; + bool forwards = (direction == TSK_DIR_FORWARD); + tsk_tree_position_t tree_pos; + bool valid; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + const tsk_size_t num_edges = tables->edges.num_rows; + tsk_id_t *degree = tsk_calloc(num_nodes, sizeof(*degree)); + tsk_bool_t *keep = tsk_calloc(num_edges, sizeof(*keep)); + + memset(&edge_list_heap, 0, sizeof(edge_list_heap)); + memset(&tree_pos, 0, sizeof(tree_pos)); + + if (keep == NULL || degree == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_blkalloc_init(&edge_list_heap, 8192); + if (ret != 0) { + goto out; + } + ret = tsk_tree_position_init(&tree_pos, self, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_table_copy(&tables->edges, edges, TSK_NO_INIT); + if (ret != 0) { + goto out; + } + + if (forwards) { + near_side = edges->left; + far_side = edges->right; + } else { + near_side = edges->right; + far_side = edges->left; + } + edges_in_head = NULL; + edges_in_tail = NULL; + edges_out_head = NULL; + edges_out_tail = NULL; + + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + + while (valid) { + left = tree_pos.interval.left; + right = tree_pos.interval.right; + there = forwards ? right : left; + + // remove entries that aren't being extended/postponed + remove_unextended(&edges_in_head, &edges_in_tail); + remove_unextended(&edges_out_head, &edges_out_tail); + + for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += direction) { + e = tree_pos.out.order[tj]; + // add edge to pending_out + ret = extend_edges_append_entry( + &edges_out_head, &edges_out_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += direction) { + e = tree_pos.in.order[tj]; + // add edge to pending_in + ret = extend_edges_append_entry( + &edges_in_head, &edges_in_tail, &edge_list_heap, e); + if (ret != 0) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + } + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges->parent[ex1->edge]] -= 1; + degree[edges->child[ex1->edge]] -= 1; + } + for (ex1 = edges_in_head; ex1 != NULL; ex1 = ex1->next) { + degree[edges->parent[ex1->edge]] += 1; + degree[edges->child[ex1->edge]] += 1; + } + + // iterate over pairs of out and in: (ex1, ex2, in) + for (ex1 = edges_out_head; ex1 != NULL; ex1 = ex1->next) { + if (!ex1->extended) { + e1 = ex1->edge; + for (ex2 = edges_out_head; ex2 != NULL; ex2 = ex2->next) { + if (!ex2->extended) { + e2 = ex2->edge; + if ((edges->parent[e1] == edges->child[e2]) + && (degree[edges->child[e2]] == 0)) { + for (ex_in = edges_in_head; ex_in != NULL; + ex_in = ex_in->next) { + e_in = ex_in->edge; + if ((edges->left[e_in] < right) + && (edges->right[e_in] > left)) { + if ((edges->child[e1] == edges->child[e_in]) + && (edges->parent[e2] == edges->parent[e_in])) { + ex1->extended = true; + ex2->extended = true; + ex_in->extended = true; + far_side[e1] = there; + far_side[e2] = there; + near_side[e_in] = there; + degree[edges->parent[e1]] += 2; + } + } + } + } + } + } + } + } + if (forwards) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } + } + + for (e = 0; e < (tsk_id_t) num_edges; e++) { + keep[e] = edges->left[e] < edges->right[e]; + } + ret = tsk_edge_table_keep_rows(edges, keep, 0, NULL); +out: + tsk_blkalloc_free(&edge_list_heap); + tsk_tree_position_free(&tree_pos); + tsk_safe_free(degree); + tsk_safe_free(keep); + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, + tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) +{ + int ret = 0; + tsk_table_collection_t tables; + tsk_treeseq_t ts; + int iter, j; + tsk_size_t last_num_edges; + const int direction[] = { TSK_DIR_FORWARD, TSK_DIR_REVERSE }; + + tsk_memset(&tables, 0, sizeof(tables)); + tsk_memset(&ts, 0, sizeof(ts)); + tsk_memset(output, 0, sizeof(*output)); + + /* Note: there is a fair bit of copying of table data in this implementation + * currently, as we create a new tree sequence for each iteration, which + * takes a full copy of the input tables. We could streamline this by + * adding a flag to treeseq_init which says "steal a reference to these + * tables and *don't* free them at the end". Then, we would only need + * one copy of the full tables, and could pass in a standalone edge + * table to use for in-place updating. + */ + ret = tsk_table_collection_copy(self->tables, &tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init(&ts, &tables, 0); + if (ret != 0) { + goto out; + } + + last_num_edges = tsk_treeseq_get_num_edges(&ts); + for (iter = 0; iter < max_iter; iter++) { + for (j = 0; j < 2; j++) { + ret = tsk_treeseq_extend_edges_iter(&ts, direction[j], &tables.edges); + if (ret != 0) { + goto out; + } + /* We're done with the current ts now */ + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + } + if (last_num_edges == tsk_treeseq_get_num_edges(&ts)) { + break; + } + last_num_edges = tsk_treeseq_get_num_edges(&ts); + } + + /* Hand ownership of the tree sequence to the calling code */ + tsk_memcpy(output, &ts, sizeof(ts)); + tsk_memset(&ts, 0, sizeof(*output)); +out: + tsk_treeseq_free(&ts); + tsk_table_collection_free(&tables); + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index d9a93d629e..2305fb5ae3 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -56,8 +56,8 @@ extern "C" { /* Options for map_mutations */ #define TSK_MM_FIXED_ANCESTRAL_STATE (1 << 0) -/* For the edge diff iterator */ -#define TSK_INCLUDE_TERMINAL (1 << 0) +#define TSK_DIR_FORWARD 1 +#define TSK_DIR_REVERSE -1 /** @defgroup API_FLAGS_TS_INIT_GROUP :c:func:`tsk_treeseq_init` specific flags. @@ -258,30 +258,6 @@ typedef struct { tsk_id_t right_index; } tsk_tree_t; -/* Diff iterator. */ -typedef struct _tsk_edge_list_node_t { - tsk_edge_t edge; - struct _tsk_edge_list_node_t *next; - struct _tsk_edge_list_node_t *prev; -} tsk_edge_list_node_t; - -typedef struct { - tsk_edge_list_node_t *head; - tsk_edge_list_node_t *tail; -} tsk_edge_list_t; - -typedef struct { - tsk_size_t num_nodes; - tsk_size_t num_edges; - double tree_left; - const tsk_treeseq_t *tree_sequence; - tsk_id_t insertion_index; - tsk_id_t removal_index; - tsk_id_t tree_index; - tsk_id_t last_index; - tsk_edge_list_node_t *edge_list_nodes; -} tsk_diff_iter_t; - /****************************************************************************/ /* Tree sequence.*/ /****************************************************************************/ @@ -915,6 +891,39 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, tsk_id_t *node_map); +/** +@brief Extends edges + +Returns a modified tree sequence in which the span covered by ancestral nodes +is "extended" to regions of the genome according to the following rule: +If an ancestral segment corresponding to node `n` has parent `p` and +child `c` on some portion of the genome, and on an adjacent segment of +genome `p` is the immediate parent of `c`, then `n` is inserted into the +edge from `p` to `c`. This involves extending the span of the edges +from `p` to `n` and `n` to `c` and reducing the span of the edge from +`p` to `c`. Since the latter edge may be removed entirely, this process +reduces (or at least does not increase) the number of edges in the tree +sequence. + +The method works by iterating over the genome to look for edges that can +be extended in this way; the maximum number of such iterations is +controlled by ``max_iter``. + + +@rst + +**Options**: None currently defined. +@endrst + +@param self A pointer to a tsk_treeseq_t object. +@param max_iter The maximum number of iterations over the tree sequence. +@param options Bitwise option flags. (UNUSED) +@param output A pointer to an uninitialised tsk_treeseq_t object. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_treeseq_extend_edges( + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); + /** @} */ int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, @@ -967,6 +976,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei const double *weights, tsk_size_t num_covariates, const double *covariates, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +/* Two way weighted stats with covariates */ + +typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, + tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, @@ -1027,6 +1047,10 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, + const tsk_id_t *samples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); + /****************************************************************************/ /* Tree */ /****************************************************************************/ @@ -1111,10 +1135,6 @@ int tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) @{ */ -/** @brief Value returned by seeking methods when they have successfully - seeked to a non-null tree. */ -#define TSK_TREE_OK 1 - /** @brief Seek to the first tree in the sequence. @@ -1220,12 +1240,6 @@ we will have ``position < tree.interval.right``. Seeking to a position currently covered by the tree is a constant time operation. - -.. warning:: - The current implementation of ``seek`` does **not** provide efficient - random access to arbitrary positions along the genome. However, - sequentially seeking in either direction is as efficient as calling - :c:func:`tsk_tree_next` or :c:func:`tsk_tree_prev` directly. @endrst @param self A pointer to an initialised tsk_tree_t object. @@ -1236,6 +1250,22 @@ a constant time operation. */ int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options); +/** +@brief Seek to a specific tree in a tree sequence. + +@rst +Set the state of this tree to reflect the tree in parent +tree sequence whose index is ``0 <= tree < num_trees``. +@endrst + +@param self A pointer to an initialised tsk_tree_t object. +@param tree The target tree index. +@param options Seek options. Currently unused. Set to 0 for compatibility + with future versions of tskit. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options); + /** @} */ /** @@ -1739,16 +1769,42 @@ bool tsk_tree_is_sample(const tsk_tree_t *self, tsk_id_t u); */ bool tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other); -/****************************************************************************/ -/* Diff iterator */ -/****************************************************************************/ - -int tsk_diff_iter_init( +int tsk_diff_iter_init_from_ts( tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); -int tsk_diff_iter_free(tsk_diff_iter_t *self); -int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, - tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); -void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); + +/* Temporarily putting this here to avoid problems with doxygen. Will need to + * move up the file later when it gets incorporated into the tsk_tree_t object. + */ +typedef struct { + tsk_id_t index; + struct { + double left; + double right; + } interval; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } in; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } out; + tsk_id_t left_current_index; + tsk_id_t right_current_index; + int direction; + const tsk_treeseq_t *tree_sequence; +} tsk_tree_position_t; + +int tsk_tree_position_init( + tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +int tsk_tree_position_free(tsk_tree_position_t *self); +int tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out); +bool tsk_tree_position_next(tsk_tree_position_t *self); +bool tsk_tree_position_prev(tsk_tree_position_t *self); +int tsk_tree_position_seek_forward(tsk_tree_position_t *self, tsk_id_t index); +int tsk_tree_position_seek_backward(tsk_tree_position_t *self, tsk_id_t index); #ifdef __cplusplus } diff --git a/docs/_config.yml b/docs/_config.yml index cba8b909e9..c781ee3529 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -38,10 +38,14 @@ sphinx: - sphinx.ext.intersphinx - sphinx_issues - sphinxarg.ext + - IPython.sphinxext.ipython_console_highlighting #- sphinxcontrib.prettyspecialmethods config: html_theme: tskit_book_theme + html_theme_options: + pygment_light_style: monokai + pygment_dark_style: monokai pygments_style: monokai myst_enable_extensions: - colon_fence @@ -66,15 +70,18 @@ sphinx: # Note we have to use the regex version here because of # https://github.com/sphinx-doc/sphinx/issues/9748 nitpick_ignore_regex: [ + ["c:identifier", "uint8_t"], ["c:identifier", "int32_t"], ["c:identifier", "uint32_t"], ["c:identifier", "uint64_t"], ["c:identifier", "FILE"], + ["c:identifier", "bool"], # This is for the anonymous interval struct embedded in the tsk_tree_t. ["c:identifier", "tsk_tree_t.@1"], ["c:type", "int32_t"], ["c:type", "uint32_t"], ["c:type", "uint64_t"], + ["c:type", "bool"], # TODO these have been triaged here to make the docs compile, but we should # sort them out properly. https://github.com/tskit-dev/tskit/issues/336 ["py:class", "array_like"], @@ -88,6 +95,18 @@ sphinx: ["py:class", "dtype=np.int64"], ] + # Added to allow "bool" be used as a :ctype: - this list has to be + # manually specifed in order to remove "bool" from it. + c_extra_keywords: [ + "alignas", + "alignof", + "complex", + "imaginary", + "noreturn", + "static_assert", + "thread_local" + ] + autodoc_member_order: bysource # Without this option, autodoc tries to put links for all return types diff --git a/docs/c-api.rst b/docs/c-api.rst index 33246cf6cd..bd8233ed6e 100644 --- a/docs/c-api.rst +++ b/docs/c-api.rst @@ -233,6 +233,7 @@ Basic Types .. doxygentypedef:: tsk_id_t .. doxygentypedef:: tsk_size_t .. doxygentypedef:: tsk_flags_t +.. doxygentypedef:: tsk_bool_t ************** Common options diff --git a/docs/data-model.md b/docs/data-model.md index 20b8957175..3339f425cb 100644 --- a/docs/data-model.md +++ b/docs/data-model.md @@ -830,12 +830,13 @@ HTML(html_quintuple_table(ts, show_convenience_arrays=True)) ### Roots -The roots of a tree are defined as the unique endpoints of upward paths -starting from sample nodes ({ref}`isolated` -sample nodes also count as roots). Thus, trees can have multiple roots in `tskit`. -For example, if we delete the edge joining `6` and `7` in the previous -example, we get a tree with two roots: - +In the `tskit` {class}`trees ` we have shown so far, all the sample nodes have +been connected to each other. This means each tree has only a single {attr}`~Tree.root` +(i.e. the oldest node found when tracing a path backwards in time from any sample). +However, a tree can contain {ref}`sec_data_model_tree_isolated_sample_nodes` +or unconnected topologies, and can therefore have *multiple* {attr}`~Tree.roots`. +Here's an example, created by deleting the edge joining `6` and `7` in the tree sequence +used above: ```{code-cell} ipython3 :tags: ["hide-input"] @@ -845,7 +846,7 @@ ts_multiroot = tables.tree_sequence() SVG(ts_multiroot.first().draw_svg(time_scale="rank")) ``` -Note that in tree sequence terminology, this should *not* be thought +In `tskit` terminology, this should *not* be thought of as two separate trees, but as a single multi-root "tree", comprising two unlinked topologies. This fits with the definition of a tree in a tree sequence: a tree describes the ancestry of the same @@ -853,19 +854,34 @@ fixed set of sample nodes at a single position in the genome. In the picture above, *both* the left and right hand topologies are required to describe the genealogy of samples 0..4 at this position. -Here's what it looks like for an entire tree sequence: +Here's what the entire tree sequence now looks like: ```{code-cell} ipython3 :tags: ["hide-input"] SVG(ts_multiroot.draw_svg(time_scale="rank")) ``` -This tree sequence consists of three trees. The first tree, which applies from -position 0 to 20, is the one used in our example. As we saw, removing the edge -connecting node 6 to node 7 has created a tree with 2 roots (and thus 2 -unconnected topologies in a single tree). In contrast, the second tree, from -position 20 to 40, has a single root. Finally the third tree, from position -40 to 60, again has two roots. +From the terminology above, it can be seen that this tree sequence consists of only +three trees (not five). The first tree, which applies from position 0 to 20, is the one +used in our example. As we saw, removing the edge connecting node 6 to node 7 has +created a tree with 2 roots (and thus 2 unconnected topologies in a single tree). +In contrast, the second tree, from position 20 to 40, has a single root. Finally the +third tree, from position 40 to 60, again has two roots. + +(sec_data_model_tree_root_threshold)= + +#### The root threshold + +The roots of a tree are defined by reference to the +{ref}`sample nodes`. By default, roots are the unique +endpoints of the paths traced upwards from the sample nodes; equivalently, each root +counts one or more samples among its descendants (or is itself a sample node). This is +the case when the {attr}`~Tree.root_threshold` property of a tree is left at its default +value of `1`. If, however, the `root_threshold` is (say) `2`, then a node is +considered a root only if it counts at least two samples among its descendants. Setting +an alternative `root_threshold` value can be used to avoid visiting +{ref}`sec_data_model_tree_isolated_sample_nodes`, for example when dealing with trees +containing {ref}`sec_data_model_missing_data`. (sec_data_model_tree_virtual_root)= @@ -940,11 +956,18 @@ for tree in ts_multiroot.trees(): ) ``` -However, it is also possible for a {ref}`sample node` -to be isolated. Unlike other nodes, isolated *sample* nodes are still considered as -being present on the tree (meaning they will still returned by the {meth}`Tree.nodes` -and {meth}`Tree.samples` methods): they are therefore plotted, but unconnected to any -other nodes. To illustrate, we can remove the edge from node 2 to node 7. + +(sec_data_model_tree_isolated_sample_nodes)= + +#### Isolated sample nodes + +It is also possible for a {ref}`sample node` +to be isolated. As long as the {ref}`root threshold` +is set to its default value, an isolated *sample* node will count as a root, and +therefore be considered as being present on the tree (meaning it will be +returned by the {meth}`Tree.nodes` +and {meth}`Tree.samples` methods). When displaying a tree, isolated samples are shown +unconnected to other nodes. To illustrate, we can remove the edge from node 2 to node 7: ```{code-cell} ipython3 :tags: ["hide-input"] @@ -955,9 +978,9 @@ ts_isolated = tables.tree_sequence() SVG(ts_isolated.draw_svg(time_scale="rank")) ``` -The rightmost tree now contains an isolated sample node (node 2). Isolated -sample nodes count as one of the {ref}`sec_data_model_tree_roots` of the tree, -so that tree has three roots, one of which is node 2: +The rightmost tree now contains an isolated sample node (node 2), which counts as +one of the {ref}`sec_data_model_tree_roots` of the tree. This tree therefore has three +roots, one of which is node 2: ```{code-cell} ipython3 rightmost_tree = ts_isolated.at_index(-1) diff --git a/docs/python-api.md b/docs/python-api.md index ac53d3fce9..20a2d541b4 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -268,6 +268,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.trim TreeSequence.split_edges TreeSequence.decapitate + TreeSequence.extend_edges ``` (sec_python_api_tree_sequences_ibd)= @@ -321,6 +322,7 @@ Single site TreeSequence.Fst TreeSequence.genealogical_nearest_neighbours TreeSequence.genetic_relatedness + TreeSequence.genetic_relatedness_weighted TreeSequence.general_stat TreeSequence.segregating_sites TreeSequence.sample_count_stat diff --git a/docs/stats.md b/docs/stats.md index 39257e1017..72aa5d615b 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -71,6 +71,7 @@ appears beside the listed method. * Multi-way * {meth}`~TreeSequence.divergence` * {meth}`~TreeSequence.genetic_relatedness` + {meth}`~TreeSequence.genetic_relatedness_weighted` * {meth}`~TreeSequence.f4` {meth}`~TreeSequence.f3` {meth}`~TreeSequence.f2` @@ -593,6 +594,12 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1. where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number of samples. +`genetic_relatedness_weighted` +: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`, + + where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number + of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix. + `Y2` : {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}` diff --git a/docs/substitutions/table_keep_rows_main.rst b/docs/substitutions/table_keep_rows_main.rst new file mode 100644 index 0000000000..95652527a2 --- /dev/null +++ b/docs/substitutions/table_keep_rows_main.rst @@ -0,0 +1,14 @@ +Updates this table in-place according to the specified boolean +array, and returns the resulting mapping from old to new row IDs. +For each row ``j``, if ``keep[j]`` is True, that row will be +retained in the output; otherwise, the row will be deleted. +Rows are retained in their original ordering. + +The returned ``id_map`` is an array of the same length as +this table before the operation, such that ``id_map[j] = -1`` +(:data:`tskit.NULL`) if row ``j`` was deleted, and ``id_map[j]`` +is the new ID of that row, otherwise. + +.. todo:: + This needs some examples to link to. See + https://github.com/tskit-dev/tskit/issues/2708 diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 268680a378..d2e19f2be6 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,17 +1,84 @@ -------------------- -[0.5.4] - 2022-XX-XX +[0.5.6] - 2023-XX-XX -------------------- **Features** +- Add ``TreeSequence.genetic_relatedness_weighted`` stats method. + (:user:`petrelharp`, :user:`brieuclehmann`, :user:`jeromekelleher`, + :pr:`2785`, :pr:`1246`) + +- Add ``TreeSequence.impute_unknown_mutations_time`` method to return an + array of mutation times based on the times of associated nodes + (:user:`duncanMR`, :pr:`2760`, :issue:`2758`) + +- Add ``asdict`` to all dataclasses. These are returned when you access a row or + other tree sequence object. (:user:`benjeffery`, :pr:`2759`, :issue:`2719`) + +- Add ``TreeSequence.extend_edges`` method that extends ancestral haplotypes + using recombination information, leading to unary nodes in many trees and + fewer edges. (:user:`petrelharp`, :user:`hfr1tz3`, :user:`avabamf`, :pr:`2651`) + +-------------------- +[0.5.5] - 2023-05-17 +-------------------- + +**Performance improvements** + +- Methods like ts.at() which seek to a specified position on the sequence from + a new Tree instance are now much faster (:user:`molpopgen`, :pr:`2661`). + +**Features** + +- Add ``__repr__`` for variants to return a string representation of the raw data + without spewing megabytes of text (:user:`chriscrsmith`, :pr:`2695`, :issue:`2694`) + +- Add ``keep_rows`` method to table classes to support efficient in-place + table subsetting (:user:`jeromekelleher`, :pr:`2700`) + +**Bugfixes** + +- Fix `UnicodeDecodeError` when calling `Variant.alleles` on the `emscripten` platform. + (:user:`benjeffery`, :pr:`2754`, :issue:`2737`) + +-------------------- +[0.5.4] - 2023-01-13 +-------------------- + +**Features** + +- A new ``Tree.is_root`` method avoids the need to to search the potentially + large list of ``Tree.roots`` (:user:`hyanwong`, :pr:`2669`, :issue:`2620`) + - The ``TreeSequence`` object now has the attributes ``min_time`` and ``max_time``, which are the minimum and maximum among the node times and mutation times, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) +- The ``draw_svg`` methods now have a ``max_num_trees`` parameter to truncate + the total number of trees shown, giving a readable display for tree + sequences with many trees (:user:`hyanwong`, :pr:`2652`) + - The ``draw_svg`` methods now accept a ``canvas_size`` parameter to allow extra room on the canvas e.g. for long labels or repositioned graphical elements (:user:`hyanwong`, :pr:`2646`, :issue:`2645`) +- The ``Tree`` object now has the method ``siblings`` to get + the siblings of a node. It returns an empty tuple if the node + has no siblings, is not a node in the tree, is the virtual root, + or is an isolated non-sample node. + (:user:`szhan`, :pr:`2618`, :issue:`2616`) + +- The ``msprime.RateMap`` class has been ported into tskit: functionality should + be identical to the version in msprime, apart from minor changes in the formatting + of tabular text output (:user:`hyanwong`, :user:`jeromekelleher`, :pr:`2678`) + +- Tskit now supports and has wheels for Python 3.11. This Python version has a significant + performance boost (:user:`benjeffery`, :pr:`2624`, :issue:`2248`) + +- Add the `update_sample_flags` option to `simplify` which ensures + no node sample flags are changed to allow calling code to manage sample status. + (:user:`jeromekelleher`, :issue:`2662`, :pr:`2663`). + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` @@ -49,7 +116,7 @@ - Single statistics computed with ``TreeSequence.general_stat`` are now returned as numpy scalars if windows=None, AND; samples is a single - list or None (for a 1-way stat), OR indexes is None or a single list of + list or None (for a 1-way stat), OR indexes is None or a single list of length k (instead of a list of length-k lists). (:user:`gtsambos`, :pr:`2417`, :issue:`2308`) @@ -64,10 +131,10 @@ **Performance improvements** - TreeSequence.link_ancestors no longer continues to process edges once all - of the sample and ancestral nodes have been accounted for, improving memory + of the sample and ancestral nodes have been accounted for, improving memory overhead and overall performance (:user:`gtsambos`, :pr:`2456`, :issue:`2442`) - + -------------------- [0.5.2] - 2022-07-29 -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index f67ed0df39..71a690ce7e 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2023 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -221,10 +221,10 @@ handle_library_error(int err) { int kas_err; const char *not_kas_format_msg - = "File not in kastore format. If this file " - "was generated by msprime < 0.6.0 (June 2018) it uses the old HDF5-based " - "format which can no longer be read directly. Please convert to the new " - "kastore format using the ``tskit upgrade`` command."; + = "File not in kastore format. Either the file is corrupt or it is not a " + "tskit tree sequence file. It may be a legacy HDF file upgradable with " + "`tskit upgrade` or a compressed tree sequence file that can be decompressed " + "with `tszip`."; const char *ibd_pairs_not_stored_msg = "Sample pairs are not stored by default " "in the IdentitySegments object returned by ibd_segments(), and you have " @@ -578,7 +578,8 @@ make_alleles(tsk_variant_t *variant) goto out; } for (j = 0; j < variant->num_alleles; j++) { - item = Py_BuildValue("s#", variant->alleles[j], variant->allele_lengths[j]); + item = Py_BuildValue( + "s#", variant->alleles[j], (Py_ssize_t) variant->allele_lengths[j]); if (item == NULL) { Py_DECREF(t); goto out; @@ -942,13 +943,14 @@ tsk_id_converter(PyObject *py_obj, tsk_id_t *id_out) } static int -int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) +array_converter(int type, PyObject *py_obj, PyArrayObject **array_out) { int ret = 0; PyArrayObject *temp_array; temp_array = (PyArrayObject *) PyArray_FromAny( - py_obj, PyArray_DescrFromType(NPY_INT32), 1, 1, NPY_ARRAY_IN_ARRAY, NULL); + py_obj, PyArray_DescrFromType(type), 1, 1, NPY_ARRAY_IN_ARRAY, NULL); + if (temp_array == NULL) { goto out; } @@ -958,6 +960,64 @@ int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) return ret; } +static int +int32_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + return array_converter(NPY_INT32, py_obj, array_out); +} + +static int +bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + return array_converter(NPY_BOOL, py_obj, array_out); +} + +/* Note: it doesn't seem to be possible to cast pointers to the actual + * table functions to this type because the first argument must be a + * void *, so the simplest option is to put in a small shim that + * wraps the library function and casts to the correct table type. + */ +typedef int keep_row_func_t( + void *self, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map); + +static PyObject * +table_keep_rows( + PyObject *args, void *table, tsk_size_t num_rows, keep_row_func_t keep_row_func) +{ + + PyObject *ret = NULL; + PyArrayObject *keep = NULL; + PyArrayObject *id_map = NULL; + npy_intp n = (npy_intp) num_rows; + npy_intp array_len; + int err; + + if (!PyArg_ParseTuple(args, "O&", &bool_array_converter, &keep)) { + goto out; + } + array_len = PyArray_DIMS(keep)[0]; + if (array_len != n) { + PyErr_SetString(PyExc_ValueError, "keep array must be of length Table.num_rows"); + goto out; + } + id_map = (PyArrayObject *) PyArray_SimpleNew(1, &n, NPY_INT32); + if (id_map == NULL) { + goto out; + } + err = keep_row_func(table, PyArray_DATA(keep), 0, PyArray_DATA(id_map)); + + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) id_map; + id_map = NULL; +out: + Py_XDECREF(keep); + Py_XDECREF(id_map); + return ret; +} + /*=================================================================== * IndividualTable *=================================================================== @@ -1332,6 +1392,28 @@ IndividualTable_extend(IndividualTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +individual_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_individual_table_keep_rows( + (tsk_individual_table_t *) table, keep, options, id_map); +} + +static PyObject * +IndividualTable_keep_rows(IndividualTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + individual_table_keep_rows_generic); +out: + return ret; +} + static PyObject * IndividualTable_get_max_rows_increment(IndividualTable *self, void *closure) { @@ -1578,6 +1660,10 @@ static PyMethodDef IndividualTable_methods[] = { .ml_meth = (PyCFunction) IndividualTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) IndividualTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -1911,6 +1997,27 @@ NodeTable_extend(NodeTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +node_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_node_table_keep_rows((tsk_node_table_t *) table, keep, options, id_map); +} + +static PyObject * +NodeTable_keep_rows(NodeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, node_table_keep_rows_generic); +out: + return ret; +} + static PyObject * NodeTable_get_max_rows_increment(NodeTable *self, void *closure) { @@ -2138,6 +2245,10 @@ static PyMethodDef NodeTable_methods[] = { .ml_meth = (PyCFunction) NodeTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) NodeTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -2482,6 +2593,27 @@ EdgeTable_extend(EdgeTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +edge_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_edge_table_keep_rows((tsk_edge_table_t *) table, keep, options, id_map); +} + +static PyObject * +EdgeTable_keep_rows(EdgeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, edge_table_keep_rows_generic); +out: + return ret; +} + static PyObject * EdgeTable_get_max_rows_increment(EdgeTable *self, void *closure) { @@ -2707,11 +2839,14 @@ static PyMethodDef EdgeTable_methods[] = { .ml_meth = (PyCFunction) EdgeTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, - { .ml_name = "squash", .ml_meth = (PyCFunction) EdgeTable_squash, .ml_flags = METH_NOARGS, .ml_doc = "Squashes sets of edges with adjacent L,R and identical P,C values." }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) EdgeTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -3039,6 +3174,28 @@ MigrationTable_extend(MigrationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +migration_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_migration_table_keep_rows( + (tsk_migration_table_t *) table, keep, options, id_map); +} + +static PyObject * +MigrationTable_keep_rows(MigrationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + migration_table_keep_rows_generic); +out: + return ret; +} + static PyObject * MigrationTable_get_max_rows_increment(MigrationTable *self, void *closure) { @@ -3296,6 +3453,10 @@ static PyMethodDef MigrationTable_methods[] = { .ml_meth = (PyCFunction) MigrationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) MigrationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -3623,6 +3784,27 @@ SiteTable_extend(SiteTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +site_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_site_table_keep_rows((tsk_site_table_t *) table, keep, options, id_map); +} + +static PyObject * +SiteTable_keep_rows(SiteTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows( + args, (void *) self->table, self->table->num_rows, site_table_keep_rows_generic); +out: + return ret; +} + static PyObject * SiteTable_get_max_rows_increment(SiteTable *self, void *closure) { @@ -3837,6 +4019,10 @@ static PyMethodDef SiteTable_methods[] = { .ml_meth = (PyCFunction) SiteTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) SiteTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -4173,6 +4359,28 @@ MutationTable_extend(MutationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +mutation_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_mutation_table_keep_rows( + (tsk_mutation_table_t *) table, keep, options, id_map); +} + +static PyObject * +MutationTable_keep_rows(MutationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + mutation_table_keep_rows_generic); +out: + return ret; +} + static PyObject * MutationTable_get_max_rows_increment(MutationTable *self, void *closure) { @@ -4432,6 +4640,10 @@ static PyMethodDef MutationTable_methods[] = { .ml_meth = (PyCFunction) MutationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) MutationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -4754,6 +4966,28 @@ PopulationTable_extend(PopulationTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +population_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_population_table_keep_rows( + (tsk_population_table_t *) table, keep, options, id_map); +} + +static PyObject * +PopulationTable_keep_rows(PopulationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + population_table_keep_rows_generic); +out: + return ret; +} + static PyObject * PopulationTable_get_max_rows_increment(PopulationTable *self, void *closure) { @@ -4918,6 +5152,10 @@ static PyMethodDef PopulationTable_methods[] = { .ml_meth = (PyCFunction) PopulationTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) PopulationTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -5232,6 +5470,28 @@ ProvenanceTable_extend(ProvenanceTable *self, PyObject *args, PyObject *kwds) return ret; } +static int +provenance_table_keep_rows_generic( + void *table, const tsk_bool_t *keep, tsk_flags_t options, tsk_id_t *id_map) +{ + return tsk_provenance_table_keep_rows( + (tsk_provenance_table_t *) table, keep, options, id_map); +} + +static PyObject * +ProvenanceTable_keep_rows(ProvenanceTable *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_keep_rows(args, (void *) self->table, self->table->num_rows, + provenance_table_keep_rows_generic); +out: + return ret; +} + static PyObject * ProvenanceTable_get_max_rows_increment(ProvenanceTable *self, void *closure) { @@ -5385,6 +5645,10 @@ static PyMethodDef ProvenanceTable_methods[] = { .ml_meth = (PyCFunction) ProvenanceTable_extend, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Extend this table from another using specified row_indexes" }, + { .ml_name = "keep_rows", + .ml_meth = (PyCFunction) ProvenanceTable_keep_rows, + .ml_flags = METH_VARARGS, + .ml_doc = "Keep rows in this table according to boolean array" }, { NULL } /* Sentinel */ }; @@ -6585,24 +6849,27 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) npy_intp *shape, dims; tsk_size_t num_samples; tsk_flags_t options = 0; - int filter_sites = true; + int filter_sites = false; int filter_individuals = false; int filter_populations = false; + int filter_nodes = true; + int update_sample_flags = true; int keep_unary = false; int keep_unary_in_individuals = false; int keep_input_roots = false; int reduce_to_site_topology = false; - static char *kwlist[] = { "samples", "filter_sites", "filter_populations", - "filter_individuals", "reduce_to_site_topology", "keep_unary", - "keep_unary_in_individuals", "keep_input_roots", NULL }; + static char *kwlist[] + = { "samples", "filter_sites", "filter_populations", "filter_individuals", + "filter_nodes", "update_sample_flags", "reduce_to_site_topology", + "keep_unary", "keep_unary_in_individuals", "keep_input_roots", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiii", kwlist, &samples, - &filter_sites, &filter_populations, &filter_individuals, - &reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals, - &keep_input_roots)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiiii", kwlist, &samples, + &filter_sites, &filter_populations, &filter_individuals, &filter_nodes, + &update_sample_flags, &reduce_to_site_topology, &keep_unary, + &keep_unary_in_individuals, &keep_input_roots)) { goto out; } samples_array = (PyArrayObject *) PyArray_FROMANY( @@ -6621,6 +6888,12 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) if (filter_populations) { options |= TSK_SIMPLIFY_FILTER_POPULATIONS; } + if (!filter_nodes) { + options |= TSK_SIMPLIFY_NO_FILTER_NODES; + } + if (!update_sample_flags) { + options |= TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS; + } if (reduce_to_site_topology) { options |= TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY; } @@ -8696,6 +8969,46 @@ TreeSequence_mean_descendants(TreeSequence *self, PyObject *args, PyObject *kwds return ret; } +static PyObject * +TreeSequence_extend_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + int max_iter; + tsk_flags_t options = 0; + static char *kwlist[] = { "max_iter", NULL }; + TreeSequence *output = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "i", kwlist, &max_iter)) { + goto out; + } + + output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); + if (output == NULL) { + goto out; + } + output->tree_sequence = PyMem_Malloc(sizeof(*output->tree_sequence)); + if (output->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + + err = tsk_treeseq_extend_edges( + self->tree_sequence, max_iter, options, output->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) output; + output = NULL; +out: + Py_XDECREF(output); + return ret; +} + /* Error value returned from summary_func callback if an error occured. * This is chosen so that it is not a valid tskit error code and so can * never be mistaken for a different error */ @@ -9353,6 +9666,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd return ret; } +static PyObject * +TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, + PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise", + "polarised", NULL }; + PyObject *weights = NULL; + PyObject *indexes = NULL; + PyObject *windows = NULL; + PyArrayObject *weights_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *result_array = NULL; + tsk_size_t num_windows, num_index_tuples; + npy_intp *w_shape, *shape; + tsk_flags_t options = 0; + char *mode = NULL; + int span_normalise = true; + int polarised = false; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes, + &windows, &mode, &span_normalise, &polarised)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (span_normalise) { + options |= TSK_STAT_SPAN_NORMALISE; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + weights_array = (PyArrayObject *) PyArray_FROMANY( + weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY); + if (weights_array == NULL) { + goto out; + } + w_shape = PyArray_DIMS(weights_array); + if (w_shape[0] != (npy_intp) tsk_treeseq_get_num_samples(self->tree_sequence)) { + PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); + goto out; + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_index_tuples = shape[0]; + + result_array = TreeSequence_allocate_results_array( + self, options, num_windows, num_index_tuples); + if (result_array == NULL) { + goto out; + } + err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), + num_index_tuples, PyArray_DATA(indexes_array), num_windows, + PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(weights_array); + Py_XDECREF(indexes_array); + Py_XDECREF(windows_array); + Py_XDECREF(result_array); + return ret; +} + static PyObject * TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -9366,6 +9766,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k self, args, kwds, 2, tsk_treeseq_genetic_relatedness); } +static PyObject * +TreeSequence_genetic_relatedness_weighted( + TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_weighted_stat_method( + self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -9396,6 +9804,78 @@ TreeSequence_f4(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_stat_method(self, args, kwds, 4, tsk_treeseq_f4); } +static PyObject * +TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "windows", "samples", "mode", NULL }; + PyArrayObject *result_array = NULL; + PyObject *windows = NULL; + PyObject *py_samples = Py_None; + char *mode = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *samples_array = NULL; + tsk_flags_t options = 0; + npy_intp *shape, dims[3]; + tsk_size_t num_samples, num_windows; + tsk_id_t *samples = NULL; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|Os", kwlist, &windows, &py_samples, &mode)) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + if (py_samples != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + samples = PyArray_DATA(samples_array); + num_samples = (tsk_size_t) shape[0]; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + dims[0] = num_windows; + dims[1] = num_samples; + dims[2] = num_samples; + result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); + if (result_array == NULL) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_divergence_matrix( + self->tree_sequence, + num_samples, samples, + num_windows, PyArray_DATA(windows_array), + options, PyArray_DATA(result_array)); + Py_END_ALLOW_THREADS + // clang-format on + /* Clang-format insists on doing this in spite of the "off" instruction above */ + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(result_array); + Py_XDECREF(windows_array); + Py_XDECREF(samples_array); + return ret; +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -10080,6 +10560,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes genetic relatedness between sample sets." }, + { .ml_name = "genetic_relatedness_weighted", + .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness between weighted sums of samples." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, @@ -10104,10 +10588,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_f4, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the f4 statistic." }, + { .ml_name = "divergence_matrix", + .ml_meth = (PyCFunction) TreeSequence_divergence_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pairwise divergence matrix." }, { .ml_name = "split_edges", .ml_meth = (PyCFunction) TreeSequence_split_edges, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, + { .ml_name = "extend_edges", + .ml_meth = (PyCFunction) TreeSequence_extend_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Extends edges, creating unary nodes." }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, @@ -10417,6 +10909,29 @@ Tree_seek(Tree *self, PyObject *args) return ret; } +static PyObject * +Tree_seek_index(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t index = 0; + int err; + + if (Tree_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O&", tsk_id_converter, &index)) { + goto out; + } + err = tsk_tree_seek_index(self->tree, index, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyObject * Tree_clear(Tree *self) { @@ -11555,6 +12070,10 @@ static PyMethodDef Tree_methods[] = { .ml_meth = (PyCFunction) Tree_seek, .ml_flags = METH_VARARGS, .ml_doc = "Seeks to the tree at the specified position" }, + { .ml_name = "seek_index", + .ml_meth = (PyCFunction) Tree_seek_index, + .ml_flags = METH_VARARGS, + .ml_doc = "Seeks to the tree at the specified index" }, { .ml_name = "clear", .ml_meth = (PyCFunction) Tree_clear, .ml_flags = METH_NOARGS, diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index 6dcee67d6d..a582f42387 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -19,4 +19,5 @@ pytest==7.1.3 pytest-cov==4.0.0 pytest-xdist==2.5.0 svgwrite==1.4.3 +tszip==0.2.2 xmlunittest==0.5.0 diff --git a/python/requirements/CI-docs/requirements.txt b/python/requirements/CI-docs/requirements.txt index d4a1f0c9fc..e9733957ae 100644 --- a/python/requirements/CI-docs/requirements.txt +++ b/python/requirements/CI-docs/requirements.txt @@ -1,13 +1,9 @@ -breathe==4.34.0 -jupyter-book==0.13.1 -h5py==3.7.0 -jsonschema==3.2.0 #jupyter-book 0.13.1 depends on jsonschema<4 -msprime==1.2.0 -numpy==1.21.6 # Held at 1.21.6 for Python 3.7 compatibility -PyGithub==1.55 -sphinx-argparse==0.3.1 -sphinx-autodoc-typehints==1.18.3 # Held at 1.18.3 as that depends on sphinx>=5.2.1 while jupyter-book 0.13.1 depends on sphinx<5 +jupyter-book==0.15.1 +breathe==4.35.0 +sphinx-autodoc-typehints==1.19.1 sphinx-issues==3.0.1 -sphinxcontrib-prettyspecialmethods==0.1.0 +sphinx-argparse==0.4.0 +numpy==1.25.1 svgwrite==1.4.3 -tskit-book-theme==0.3.2 \ No newline at end of file +msprime==1.2.0 +tskit-book-theme \ No newline at end of file diff --git a/python/requirements/CI-tests-conda/requirements.txt b/python/requirements/CI-tests-conda/requirements.txt index 5a632c2821..453d829799 100644 --- a/python/requirements/CI-tests-conda/requirements.txt +++ b/python/requirements/CI-tests-conda/requirements.txt @@ -1,3 +1,4 @@ msprime==1.2.0 kastore==0.3.2 jsonschema==4.16.0 +h5py==3.7.0 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index f168af3e01..9f3b31ef3d 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -1,9 +1,9 @@ lshmm==0.0.4 -numpy==1.21.6 # Held at 1.21.6 for Python 3.7 compatibility +numpy==1.21.6; python_version < '3.11' # Held at 1.21.6 for Python 3.7 compatibility +numpy==1.24.1; python_version > '3.10' pytest==7.1.3 pytest-cov==4.0.0 pytest-xdist==2.5.0 -h5py==3.7.0 svgwrite==1.4.3 portion==2.3.0 xmlunittest==0.5.0 @@ -11,4 +11,5 @@ biopython==1.79 dendropy==4.5.2 networkx==2.6.3 # Held at 2.6.3 for Python 3.7 compatibility msgpack==1.0.4 -newick==1.3.2 \ No newline at end of file +newick==1.3.2 +tszip==0.2.2 \ No newline at end of file diff --git a/python/requirements/benchmark.txt b/python/requirements/benchmark.txt new file mode 100644 index 0000000000..12a0be4060 --- /dev/null +++ b/python/requirements/benchmark.txt @@ -0,0 +1,9 @@ +click +psutil +tqdm +matplotlib +si-prefix +jsonschema +svgwrite +msprime +PyYAML \ No newline at end of file diff --git a/python/requirements/development.txt b/python/requirements/development.txt index 4b3753de4c..ba48881723 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -36,6 +36,7 @@ sphinx-jupyterbook-latex sphinxcontrib-prettyspecialmethods tqdm tskit-book-theme +tszip pydata_sphinx_theme>=0.7.2 svgwrite>=1.1.10 xmlunittest diff --git a/python/setup.cfg b/python/setup.cfg index aca17749de..13a1506f25 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -20,6 +20,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Programming Language :: Python :: 3 :: Only Development Status :: 5 - Production/Stable Environment :: Other Environment diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 1f064b6a8d..f069f04f2e 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -53,7 +53,7 @@ def __init__(self, num_nodes): @classmethod def from_tree(cls, tree): - ret = PythonTree(tree.num_nodes) + ret = PythonTree(tree.tree_sequence.num_nodes) ret.left, ret.right = tree.get_interval() ret.site_list = list(tree.sites()) ret.index = tree.get_index() diff --git a/python/tests/conftest.py b/python/tests/conftest.py index d2539ed0fb..d23c019003 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -111,6 +111,14 @@ def ts_fixture(): return tsutil.all_fields_ts() +@fixture(scope="session") +def ts_fixture_for_simplify(): + """ + A tree sequence with data in all fields execpt edge metadata and migrations + """ + return tsutil.all_fields_ts(edge_metadata=False, migrations=False) + + @fixture(scope="session") def replicate_ts_fixture(): """ diff --git a/python/tests/data/svg/internal_sample_ts.svg b/python/tests/data/svg/internal_sample_ts.svg index 8527c34b9d..42d55392a0 100644 --- a/python/tests/data/svg/internal_sample_ts.svg +++ b/python/tests/data/svg/internal_sample_ts.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg index 5fb67266a4..0ab913202c 100644 --- a/python/tests/data/svg/tree.svg +++ b/python/tests/data/svg/tree.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_both_axes.svg b/python/tests/data/svg/tree_both_axes.svg index 1cc2c584d2..a86dd3d379 100644 --- a/python/tests/data/svg/tree_both_axes.svg +++ b/python/tests/data/svg/tree_both_axes.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + @@ -29,7 +29,7 @@ Time - + diff --git a/python/tests/data/svg/tree_muts.svg b/python/tests/data/svg/tree_muts.svg index 09abff1e02..3d2c017317 100644 --- a/python/tests/data/svg/tree_muts.svg +++ b/python/tests/data/svg/tree_muts.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_muts_all_edge.svg b/python/tests/data/svg/tree_muts_all_edge.svg index 664e649ad6..adf4eb1a09 100644 --- a/python/tests/data/svg/tree_muts_all_edge.svg +++ b/python/tests/data/svg/tree_muts_all_edge.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/tree_timed_muts.svg b/python/tests/data/svg/tree_timed_muts.svg index 3efd7f32e2..0b79065c61 100644 --- a/python/tests/data/svg/tree_timed_muts.svg +++ b/python/tests/data/svg/tree_timed_muts.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/tree_x_axis.svg b/python/tests/data/svg/tree_x_axis.svg index be63748d24..e6c7af5cd6 100644 --- a/python/tests/data/svg/tree_x_axis.svg +++ b/python/tests/data/svg/tree_x_axis.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ pos on genome - + diff --git a/python/tests/data/svg/tree_y_axis_rank.svg b/python/tests/data/svg/tree_y_axis_rank.svg index 413b99c6db..9373285104 100644 --- a/python/tests/data/svg/tree_y_axis_rank.svg +++ b/python/tests/data/svg/tree_y_axis_rank.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Time (relative steps) - + diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg index 63d68cb5ee..d413b0a08b 100644 --- a/python/tests/data/svg/ts.svg +++ b/python/tests/data/svg/ts.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_max_trees.svg b/python/tests/data/svg/ts_max_trees.svg new file mode 100644 index 0000000000..3f3578de88 --- /dev/null +++ b/python/tests/data/svg/ts_max_trees.svg @@ -0,0 +1,455 @@ + + + + + + + + + + + + + + + + + + Genome position + + + + + + + + + 15 + + + + + + 16 + + + + + + 20 + + + + + + 93 + + + + + + 98 + + + + + + + + + + + + + + + + + + + + + + + + + Time (generations) + + + + + + + 0.00 + + + + + + 0.25 + + + + + + 0.32 + + + + + + 0.56 + + + + + + 0.57 + + + + + + 1.63 + + + + + + 2.32 + + + + + + 3.06 + + + + + + 4.15 + + + + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 33 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 25 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + + + 1 + + + 8 + + + + 16 + + + 30 + + + + + + + 31 trees + + + skipped + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + + + 8 + + + 4 + + + + 6 + + + + 7 + + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 42 + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + 4 + + + + 6 + + + + 7 + + + + + + 9 + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 39 + + + + + + diff --git a/python/tests/data/svg/ts_max_trees_treewise.svg b/python/tests/data/svg/ts_max_trees_treewise.svg new file mode 100644 index 0000000000..84b1929b02 --- /dev/null +++ b/python/tests/data/svg/ts_max_trees_treewise.svg @@ -0,0 +1,429 @@ + + + + + + + + + + Genome position + + + + + + + + + 15 + + + + + + 16 + + + + + + 20 + + + + + + 93 + + + + + + 98 + + + + + + + Time (generations) + + + + + + + 0.00 + + + + + + 0.25 + + + + + + 0.32 + + + + + + 0.56 + + + + + + 0.57 + + + + + + 1.63 + + + + + + 2.32 + + + + + + 3.06 + + + + + + 4.15 + + + + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 33 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + 8 + + + + 16 + + + 25 + + + + + + + + + + + 1 + + + + + + 4 + + + + + 5 + + + + 7 + + + + 11 + + + + + + 0 + + + + + + 2 + + + + + 3 + + + + + + 1 + + + 8 + + + + 16 + + + 30 + + + + + + + 31 trees + + + skipped + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + + + 8 + + + 4 + + + + 6 + + + + 7 + + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 42 + + + + + + + + + + + 0 + + + + + + 5 + + + + + + 1 + + + + + 4 + + + + 6 + + + + 7 + + + + + + 9 + + + 10 + + + + + + 2 + + + + + 3 + + + + 15 + + + 39 + + + + + + diff --git a/python/tests/data/svg/ts_multiroot.svg b/python/tests/data/svg/ts_multiroot.svg index ad61d5ec80..28dba3aa4e 100644 --- a/python/tests/data/svg/ts_multiroot.svg +++ b/python/tests/data/svg/ts_multiroot.svg @@ -1,12 +1,12 @@ - + - - + + @@ -19,7 +19,7 @@ Genome position - + @@ -141,7 +141,7 @@ Time (generations) - + diff --git a/python/tests/data/svg/ts_mut_highlight.svg b/python/tests/data/svg/ts_mut_highlight.svg index e8404ef4f7..0f7276d245 100644 --- a/python/tests/data/svg/ts_mut_highlight.svg +++ b/python/tests/data/svg/ts_mut_highlight.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mut_times.svg b/python/tests/data/svg/ts_mut_times.svg index 2ba161bb41..3bd6fb5ef3 100644 --- a/python/tests/data/svg/ts_mut_times.svg +++ b/python/tests/data/svg/ts_mut_times.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mut_times_logscale.svg b/python/tests/data/svg/ts_mut_times_logscale.svg index 669d4d97f6..86382d3cf8 100644 --- a/python/tests/data/svg/ts_mut_times_logscale.svg +++ b/python/tests/data/svg/ts_mut_times_logscale.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mutations_no_edges.svg b/python/tests/data/svg/ts_mutations_no_edges.svg index 547d0cc75f..4feb1e2a7a 100644 --- a/python/tests/data/svg/ts_mutations_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_no_edges.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_mutations_timed_no_edges.svg b/python/tests/data/svg/ts_mutations_timed_no_edges.svg index 37c5dc1fe6..de064ffc36 100644 --- a/python/tests/data/svg/ts_mutations_timed_no_edges.svg +++ b/python/tests/data/svg/ts_mutations_timed_no_edges.svg @@ -1,7 +1,7 @@ - + @@ -12,7 +12,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_no_axes.svg b/python/tests/data/svg/ts_no_axes.svg index 1e3c2ff479..051cbb1fb0 100644 --- a/python/tests/data/svg/ts_no_axes.svg +++ b/python/tests/data/svg/ts_no_axes.svg @@ -1,7 +1,7 @@ - + diff --git a/python/tests/data/svg/ts_plain.svg b/python/tests/data/svg/ts_plain.svg index f0586a9dae..6bb71f35a8 100644 --- a/python/tests/data/svg/ts_plain.svg +++ b/python/tests/data/svg/ts_plain.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_plain_no_xlab.svg b/python/tests/data/svg/ts_plain_no_xlab.svg index fdf2bf4618..648a306187 100644 --- a/python/tests/data/svg/ts_plain_no_xlab.svg +++ b/python/tests/data/svg/ts_plain_no_xlab.svg @@ -1,12 +1,12 @@ - + - + diff --git a/python/tests/data/svg/ts_plain_y.svg b/python/tests/data/svg/ts_plain_y.svg index beac22d28b..9e4499d3f7 100644 --- a/python/tests/data/svg/ts_plain_y.svg +++ b/python/tests/data/svg/ts_plain_y.svg @@ -1,7 +1,7 @@ - + @@ -9,7 +9,7 @@ Genome position - + @@ -53,7 +53,7 @@ Time - + diff --git a/python/tests/data/svg/ts_rank.svg b/python/tests/data/svg/ts_rank.svg index 3455d28316..b8527b6638 100644 --- a/python/tests/data/svg/ts_rank.svg +++ b/python/tests/data/svg/ts_rank.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Node time - + diff --git a/python/tests/data/svg/ts_x_lim.svg b/python/tests/data/svg/ts_x_lim.svg index ce80434d50..e0aef7f41d 100644 --- a/python/tests/data/svg/ts_x_lim.svg +++ b/python/tests/data/svg/ts_x_lim.svg @@ -1,7 +1,7 @@ - + @@ -14,7 +14,7 @@ Genome position - + diff --git a/python/tests/data/svg/ts_xlabel.svg b/python/tests/data/svg/ts_xlabel.svg index da4a0b8b85..8af7c9dd36 100644 --- a/python/tests/data/svg/ts_xlabel.svg +++ b/python/tests/data/svg/ts_xlabel.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ genomic position (bp) - + diff --git a/python/tests/data/svg/ts_y_axis.svg b/python/tests/data/svg/ts_y_axis.svg index 202f5d4a65..dccc399050 100644 --- a/python/tests/data/svg/ts_y_axis.svg +++ b/python/tests/data/svg/ts_y_axis.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time (generations) - + diff --git a/python/tests/data/svg/ts_y_axis_log.svg b/python/tests/data/svg/ts_y_axis_log.svg index ac0051336f..70afefd41f 100644 --- a/python/tests/data/svg/ts_y_axis_log.svg +++ b/python/tests/data/svg/ts_y_axis_log.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time (log scale) - + diff --git a/python/tests/data/svg/ts_y_axis_regular.svg b/python/tests/data/svg/ts_y_axis_regular.svg index d2d866e51f..f5e4240cde 100644 --- a/python/tests/data/svg/ts_y_axis_regular.svg +++ b/python/tests/data/svg/ts_y_axis_regular.svg @@ -1,7 +1,7 @@ - + @@ -16,7 +16,7 @@ Genome position - + @@ -99,7 +99,7 @@ Time - + diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 5f7e838e6c..6505aec05d 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -111,6 +111,8 @@ def __init__( keep_unary=False, keep_unary_in_individuals=False, keep_input_roots=False, + filter_nodes=True, + update_sample_flags=True, ): self.ts = ts self.n = len(sample) @@ -119,6 +121,8 @@ def __init__( self.filter_sites = filter_sites self.filter_populations = filter_populations self.filter_individuals = filter_individuals + self.filter_nodes = filter_nodes + self.update_sample_flags = update_sample_flags self.keep_unary = keep_unary self.keep_unary_in_individuals = keep_unary_in_individuals self.keep_input_roots = keep_input_roots @@ -130,35 +134,57 @@ def __init__( self.tables.clear() self.edge_buffer = {} self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 + self.is_sample = np.zeros(ts.num_nodes, dtype=np.int8) self.mutation_node_map = [-1 for _ in range(self.num_mutations)] self.samples = set(sample) self.sort_offset = -1 # We keep a map of input nodes to mutations. self.mutation_map = [[] for _ in range(ts.num_nodes)] - position = ts.tables.sites.position - site = ts.tables.mutations.site - node = ts.tables.mutations.node + position = ts.sites_position + site = ts.mutations_site + node = ts.mutations_node for mutation_id in range(ts.num_mutations): site_position = position[site[mutation_id]] self.mutation_map[node[mutation_id]].append((site_position, mutation_id)) + for sample_id in sample: - output_id = self.record_node(sample_id, is_sample=True) - self.add_ancestry(sample_id, 0, self.sequence_length, output_id) + self.is_sample[sample_id] = 1 + + if not self.filter_nodes: + # NOTE In the C implementation we would really just not touch the + # original tables. + self.tables.nodes.replace_with(self.ts.tables.nodes) + if self.update_sample_flags: + flags = self.tables.nodes.flags + # Zero out other sample flags + flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE) + flags[sample] |= tskit.NODE_IS_SAMPLE + self.tables.nodes.flags = flags.astype(np.uint32) + + self.node_id_map[:] = np.arange(ts.num_nodes) + for sample_id in sample: + self.add_ancestry(sample_id, 0, self.sequence_length, sample_id) + else: + for sample_id in sample: + output_id = self.record_node(sample_id) + self.add_ancestry(sample_id, 0, self.sequence_length, output_id) + self.position_lookup = None if self.reduce_to_site_topology: self.position_lookup = np.hstack([[0], position, [self.sequence_length]]) - def record_node(self, input_id, is_sample=False): + def record_node(self, input_id): """ Adds a new node to the output table corresponding to the specified input node ID. """ node = self.ts.node(input_id) flags = node.flags - # Need to zero out the sample flag - flags &= ~tskit.NODE_IS_SAMPLE - if is_sample: - flags |= tskit.NODE_IS_SAMPLE + if self.update_sample_flags: + # Need to zero out the sample flag + flags &= ~tskit.NODE_IS_SAMPLE + if self.is_sample[input_id]: + flags |= tskit.NODE_IS_SAMPLE output_id = self.tables.nodes.append(node.replace(flags=flags)) self.node_id_map[input_id] = output_id return output_id @@ -275,7 +301,7 @@ def merge_labeled_ancestors(self, S, input_id): The new parent must be assigned and any overlapping segments coalesced. """ output_id = self.node_id_map[input_id] - is_sample = output_id != -1 + is_sample = self.is_sample[input_id] if is_sample: # Free up the existing ancestry mapping. x = self.A_tail[input_id] @@ -319,7 +345,7 @@ def merge_labeled_ancestors(self, S, input_id): self.add_ancestry(input_id, prev_right, self.sequence_length, output_id) if output_id != -1: num_edges = self.flush_edges() - if num_edges == 0 and not is_sample: + if self.filter_nodes and num_edges == 0 and not is_sample: self.rewind_node(input_id, output_id) def extract_ancestry(self, edge): diff --git a/python/tests/test_coaltime_distribution.py b/python/tests/test_coaltime_distribution.py index 30ade948a1..715677d99e 100644 --- a/python/tests/test_coaltime_distribution.py +++ b/python/tests/test_coaltime_distribution.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -25,6 +25,7 @@ """ import msprime import numpy as np +import pytest import tests import tskit @@ -181,7 +182,7 @@ def ts_two_trees_ten_leaves(self): @tests.cached_example def ts_many_edge_diffs(self): ts = msprime.sim_ancestry( - samples=75, + samples=80, ploidy=1, sequence_length=4, recombination_rate=10, @@ -216,31 +217,31 @@ def test_time(self): t = np.array([0, 1, 5, 8, 29]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array([[0, 1, 1, 1, 1]]).T distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array([[0, 1, 2, 3, 4]]).T distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array([[0, 0.25, 0.50, 0.75, 1]]).T distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestPairWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -277,13 +278,13 @@ def test_time(self): t = np.array([0, 1, 5, 8, 29]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array( @@ -297,7 +298,7 @@ def test_weights(self): ) distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array( @@ -311,7 +312,7 @@ def test_cum_weights(self): ) distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array( @@ -325,7 +326,7 @@ def test_quantile(self): ) distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestTrioFirstWeightedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -376,13 +377,13 @@ def test_time(self): t = np.array([0.0, 1.0, 2.0, 2.0, 6.0, 8.00]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array( @@ -397,7 +398,7 @@ def test_weights(self): ) distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array( @@ -412,7 +413,7 @@ def test_cum_weights(self): ) distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array( @@ -429,7 +430,7 @@ def test_quantile(self): q /= q[-1, :] distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq[:, :-1]) and np.all(np.isnan(tq[:, -1])) + np.testing.assert_allclose(q, tq[:, :-1]) and np.all(np.isnan(tq[:, -1])) class TestSingleBlockCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -461,31 +462,32 @@ def test_time(self): t = np.array([0.0, 0.54, 0.59, 0.73, 1.74]) distr = self.coalescence_time_distribution() tt = distr.tables[0].time - assert np.allclose(t, tt) + np.testing.assert_allclose(t, tt) def test_block(self): b = np.array([0, 0, 0, 0, 0]) distr = self.coalescence_time_distribution() tb = distr.tables[0].block - assert np.allclose(b, tb) + np.testing.assert_allclose(b, tb) def test_weights(self): w = np.array([[0, 1, 2, 1, 2]]).T distr = self.coalescence_time_distribution() tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) def test_cum_weights(self): c = np.array([[0, 1, 3, 4, 6]]).T distr = self.coalescence_time_distribution() tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) and np.allclose(c, tc) + np.testing.assert_allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_quantile(self): q = np.array([[0.0, 1 / 6, 3 / 6, 4 / 6, 1.0]]).T distr = self.coalescence_time_distribution() tq = distr.tables[0].quantile - assert np.allclose(q, tq) + np.testing.assert_allclose(q, tq) class TestWindowedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -523,7 +525,8 @@ def test_time(self): distr = self.coalescence_time_distribution() tt1 = distr.tables[0].time tt2 = distr.tables[1].time - assert np.allclose(t1, tt1) and np.allclose(t2, tt2) + np.testing.assert_allclose(t1, tt1) + np.testing.assert_allclose(t2, tt2) def test_block(self): b1 = np.array([0, 0, 0, 0]) @@ -531,7 +534,8 @@ def test_block(self): distr = self.coalescence_time_distribution() tb1 = distr.tables[0].block tb2 = distr.tables[1].block - assert np.allclose(b1, tb1) and np.allclose(b2, tb2) + np.testing.assert_allclose(b1, tb1) + np.testing.assert_allclose(b2, tb2) def test_weights(self): w1 = np.array([[0, 1, 1, 1]]).T @@ -539,7 +543,8 @@ def test_weights(self): distr = self.coalescence_time_distribution() tw1 = distr.tables[0].weights tw2 = distr.tables[1].weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) def test_cum_weights(self): c1 = np.array([[0, 1, 2, 3]]).T @@ -547,7 +552,8 @@ def test_cum_weights(self): distr = self.coalescence_time_distribution() tc1 = distr.tables[0].cum_weights tc2 = distr.tables[1].cum_weights - assert np.allclose(c1, tc1) and np.allclose(c2, tc2) + np.testing.assert_allclose(c1, tc1) + np.testing.assert_allclose(c2, tc2) def test_quantile(self): e1 = np.array([[0.0, 1 / 3, 2 / 3, 1.0]]).T @@ -555,7 +561,8 @@ def test_quantile(self): distr = self.coalescence_time_distribution() te1 = distr.tables[0].quantile te2 = distr.tables[1].quantile - assert np.allclose(e1, te1) and np.allclose(e2, te2) + np.testing.assert_allclose(e1, te1) + np.testing.assert_allclose(e2, te2) class TestCoalescenceTimeDistributionPointMethods(TestCoalescenceTimeDistribution): @@ -595,7 +602,7 @@ def test_ecdf(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) te = distr.ecdf(t) - assert np.allclose(e, te) + np.testing.assert_allclose(e, te) def test_num_coalesced(self): c = np.array([0, 0, 1, 1, 3, 3, 4, 4, 6, 6]).reshape(1, 10, 1) @@ -605,7 +612,7 @@ def test_num_coalesced(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) tc = distr.num_coalesced(t) - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc) def test_num_uncoalesced(self): u = np.array([6, 6, 5, 5, 3, 3, 2, 2, 0, 0]).reshape(1, 10, 1) @@ -615,7 +622,28 @@ def test_num_uncoalesced(self): [0.0, 0.25, et[1], 0.57, et[2], 0.65, et[3], 1.00, et[4], 2.00], ) tu = distr.num_uncoalesced(t) - assert np.allclose(u, tu) + np.testing.assert_allclose(u, tu) + + def test_interpolated_quantile(self): + x = np.array( + [ + 0.54, + 0.558, + 0.576, + 0.5993, + 0.6413, + 0.6833, + 0.7253, + 0.9609, + 1.2206, + 1.4803, + 1.74, + ] + ).reshape(1, 11, 1) + distr = self.coalescence_time_distribution() + q = np.linspace(0, 1, 11) + qx = distr.quantile(q).round(4) + np.testing.assert_allclose(x, qx) class TestCoalescenceTimeDistributionIntervalMethods(TestCoalescenceTimeDistribution): @@ -667,7 +695,7 @@ def test_coalescence_probability_in_intervals(self): et = distr.tables[0].time t = np.array([0.00, 0.55, et[3], 2.00]) tp = distr.coalescence_probability_in_intervals(t) - assert np.allclose(p, tp) + np.testing.assert_allclose(p, tp) def test_coalescence_probability_in_intervals_oor(self): distr = self.coalescence_time_distribution() @@ -681,7 +709,7 @@ def test_coalescence_rate_in_intervals(self): et = distr.tables[0].time t = np.array([0.00, 0.55, et[3], 2.00]) tc = distr.coalescence_rate_in_intervals(t) - assert np.allclose(c, tc) + np.testing.assert_allclose(c, tc, atol=1e-6) def test_coalescence_rate_in_intervals_oor(self): distr = self.coalescence_time_distribution() @@ -694,7 +722,7 @@ def test_mean(self): distr = self.coalescence_time_distribution() et = distr.tables[0].time tm = distr.mean(et[2]) - assert np.allclose(m, tm) + np.testing.assert_allclose(m, tm) def test_mean_oor(self): distr = self.coalescence_time_distribution() @@ -754,7 +782,8 @@ def test_cum_weights(self): boot_distr = self.coalescence_time_distribution_boot() tw1 = boot_distr.tables[0].cum_weights tw2 = boot_distr.tables[1].cum_weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) def test_ecdf(self): e = np.array( @@ -766,20 +795,20 @@ def test_ecdf(self): boot_distr = self.coalescence_time_distribution_boot() t = np.array([0.54, 0.55, 0.59, 0.60, 0.73, 0.74, 1.74]) te = boot_distr.ecdf(t) - assert np.allclose(e, te) + np.testing.assert_allclose(e, te) def test_mean(self): m = np.array([[1.02, 0.9566667]]) boot_distr = self.coalescence_time_distribution_boot() tm = boot_distr.mean() - assert np.allclose(m, tm) + np.testing.assert_allclose(m, tm) def test_boot_of_boot_equivalence(self): boot_distr = self.coalescence_time_distribution_boot() reboot_distr = next(boot_distr.block_bootstrap(1, 3)) cw1 = boot_distr.tables[1].cum_weights cw2 = reboot_distr.tables[1].cum_weights - assert np.allclose(cw1, cw2) + np.testing.assert_allclose(cw1, cw2) class TestCoalescenceTimeDistributionEmpty(TestCoalescenceTimeDistribution): @@ -790,11 +819,16 @@ class TestCoalescenceTimeDistributionEmpty(TestCoalescenceTimeDistribution): def coalescence_time_distribution(self): ts = self.ts_two_trees_four_leaves() - def null_weight(node, tree, sample_sets): - return np.array([0, 0]) + def null_weight_init(node, sample_sets): + blank = np.array([[0, 0]], dtype=np.float64) + return (blank,) + + def null_weight_update(blank): + blank = np.array([[0, 0]], dtype=np.float64) + return blank, (blank,) distr = ts.coalescence_time_distribution( - weight_func=null_weight, + weight_func=(null_weight_init, null_weight_update), span_normalise=False, ) return distr @@ -834,6 +868,18 @@ def test_coalescence_rate_in_intervals(self): tc = distr.coalescence_rate_in_intervals(t) assert np.all(np.isnan(tc)) + def test_quantile(self): + distr = self.coalescence_time_distribution() + t = np.array([0.0, 0.5, 1.0]) + tq = distr.quantile(t) + assert np.all(np.isnan(tq)) + + def test_resample(self): + distr = self.coalescence_time_distribution() + boot_distr = next(distr.block_bootstrap(1, 3)) + assert np.all(boot_distr.tables[0].cum_weights == 0) + assert np.all(np.isnan(boot_distr.tables[0].quantile)) + class TestCoalescenceTimeDistributionNullWeight(TestCoalescenceTimeDistribution): """ @@ -844,11 +890,16 @@ class TestCoalescenceTimeDistributionNullWeight(TestCoalescenceTimeDistribution) def coalescence_time_distribution(self): ts = self.ts_two_trees_four_leaves() - def half_empty(node, tree, sample_sets): - return np.array([1, 0]) + def half_empty_init(node, sample_sets): + blank = np.array([[1, 0]], dtype=np.float64) + return (blank,) + + def half_empty_update(blank): + blank = np.array([[1, 0]], dtype=np.float64) + return blank, (blank,) distr = ts.coalescence_time_distribution( - weight_func=half_empty, + weight_func=(half_empty_init, half_empty_update), span_normalise=False, ) return distr @@ -888,6 +939,20 @@ def test_coalescence_rate_in_intervals(self): tr = distr.coalescence_rate_in_intervals(t) assert np.all(np.isnan(tr[1, :])) and np.all(~np.isnan(tr[0, :])) + def test_quantile(self): + distr = self.coalescence_time_distribution() + t = np.array([0.0, 0.5, 1.0]) + tq = distr.quantile(t) + assert np.all(np.isnan(tq[1, :])) and np.all(~np.isnan(tq[0, :])) + + def test_resample(self): + distr = self.coalescence_time_distribution() + boot_distr = next(distr.block_bootstrap(1, 3)) + assert np.all(boot_distr.tables[0].cum_weights[:, 1] == 0) + assert np.all(np.isnan(boot_distr.tables[0].quantile[:, 1])) + assert np.any(boot_distr.tables[0].cum_weights[:, 0] > 0) + assert np.all(~np.isnan(boot_distr.tables[0].quantile[:, 0])) + class TestCoalescenceTimeDistributionTableResize(TestCoalescenceTimeDistribution): """ @@ -921,12 +986,18 @@ def coalescence_time_distribution(self): ts = self.ts_eight_trees_two_leaves() bk = [t.interval.left for t in ts.trees()][::4] + [ts.sequence_length] - def count_root(node, tree, sample_sets): - weight = int(node == tree.get_root()) - return np.array([weight]) + def count_root_init(node, sample_sets): + all_samples = [i for s in sample_sets for i in s] + state = np.array([[node == i for i in all_samples]], dtype=np.float64) + return (state,) + + def count_root_update(child_state): + state = np.sum(child_state, axis=0, keepdims=True) + is_root = np.array([[np.all(state > 0)]], dtype=np.float64) + return is_root, (state,) distr = ts.coalescence_time_distribution( - weight_func=count_root, + weight_func=(count_root_init, count_root_update), window_breaks=np.array(bk), blocks_per_window=2, span_normalise=False, @@ -936,12 +1007,12 @@ def count_root(node, tree, sample_sets): def test_blocks_per_window(self): distr = self.coalescence_time_distribution() bpw = np.array([i.num_blocks for i in distr.tables]) - assert np.allclose(bpw, 2) + np.testing.assert_allclose(bpw, 2) def test_trees_per_window(self): distr = self.coalescence_time_distribution() tpw = np.array([np.sum(distr.tables[i].weights) for i in range(2)]) - assert np.allclose(tpw, 4) + np.testing.assert_allclose(tpw, 4) def test_trees_per_block(self): distr = self.coalescence_time_distribution() @@ -949,7 +1020,80 @@ def test_trees_per_block(self): for table in distr.tables: for block in range(2): tpb += [np.sum(table.weights[table.block == block])] - assert np.allclose(tpb, 2) + np.testing.assert_allclose(tpb, 2) + + +class TestCoalescenceTimeDistributionBlockedVsUnblocked( + TestCoalescenceTimeDistribution +): + """ + Test that methods give the same result regardless of how trees are blocked. + """ + + def coalescence_time_distribution(self, num_blocks=1): + ts = self.ts_many_edge_diffs() + sample_sets = [list(range(10)), list(range(20, 40)), list(range(70, 80))] + distr = ts.coalescence_time_distribution( + sample_sets=sample_sets, + weight_func="pair_coalescence_events", + blocks_per_window=num_blocks, + span_normalise=True, + ) + return distr + + def test_ecdf(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose(distr_noblock.ecdf(t), distr_block.ecdf(t)) + + def test_num_coalesced(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.num_coalesced(t), distr_block.num_coalesced(t) + ) + + def test_num_uncoalesced(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.num_uncoalesced(t), distr_block.num_uncoalesced(t) + ) + + def test_quantile(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + q = np.linspace(0, 1, 11) + np.testing.assert_allclose(distr_noblock.quantile(q), distr_block.quantile(q)) + + def test_mean(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = distr_noblock.tables[0].time[-1] / 2 + np.testing.assert_allclose( + distr_noblock.mean(since=t), distr_block.mean(since=t) + ) + + def test_coalescence_rate_in_intervals(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.coalescence_rate_in_intervals(t), + distr_block.coalescence_rate_in_intervals(t), + ) + + def test_coalescence_probability_in_intervals(self): + distr_noblock = self.coalescence_time_distribution(num_blocks=1) + distr_block = self.coalescence_time_distribution(num_blocks=10) + t = np.linspace(0, distr_noblock.tables[0].time[-1] + 1, 5) + np.testing.assert_allclose( + distr_noblock.coalescence_probability_in_intervals(t), + distr_block.coalescence_probability_in_intervals(t), + ) class TestCoalescenceTimeDistributionRunningUpdate(TestCoalescenceTimeDistribution): @@ -957,23 +1101,23 @@ class TestCoalescenceTimeDistributionRunningUpdate(TestCoalescenceTimeDistributi When traversing trees, weights are updated for nodes whose descendant subtree has changed. This is done by taking the parents of added edges, and tracing ancestors down to the root. This class tests that this "running update" - scheme produces the same results as calculating weights separately for each - tree. + scheme produces the correct result. """ - # TODO: when missing data handling is implemented, test here - - def coalescence_time_distribution(self, ts): - brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) - smp_set = np.arange(0, ts.num_samples) - smp_set = np.floor_divide((len(brk) - 1) * smp_set, ts.num_samples) - smp_set = [np.where(smp_set == i)[0].tolist() for i in range(len(brk) - 1)] + def coalescence_time_distribution_running(self, ts, brk, sets=2): + n = ts.num_samples // sets + smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] distr = ts.coalescence_time_distribution( sample_sets=smp_set, window_breaks=brk, weight_func="trio_first_coalescence_events", span_normalise=False, ) + return distr + + def coalescence_time_distribution_split(self, ts, brk, sets=2): + n = ts.num_samples // sets + smp_set = [list(range(i, i + n)) for i in range(0, ts.num_samples, n)] distr_by_win = [] for left, right in zip(brk[:-1], brk[1:]): ts_trim = ts.keep_intervals([[left, right]]).trim() @@ -981,20 +1125,65 @@ def coalescence_time_distribution(self, ts): ts_trim.coalescence_time_distribution( sample_sets=smp_set, weight_func="trio_first_coalescence_events", + span_normalise=False, ) ] - return distr, distr_by_win + return distr_by_win def test_many_edge_diffs(self): + """ + Test that ts windowed by tree gives same result as set of single trees. + """ ts = self.ts_many_edge_diffs() - distr, distr_win = self.coalescence_time_distribution(ts) + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + distr = self.coalescence_time_distribution_running(ts, brk) + distr_win = self.coalescence_time_distribution_split(ts, brk) time_breaks = np.array([np.inf]) updt = distr.num_coalesced(time_breaks) sepr = np.zeros(updt.shape) for i, d in enumerate(distr_win): c = d.num_coalesced(time_breaks) sepr[:, :, i] = c.reshape((c.shape[0], 1)) - assert np.allclose(sepr, updt) + np.testing.assert_allclose(sepr, updt) + + def test_missing_trees(self): + """ + Test that ts with half of each tree masked gives same result as unmasked ts. + """ + ts = self.ts_many_edge_diffs() + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + mask = np.array( + [ + [tr.interval.left, (tr.interval.right + tr.interval.left) / 2] + for tr in ts.trees() + ] + ) + ts_mask = ts.delete_intervals(mask) + distr = self.coalescence_time_distribution_running(ts, brk) + distr_mask = self.coalescence_time_distribution_running(ts_mask, brk) + time_breaks = np.array([np.inf]) + updt = distr.num_coalesced(time_breaks) + updt_mask = distr_mask.num_coalesced(time_breaks) + np.testing.assert_allclose(updt, updt_mask) + + def test_unary_nodes(self): + """ + Test that ts with unary nodes gives same result as ts with unary nodes removed. + """ + ts = self.ts_many_edge_diffs() + ts_unary = ts.simplify( + samples=list(range(ts.num_samples // 2)), keep_unary=True + ) + ts_nounary = ts.simplify( + samples=list(range(ts.num_samples // 2)), keep_unary=False + ) + brk = np.array([t.interval.left for t in ts.trees()] + [ts.sequence_length]) + distr_unary = self.coalescence_time_distribution_running(ts_unary, brk) + distr_nounary = self.coalescence_time_distribution_running(ts_nounary, brk) + time_breaks = np.array([np.inf]) + updt_unary = distr_unary.num_coalesced(time_breaks) + updt_nounary = distr_nounary.num_coalesced(time_breaks) + np.testing.assert_allclose(updt_unary, updt_nounary) class TestSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -1016,24 +1205,39 @@ class TestSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): Uniform weights on nodes summed over trees, weighted by tree span """ - def coalescence_time_distribution(self): + def coalescence_time_distribution(self, mask_half_of_each_tree=False): + """ + Methods should give the same result if half of each tree is masked, + because "span weights" are normalised using the accessible (nonmissing) + portion of the tree sequence. + """ ts = self.ts_two_trees_four_leaves() + if mask_half_of_each_tree: + mask = np.array( + [ + [t.interval.left, (t.interval.right + t.interval.left) / 2] + for t in ts.trees() + ] + ) + ts = ts.delete_intervals(mask) distr = ts.coalescence_time_distribution( span_normalise=True, ) return distr - def test_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_weights(self, with_missing_data): w = np.array([[0, 0.12, 1.0, 0.88, 1.0]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tw = distr.tables[0].weights - assert np.allclose(w, tw) + np.testing.assert_allclose(w, tw) - def test_cum_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_cum_weights(self, with_missing_data): c = np.array([[0, 0.12, 1.12, 2.00, 3.00]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tc = distr.tables[0].cum_weights - assert np.allclose(c, tc) and np.allclose(c, tc) + np.testing.assert_allclose(c, tc) class TestWindowedSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribution): @@ -1058,9 +1262,19 @@ class TestWindowedSpanNormalisedCoalescenceTimeTable(TestCoalescenceTimeDistribu """ @tests.cached_example - def coalescence_time_distribution(self): + def coalescence_time_distribution(self, mask_half_of_each_tree=False): + """ + Methods should give the same result if half of each tree is masked, + because "span weights" are normalised using the accessible (nonmissing) + portion of the tree sequence. + """ ts = self.ts_two_trees_four_leaves() gen_breaks = np.array([0.0, 0.5, 1.0]) + if mask_half_of_each_tree: + breaks = [i for i in ts.breakpoints()] + breaks = np.unique(np.concatenate([breaks, gen_breaks])) + mask = np.array([[a, (a + b) / 2] for a, b in zip(breaks[:-1], breaks[1:])]) + ts = ts.keep_intervals(mask) distr = ts.coalescence_time_distribution( window_breaks=gen_breaks, blocks_per_window=2, @@ -1068,26 +1282,32 @@ def coalescence_time_distribution(self): ) return distr - def test_time(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_time(self, with_missing_data): t1 = np.array([0.0, 0.59, 0.73, 1.74]) t2 = np.array([0.0, 0.54, 0.59, 0.59, 0.73, 1.74, 1.74]) - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tt1 = distr.tables[0].time tt2 = distr.tables[1].time - assert np.allclose(t1, tt1) and np.allclose(t2, tt2) + np.testing.assert_allclose(t1, tt1) + np.testing.assert_allclose(t2, tt2) - def test_block(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_block(self, with_missing_data): b1 = np.array([0, 0, 0, 0]) b2 = np.array([0, 1, 0, 1, 0, 0, 1]) - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tb1 = distr.tables[0].block tb2 = distr.tables[1].block - assert np.allclose(b1, tb1) and np.allclose(b2, tb2) + np.testing.assert_allclose(b1, tb1) + np.testing.assert_allclose(b2, tb2) - def test_weights(self): + @pytest.mark.parametrize("with_missing_data", [True, False]) + def test_weights(self, with_missing_data): w1 = np.array([[0, 1.0, 1.0, 1.0]]).T w2 = np.array([[0, 0.24, 0.76, 0.24, 0.76, 0.76, 0.24]]).T - distr = self.coalescence_time_distribution() + distr = self.coalescence_time_distribution(with_missing_data) tw1 = distr.tables[0].weights tw2 = distr.tables[1].weights - assert np.allclose(w1, tw1) and np.allclose(w2, tw2) + np.testing.assert_allclose(w1, tw1) + np.testing.assert_allclose(w2, tw2) diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py new file mode 100644 index 0000000000..acb2403d41 --- /dev/null +++ b/python/tests/test_divmat.py @@ -0,0 +1,1064 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for divergence matrix based pairwise stats +""" +import collections + +import msprime +import numpy as np +import pytest + +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + +DIVMAT_MODES = ["branch", "site"] + +# NOTE: this implementation of Schieber-Vishkin algorithm is done like +# this so it's easy to run with numba. It would be more naturally +# packaged as a class. We don't actually use numba here, but it's +# handy to have a version of the SV code lying around that can be +# run directly with numba. + + +def sv_tables_init(parent_array): + n = 1 + parent_array.shape[0] + + LAMBDA = 0 + # Triply-linked tree. FIXME we shouldn't need to build this as it's + # available already in tskit + child = np.zeros(n, dtype=np.int32) + parent = np.zeros(n, dtype=np.int32) + sib = np.zeros(n, dtype=np.int32) + + for j in range(n - 1): + u = j + 1 + v = parent_array[j] + 1 + sib[u] = child[v] + child[v] = u + parent[u] = v + + lambd = np.zeros(n, dtype=np.int32) + pi = np.zeros(n, dtype=np.int32) + tau = np.zeros(n, dtype=np.int32) + beta = np.zeros(n, dtype=np.int32) + alpha = np.zeros(n, dtype=np.int32) + + p = child[LAMBDA] + n = 0 + lambd[0] = -1 + while p != LAMBDA: + while True: + n += 1 + pi[p] = n + tau[n] = LAMBDA + lambd[n] = 1 + lambd[n >> 1] + if child[p] != LAMBDA: + p = child[p] + else: + break + beta[p] = n + while True: + tau[beta[p]] = parent[p] + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p != LAMBDA: + h = lambd[n & -pi[p]] + beta[p] = ((n >> h) | 1) << h + else: + break + + # Begin the second traversal + lambd[0] = lambd[n] + pi[LAMBDA] = 0 + beta[LAMBDA] = 0 + alpha[LAMBDA] = 0 + p = child[LAMBDA] + while p != LAMBDA: + while True: + a = alpha[parent[p]] | (beta[p] & -beta[p]) + alpha[p] = a + if child[p] != LAMBDA: + p = child[p] + else: + break + while True: + if sib[p] != LAMBDA: + p = sib[p] + break + else: + p = parent[p] + if p == LAMBDA: + break + + return lambd, pi, tau, beta, alpha + + +def _sv_mrca(x, y, lambd, pi, tau, beta, alpha): + if beta[x] <= beta[y]: + h = lambd[beta[y] & -beta[x]] + else: + h = lambd[beta[x] & -beta[y]] + k = alpha[x] & alpha[y] & -(1 << h) + h = lambd[k & -k] + j = ((beta[x] >> h) | 1) << h + if j == beta[x]: + xhat = x + else: + ell = lambd[alpha[x] & ((1 << h) - 1)] + xhat = tau[((beta[x] >> ell) | 1) << ell] + if j == beta[y]: + yhat = y + else: + ell = lambd[alpha[y] & ((1 << h) - 1)] + yhat = tau[((beta[y] >> ell) | 1) << ell] + if pi[xhat] <= pi[yhat]: + z = xhat + else: + z = yhat + return z + + +def sv_mrca(x, y, lambd, pi, tau, beta, alpha): + # Convert to 1-based indexes + return _sv_mrca(x + 1, y + 1, lambd, pi, tau, beta, alpha) - 1 + + +def local_root(tree, u): + while tree.parent(u) != tskit.NULL: + u = tree.parent(u) + return u + + +def branch_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + # print(f"WINDOW {i} [{left}, {right})") + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + span = span_right - span_left + # print(f"\ttree {tree.interval} [{span_left}, {span_right})") + tables = sv_tables_init(tree.parent_array) + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = sv_mrca(u, v, *tables) + assert w == tree.mrca(u, v) + if w != tskit.NULL: + tu = ts.nodes_time[w] - ts.nodes_time[u] + tv = ts.nodes_time[w] - ts.nodes_time[v] + else: + tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] + tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] + d = (tu + tv) * span + D[i, j, k] += d + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def divergence_matrix(ts, windows=None, samples=None, mode="site"): + assert mode in ["site", "branch"] + if mode == "site": + return site_divergence_matrix(ts, samples=samples, windows=windows) + else: + return branch_divergence_matrix(ts, samples=samples, windows=windows) + + +def stats_api_divergence_matrix(ts, windows=None, samples=None, mode="site"): + samples = ts.samples() if samples is None else samples + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else list(windows) + num_windows = len(windows) - 1 + + if len(samples) == 0: + # FIXME: the code general stat code doesn't seem to handle zero samples + # case, need to identify MWE and file issue. + if windows_specified: + return np.zeros(shape=(num_windows, 0, 0)) + else: + return np.zeros(shape=(0, 0)) + + # Make sure that all the specified samples have the sample flag set, otherwise + # the library code will complain + tables = ts.dump_tables() + flags = tables.nodes.flags + # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't + # use this tree sequence outside this method. + flags[:] = 0 + flags[samples] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + + # FIXME We have to go through this annoying rigmarole because windows must start and + # end with 0 and L. We should relax this requirement to just making the windows + # contiguous, so that we just look at specific sections of the genome. + drop = [] + if windows[0] != 0: + windows = [0] + windows + drop.append(0) + if windows[-1] != ts.sequence_length: + windows.append(ts.sequence_length) + drop.append(-1) + + n = len(samples) + sample_sets = [[u] for u in samples] + indexes = [(i, j) for i in range(n) for j in range(n)] + X = ts.divergence( + sample_sets, + indexes=indexes, + mode=mode, + span_normalise=False, + windows=windows, + ) + keep = np.ones(len(windows) - 1, dtype=bool) + keep[drop] = False + X = X[keep] + out = X.reshape((X.shape[0], n, n)) + for D in out: + np.fill_diagonal(D, 0) + if not windows_specified: + out = out[0] + return out + + +def rootward_path(tree, u, v): + while u != v: + yield u + u = tree.parent(u) + + +def site_divergence_matrix(ts, windows=None, samples=None): + windows_specified = windows is not None + windows = [0, ts.sequence_length] if windows is None else windows + num_windows = len(windows) - 1 + samples = ts.samples() if samples is None else samples + + n = len(samples) + D = np.zeros((num_windows, n, n)) + tree = tskit.Tree(ts) + for i in range(num_windows): + left = windows[i] + right = windows[i + 1] + tree.seek(left) + # Iterate over the trees in this window + while tree.interval.left < right and tree.index != -1: + span_left = max(tree.interval.left, left) + span_right = min(tree.interval.right, right) + mutations_per_node = collections.Counter() + for site in tree.sites(): + if span_left <= site.position < span_right: + for mutation in site.mutations: + mutations_per_node[mutation.node] += 1 + for j in range(n): + u = samples[j] + for k in range(j + 1, n): + v = samples[k] + w = tree.mrca(u, v) + if w != tskit.NULL: + wu = w + wv = w + else: + wu = local_root(tree, u) + wv = local_root(tree, v) + du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) + dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) + # NOTE: we're just accumulating the raw mutation counts, not + # multiplying by span + D[i, j, k] += du + dv + tree.next() + # Fill out symmetric triangle in the matrix + for j in range(n): + for k in range(j + 1, n): + D[i, k, j] = D[i, j, k] + if not windows_specified: + D = D[0] + return D + + +def check_divmat( + ts, + *, + windows=None, + samples=None, + verbosity=0, + compare_stats_api=True, + compare_lib=True, + mode="site", +): + np.set_printoptions(linewidth=500, precision=4) + # print(ts.draw_text()) + if verbosity > 1: + print(ts.draw_text()) + + D1 = divergence_matrix(ts, windows=windows, samples=samples, mode=mode) + if compare_stats_api: + # Somethings like duplicate samples aren't worth hacking around for in + # stats API. + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + # print("windows = ", windows) + # print(D1) + # print(D2) + np.testing.assert_allclose(D1, D2) + assert D1.shape == D2.shape + if compare_lib: + D3 = ts.divergence_matrix(windows=windows, samples=samples, mode=mode) + # print(D3) + assert D1.shape == D3.shape + np.testing.assert_allclose(D1, D3) + return D1 + + +class TestExamplesWithAnswer: + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples(self, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + D = check_divmat(ts, samples=[], mode="site") + assert D.shape == (0, 0) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_zero_samples_windows(self, num_windows, mode): + ts = tskit.Tree.generate_balanced(2).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D = check_divmat(ts, samples=[], windows=windows, mode="site") + assert D.shape == (num_windows, 0, 0) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_sites_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts, m) + D1 = check_divmat(ts, mode="site") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("m", [0, 1, 2, 10]) + def test_single_tree_mutations_per_branch(self, m): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_mutations(ts, m) + # The stats API will produce a different value here, because + # we're just counting up the mutations and not reasoning about + # the state of samples at all. + D1 = check_divmat(ts, mode="site", compare_stats_api=False) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, m * D2) + + @pytest.mark.parametrize("L", [0.1, 1, 2, 100]) + def test_single_tree_sequence_length(self, L): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=L).tree_sequence + D1 = check_divmat(ts, mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, L * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_gap_at_end(self, num_windows, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 + # 0 1 2 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + tables = ts.dump_tables() + tables.sequence_length = 2 + ts = tables.tree_sequence() + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + D1 = check_divmat(ts, windows=windows, mode=mode) + D1 = np.sum(D1, axis=0) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_subset_permuted_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[1, 2, 0], mode=mode) + D2 = np.array( + [ + [0.0, 4.0, 2.0], + [4.0, 0.0, 4.0], + [2.0, 4.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_mixed_non_sample_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 5], mode=mode) + D2 = np.array( + [ + [0.0, 3.0], + [3.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_duplicate_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode) + D2 = np.array( + [ + [0.0, 0.0, 2.0], + [0.0, 0.0, 2.0], + [2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_multiroot(self, mode): + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + ts = ts.decapitate(1) + D1 = check_divmat(ts, mode=mode) + D2 = np.array( + [ + [0.0, 2.0, 2.0, 2.0], + [2.0, 0.0, 2.0, 2.0], + [2.0, 2.0, 0.0, 2.0], + [2.0, 2.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize( + ["left", "right"], [(0, 10), (1, 3), (3.25, 3.75), (5, 10)] + ) + def test_single_tree_interval(self, left, right): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + D1 = check_divmat(ts, windows=[left, right], mode="branch") + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + np.testing.assert_array_equal(D1[0], (right - left) * D2) + + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5, 11]) + def test_single_tree_equal_windows(self, num_windows): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + x = ts.sequence_length / num_windows + # print(windows) + D1 = check_divmat(ts, windows=windows, mode="branch") + assert D1.shape == (num_windows, 4, 4) + D2 = np.array( + [ + [0.0, 2.0, 4.0, 4.0], + [2.0, 0.0, 4.0, 4.0], + [4.0, 4.0, 0.0, 2.0], + [4.0, 4.0, 2.0, 0.0], + ] + ) + for D in D1: + np.testing.assert_array_almost_equal(D, x * D2) + + @pytest.mark.parametrize("n", [2, 3, 5]) + def test_single_tree_no_sites(self, n): + ts = tskit.Tree.generate_balanced(n, span=10).tree_sequence + D = check_divmat(ts, mode="site") + np.testing.assert_array_equal(D, np.zeros((n, n))) + + +class TestExamples: + @pytest.mark.parametrize( + "interval", [(0, 26), (1, 3), (3.25, 13.75), (5, 10), (25.5, 26)] + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_interval(self, interval, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + check_divmat(ts, windows=interval, mode=mode) + + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows(self, windows, mode): + ts = tsutil.all_trees_ts(4) + ts = tsutil.insert_branch_sites(ts) + assert ts.sequence_length == 26 + D = check_divmat(ts, windows=windows, mode=mode) + assert D.shape == (len(windows) - 1, 4, 4) + + @pytest.mark.parametrize("num_windows", [1, 5, 28]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_windows_gap_at_end(self, num_windows, mode): + tables = tsutil.all_trees_ts(4).dump_tables() + tables.sequence_length = 30 + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + assert ts.last().num_roots == 4 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5]) + @pytest.mark.parametrize("seed", range(1, 4)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_small_sims(self, n, seed, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + sequence_length=1000, + recombination_rate=0.01, + random_seed=seed, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, rate=0.1, discrete_genome=False, random_seed=seed + ) + assert ts.num_mutations > 1 + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_windows", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_sims_windows(self, n, num_windows, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=79234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations( + ts, + rate=0.01, + discrete_genome=False, + random_seed=1234, + ) + assert ts.num_mutations >= 2 + windows = np.linspace(0, ts.sequence_length, num=num_windows + 1) + check_divmat(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_balanced_tree(self, n, mode): + ts = tskit.Tree.generate_balanced(n).tree_sequence + ts = tsutil.insert_branch_sites(ts) + # print(ts.draw_text()) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_internal_sample(self, mode): + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[3] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, verbosity=0, mode=mode) + + @pytest.mark.parametrize("seed", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_one_internal_sample_sims(self, seed, mode): + ts = msprime.sim_ancestry( + 10, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=seed, + ) + t = ts.dump_tables() + # Add a new sample directly below another sample + u = t.nodes.add_row(time=-1, flags=tskit.NODE_IS_SAMPLE) + t.edges.add_row(parent=0, child=u, left=0, right=ts.sequence_length) + t.sort() + t.build_index() + ts = t.tree_sequence() + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_missing_flanks(self, mode): + ts = msprime.sim_ancestry( + 20, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = ts.keep_intervals([[20, 80]]) + assert ts.first().interval == (0, 20) + ts = tsutil.insert_branch_sites(ts) + check_divmat(ts, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_samples(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in ts1.samples(): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("n", [2, 3, 10]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_dangling_on_all(self, n, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(n).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + for u in range(ts1.num_nodes): + v = tables.nodes.add_row(time=-1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=u, child=v) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_disconnected_non_sample_topology(self, mode): + # Adding non sample branches below the samples does not alter + # the overall divergence *between* the samples + ts1 = tskit.Tree.generate_balanced(5).tree_sequence + ts1 = tsutil.insert_branch_sites(ts1) + D1 = check_divmat(ts1, mode=mode) + tables = ts1.dump_tables() + # Add an extra bit of disconnected non-sample topology + u = tables.nodes.add_row(time=0) + v = tables.nodes.add_row(time=1) + tables.edges.add_row(left=0, right=ts1.sequence_length, parent=v, child=u) + tables.sort() + tables.build_index() + ts2 = tables.tree_sequence() + D2 = check_divmat(ts2, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + +class TestSuiteExamples: + """ + Compare the stats API method vs the library implementation for the + suite test examples. Some of these examples are too large to run the + Python code above on. + """ + + def check(self, ts, windows=None, samples=None, num_threads=0, mode="branch"): + D1 = ts.divergence_matrix( + windows=windows, + samples=samples, + num_threads=num_threads, + mode=mode, + ) + D2 = stats_api_divergence_matrix( + ts, windows=windows, samples=samples, mode=mode + ) + assert D1.shape == D2.shape + if mode == "branch": + # If we have missing data then parts of the divmat are defined to be zero, + # so relative tolerances aren't useful. Because the stats API + # method necessarily involves subtracting away all of the previous + # values for an empty tree, there is a degree of numerical imprecision + # here. This value for atol is what is needed to get the tests to + # pass in practise. + has_missing_data = any(tree._has_isolated_samples() for tree in ts.trees()) + atol = 1e-12 if has_missing_data else 0 + np.testing.assert_allclose(D1, D2, atol=atol) + else: + assert mode == "site" + if np.any(ts.mutations_parent != tskit.NULL): + # The stats API computes something slightly different when we have + # recurrent mutations, so fall back to the naive version. + D2 = site_divergence_matrix(ts, windows=windows, samples=samples) + np.testing.assert_array_equal(D1, D2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_defaults(self, ts, mode): + self.check(ts, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_subset_samples(self, ts, mode): + n = min(ts.num_samples, 2) + self.check(ts, samples=ts.samples()[:n], mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=13) + self.check(ts, windows=windows, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_no_windows(self, ts, mode): + self.check(ts, num_threads=5, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_threads_windows(self, ts, mode): + windows = np.linspace(0, ts.sequence_length, num=11) + self.check(ts, num_threads=5, windows=windows, mode=mode) + + +class TestThreadsNoWindows: + def check(self, ts, num_threads, samples=None, mode=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, 2, samples, mode=mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, n, num_threads, mode): + ts = msprime.sim_ancestry( + n, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + self.check(ts, num_threads, mode=mode) + + +class TestThreadsWindows: + def check(self, ts, num_threads, *, windows, samples=None, mode=None): + D1 = ts.divergence_matrix( + num_threads=0, windows=windows, samples=samples, mode=mode + ) + D2 = ts.divergence_matrix( + num_threads=num_threads, windows=windows, samples=samples, mode=mode + ) + np.testing.assert_array_almost_equal(D1, D2) + + @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + ([0, 1, 2],), + (list(range(27)),), + ([5, 7, 9, 20],), + ([5.1, 5.2, 5.3, 5.5, 6],), + ([5.1, 5.2, 6.5],), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees(self, num_threads, windows, mode): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, num_threads, windows=windows, mode=mode) + + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + (None,), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_all_trees_samples(self, samples, windows, mode): + ts = tsutil.all_trees_ts(4) + self.check(ts, 2, windows=windows, samples=samples, mode=mode) + + @pytest.mark.parametrize("num_threads", range(1, 5)) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 100],), + ([0, 50, 75, 95, 100],), + ([50, 75, 95, 100],), + ([0, 50, 75, 95],), + (list(range(100)),), + ], + ) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_simple_sims(self, num_threads, windows, mode): + ts = msprime.sim_ancestry( + 15, + ploidy=1, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=1234, + ) + assert ts.num_trees >= 2 + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1234) + assert ts.num_mutations > 10 + self.check(ts, num_threads, windows=windows, mode=mode) + + +# NOTE these are tests that are for more general functionality that might +# get applied across many different functions, and so probably should be +# tested in another file. For now they're only used by divmat, so we can +# keep them here for simplificity. +class TestChunkByTree: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 26]]), + (2, [[0, 13], [13, 26]]), + (3, [[0, 9], [9, 18], [18, 26]]), + (4, [[0, 7], [7, 14], [14, 20], [20, 26]]), + (5, [[0, 6], [6, 11], [11, 16], [16, 21], [21, 26]]), + ], + ) + def test_all_trees_ts_26(self, num_chunks, expected): + ts = tsutil.all_trees_ts(4) + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4(self, num_chunks, expected): + ts = tsutil.all_trees_ts(3) + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("span", [1, 2, 5, 0.3]) + @pytest.mark.parametrize( + ["num_chunks", "expected"], + [ + (1, [[0, 4]]), + (2, [[0, 2], [2, 4]]), + (3, [[0, 2], [2, 3], [3, 4]]), + (4, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (5, [[0, 1], [1, 2], [2, 3], [3, 4]]), + (100, [[0, 1], [1, 2], [2, 3], [3, 4]]), + ], + ) + def test_all_trees_ts_4_trees_span(self, span, num_chunks, expected): + tables = tsutil.all_trees_ts(3).dump_tables() + tables.edges.left *= span + tables.edges.right *= span + tables.sequence_length *= span + ts = tables.tree_sequence() + assert ts.num_trees == 4 + actual = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(actual, np.array(expected) * span) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_empty_ts(self, num_chunks): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, 1]]) + + @pytest.mark.parametrize("num_chunks", range(1, 5)) + def test_single_tree(self, num_chunks): + L = 10 + ts = tskit.Tree.generate_balanced(2, span=L).tree_sequence + chunks = ts._chunk_sequence_by_tree(num_chunks) + np.testing.assert_equal(chunks, [[0, L]]) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + ts = tskit.Tree.generate_balanced(2).tree_sequence + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + ts._chunk_sequence_by_tree(num_chunks) + + +class TestChunkWindows: + # These are based on what we get from np.array_split, there's nothing + # particularly critical about exactly how we portion things up. + @pytest.mark.parametrize( + ["windows", "num_chunks", "expected"], + [ + ([0, 10], 1, [[0, 10]]), + ([0, 10], 2, [[0, 10]]), + ([0, 5, 10], 2, [[0, 5], [5, 10]]), + ([0, 5, 6, 10], 2, [[0, 5, 6], [6, 10]]), + ([0, 5, 6, 10], 3, [[0, 5], [5, 6], [6, 10]]), + ], + ) + def test_examples(self, windows, num_chunks, expected): + actual = tskit.TreeSequence._chunk_windows(windows, num_chunks) + np.testing.assert_equal(actual, expected) + + @pytest.mark.parametrize("num_chunks", [0, -1, 0.5]) + def test_bad_chunks(self, num_chunks): + with pytest.raises(ValueError, match="Number of chunks must be an integer > 0"): + tskit.TreeSequence._chunk_windows([0, 1], num_chunks) diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index bece3bc11f..e95ed10988 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -29,6 +29,7 @@ import math import os import pathlib +import platform import re import xml.dom.minidom import xml.etree @@ -44,6 +45,9 @@ from tskit import drawing +IS_WINDOWS = platform.system() == "Windows" + + class TestTreeDraw: """ Tests for the tree drawing functionality. @@ -1453,9 +1457,9 @@ def test_no_repr_svg(self): output._repr_svg_() -class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestMixin): +class TestDrawSvgBase(TestTreeDraw, xmlunittest.XmlTestMixin): """ - Tests the SVG tree drawing. + Base class for testing the SVG tree drawing method """ def verify_basic_svg(self, svg, width=200, height=200, num_trees=1): @@ -1496,6 +1500,12 @@ def verify_basic_svg(self, svg, width=200, height=200, num_trees=1): cls = group.attrib["class"] assert re.search(r"\broot\b", cls) + +class TestDrawSvg(TestDrawSvgBase): + """ + Simple testing for the draw_svg method + """ + def test_repr_svg(self): ts = self.get_simple_ts() svg = ts.draw_svg() @@ -1535,7 +1545,9 @@ def test_draw_to_file(self, tmp_path): def test_nonimplemented_base_class(self): ts = self.get_simple_ts() - plot = drawing.SvgPlot(ts, (100, 100), {}, "", "dummy-class", None, True, True) + plot = drawing.SvgAxisPlot( + ts, (100, 100), {}, "", "dummy-class", None, True, True + ) plot.set_spacing() with pytest.raises(NotImplementedError): plot.draw_x_axis(tick_positions=ts.breakpoints(as_array=True)) @@ -2422,6 +2434,43 @@ def test_debug_box(self): assert svg.count("outer_plotbox") == ts.num_trees + 1 assert svg.count("inner_plotbox") == ts.num_trees + 1 + @pytest.mark.parametrize("max_trees", [-1, 0, 1]) + def test_bad_max_num_trees(self, max_trees): + ts = self.get_simple_ts() + with pytest.raises(ValueError, match="at least 2"): + ts.draw_svg(max_num_trees=max_trees) + + @pytest.mark.parametrize("max_trees", [2, 4, 9]) + def test_max_num_trees(self, max_trees): + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1) + assert ts.num_trees > 10 + num_sites = 0 + num_unplotted_sites = 0 + svg = ts.draw_svg(max_num_trees=max_trees) + for tree in ts.trees(): + if ( + tree.index < (max_trees + 1) // 2 + or ts.num_trees - tree.index <= max_trees // 2 + ): + num_sites += tree.num_sites + assert re.search(rf"t{tree.index}[^\d]", svg) is not None + else: + assert re.search(rf"t{tree.index}[^\d]", svg) is None + num_unplotted_sites += tree.num_sites + assert num_unplotted_sites > 0 + site_strings_in_stylesheet = svg.count(".site") + assert svg.count("site") - site_strings_in_stylesheet == num_sites + self.verify_basic_svg(svg, width=200 * (max_trees + 1)) + + +class TestDrawKnownSvg(TestDrawSvgBase): + """ + Compare against known files + """ + def verify_known_svg(self, svg, filename, save=False, **kwargs): # expected SVG files can be inspected in tests/data/svg/*.svg svg = xml.dom.minidom.parseString( @@ -2752,6 +2801,45 @@ def test_known_svg_ts_xlim(self, overwrite_viz, draw_plotbox, caplog): num_trees = sum(1 for b in ts.breakpoints() if 0.051 <= b < 0.9) + 1 self.verify_known_svg(svg, "ts_x_lim.svg", overwrite_viz, width=200 * num_trees) + @pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows") + def test_known_max_num_trees(self, overwrite_viz, draw_plotbox, caplog): + max_trees = 5 + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1) + assert ts.num_trees > 10 + first_break = next(ts.trees()).interval.right + # limit to just past the first tree + svg = ts.draw_svg( + max_num_trees=max_trees, + x_lim=(first_break + 0.1, ts.sequence_length - 0.1), + y_axis=True, + time_scale="log_time", + ) + self.verify_known_svg( + svg, "ts_max_trees.svg", overwrite_viz, width=200 * (max_trees + 1) + ) + + @pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows") + def test_known_max_num_trees_treewise(self, overwrite_viz, draw_plotbox, caplog): + max_trees = 5 + ts = msprime.sim_ancestry( + 3, sequence_length=100, recombination_rate=0.1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=1) + assert ts.num_trees > 10 + first_break = next(ts.trees()).interval.right + svg = ts.draw_svg( + max_num_trees=max_trees, + x_lim=(first_break + 0.1, ts.sequence_length - 0.1), + y_axis=True, + x_scale="treewise", + ) + self.verify_known_svg( + svg, "ts_max_trees_treewise.svg", overwrite_viz, width=200 * (max_trees + 1) + ) + class TestRounding: def test_rnd(self): diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py new file mode 100644 index 0000000000..ffca67b2e7 --- /dev/null +++ b/python/tests/test_extend_edges.py @@ -0,0 +1,379 @@ +import msprime +import numpy as np +import pytest + +import tests.test_wright_fisher as wf +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def extend_edges(ts, max_iter=10): + tables = ts.dump_tables() + + last_num_edges = ts.num_edges + for _ in range(max_iter): + for forwards in [True, False]: + edges = _extend(ts, forwards=forwards) + tables.edges.replace_with(edges) + tables.build_index() + ts = tables.tree_sequence() + if ts.num_edges == last_num_edges: + break + else: + last_num_edges = ts.num_edges + return ts + + +def _extend(ts, forwards=True): + degree = np.full(ts.num_nodes, 0) + keep = np.full(ts.num_edges, True, dtype=bool) + + edges = ts.tables.edges.copy() + + # "here" will be left if fowards else right; + # and "there" is the other + new_left = edges.left.copy() + new_right = edges.right.copy() + if forwards: + direction = 1 + # in C we can just modify these in place, but in + # python they are (silently) immutable + new_here = new_left + new_there = new_right + else: + direction = -1 + new_here = new_right + new_there = new_left + edges_out = [] + edges_in = [] + + tree_pos = tsutil.TreePosition(ts) + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + while valid: + left, right = tree_pos.interval + there = right if forwards else left + + # clear out non-extended or postponed edges + edges_out = [[e, False] for e, x in edges_out if x] + edges_in = [[e, False] for e, x in edges_in if x] + + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, direction): + e = tree_pos.out_range.order[j] + edges_out.append([e, False]) + + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, direction): + e = tree_pos.in_range.order[j] + edges_in.append([e, False]) + + for e, _ in edges_out: + degree[edges.parent[e]] -= 1 + degree[edges.child[e]] -= 1 + for e, _ in edges_in: + degree[edges.parent[e]] += 1 + degree[edges.child[e]] += 1 + + assert np.all(degree >= 0) + for ex1 in edges_out: + if not ex1[1]: + e1 = ex1[0] + for ex2 in edges_out: + if not ex2[1]: + # the intermediate node should not be present in + # the new tree + e2 = ex2[0] + if (edges.parent[e1] == edges.child[e2]) and ( + degree[edges.child[e2]] == 0 + ): + for ex_in in edges_in: + e_in = ex_in[0] + # we might have passed the interval that a + # postponed edge in covers, in which case + # we should skip it + if new_left[e_in] < right and new_right[e_in] > left: + if ( + edges.child[e1] == edges.child[e_in] + and edges.parent[e2] == edges.parent[e_in] + ): + ex1[1] = True + ex2[1] = True + ex_in[1] = True + new_there[e1] = there + new_there[e2] = there + new_here[e_in] = there + # amend degree: the intermediate + # node has 2 edges instead of 0 + degree[edges.parent[e1]] += 2 + # end of loop, next tree + if forwards: + valid = tree_pos.next() + else: + valid = tree_pos.prev() + + for j in range(edges.num_rows): + left = new_left[j] + right = new_right[j] + if left < right: + edges[j] = edges[j].replace(left=left, right=right) + else: + keep[j] = False + edges.keep_rows(keep) + return edges + + +class TestExtendEdges: + """ + Test the 'extend edges' method + """ + + def verify_extend_edges(self, ts, max_iter=10): + # This can still fail for various weird examples: + # for instance, if adjacent trees have + # a <- b <- c <- d and a <- d (where say b was + # inserted in an earlier pass), then b and c + # won't be extended + + ets = ts.extend_edges(max_iter=max_iter) + assert ts.num_samples == ets.num_samples + assert ts.num_nodes == ets.num_nodes + assert ts.num_edges >= ets.num_edges + t = ts.simplify().tables + et = ets.simplify().tables + t.assert_equals(et, ignore_provenance=True) + old_edges = {} + for e in ts.edges(): + k = (e.parent, e.child) + if k not in old_edges: + old_edges[k] = [] + old_edges[k].append((e.left, e.right)) + + for e in ets.edges(): + # e should be in old_edges, + # but with modified limits: + # USUALLY overlapping limits, but + # not necessarily after more than one pass + k = (e.parent, e.child) + assert k in old_edges + if max_iter == 1: + overlaps = False + for left, right in old_edges[k]: + if (left <= e.right) and (right >= e.left): + overlaps = True + assert overlaps + + if max_iter > 1: + chains = [] + for _, tt, ett in ts.coiterate(ets): + this_chains = [] + for a in tt.nodes(): + assert a in ett.nodes() + b = tt.parent(a) + if b != tskit.NULL: + c = tt.parent(b) + if c != tskit.NULL: + this_chains.append((a, b, c)) + assert b in ett.nodes() + # the relationship a <- b should still be in the tree + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + assert p == b + chains.append(this_chains) + + extended_ac = {} + not_extended_ac = {} + extended_ab = {} + not_extended_ab = {} + for k, (interval, tt, ett) in enumerate(ts.coiterate(ets)): + for j in (k - 1, k + 1): + if j < 0 or j >= len(chains): + continue + else: + this_chains = chains[j] + for a, b, c in this_chains: + if ( + a in tt.nodes() + and tt.parent(a) == c + and b not in tt.nodes() + ): + # the relationship a <- b <- c should still be in the tree, + # although maybe they aren't direct parent-offspring + # UNLESS we've got an ambiguous case, where on the opposite + # side of the interval a chain a <- b' <- c got extended + # into the region OR b got inserted into another chain + assert a in ett.nodes() + assert c in ett.nodes() + if b not in ett.nodes(): + if (a, c) not in not_extended_ac: + not_extended_ac[(a, c)] = [] + not_extended_ac[(a, c)].append(interval) + else: + if (a, c) not in extended_ac: + extended_ac[(a, c)] = [] + extended_ac[(a, c)].append(interval) + p = a + while p != tskit.NULL and p != b: + p = ett.parent(p) + if p != b: + if (a, b) not in not_extended_ab: + not_extended_ab[(a, b)] = [] + not_extended_ab[(a, b)].append(interval) + else: + if (a, b) not in extended_ab: + extended_ab[(a, b)] = [] + extended_ab[(a, b)].append(interval) + while p != tskit.NULL and p != c: + p = ett.parent(p) + assert p == c + for a, c in not_extended_ac: + # check that a <- ... <- c has been extended somewhere + # although not necessarily from an adjacent segment + assert (a, c) in extended_ac + for interval in not_extended_ac[(a, c)]: + ett = ets.at(interval.left) + assert ett.parent(a) != c + for k in not_extended_ab: + assert k in extended_ab + for interval in not_extended_ab[k]: + assert interval in extended_ab[k] + + # finally, compare C version to python version + py_et = extend_edges(ts, max_iter=max_iter).dump_tables() + et = ets.dump_tables() + et.assert_equals(py_et) + + def test_runs(self): + ts = msprime.simulate(5, random_seed=126) + self.verify_extend_edges(ts) + + def test_max_iter(self): + ts = msprime.simulate(5, random_seed=126) + with pytest.raises(ValueError, match="max_iter"): + ets = ts.extend_edges(max_iter=0) + ets = ts.extend_edges(max_iter=1) + et = ets.extend_edges(max_iter=1).dump_tables() + eet = ets.extend_edges(max_iter=2).dump_tables() + eet.assert_equals(et) + + def test_simple_ex(self): + # An example where you need to go forwards *and* backwards: + # 7 and 8 should be extended to the whole sequence + # + # 6 6 6 6 + # +-+-+ +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 | | + # | | ++-+ | | +-++ | | + # 4 5 4 | | 4 | 5 4 5 + # +++ +++ +++ | | | | +++ +++ +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + # + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 1.0, + 5: 1.0, + 6: 3.0, + 7: 2.0, + 8: 2.0, + } + # (p, c, l, r) + edge_stuff = [ + (4, 0, 0, 10), + (4, 1, 0, 5), + (4, 1, 7, 10), + (5, 2, 0, 2), + (5, 2, 5, 10), + (5, 3, 0, 2), + (5, 3, 5, 10), + (6, 4, 0, 2), + (6, 4, 5, 10), + (6, 5, 0, 2), + (6, 5, 7, 10), + (6, 3, 2, 5), + (6, 7, 2, 5), + (6, 8, 5, 7), + (7, 2, 2, 5), + (7, 4, 2, 5), + (8, 1, 5, 7), + (8, 5, 5, 7), + ] + tables = tskit.TableCollection(sequence_length=10) + nodes = tables.nodes + for n, t in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n < 4 else 0 + nodes.add_row(time=t, flags=flags) + edges = tables.edges + for p, c, l, r in edge_stuff: + edges.add_row(parent=p, child=c, left=l, right=r) + tables.sort() + ts = tables.tree_sequence() + ets = ts.extend_edges() + assert ts.num_edges == 18 + assert ets.num_edges == 13 + for t in ets.trees(): + assert 7 in t.nodes() + assert 8 in t.nodes() + assert t.parent(4) == 7 + assert t.parent(7) == 6 + assert t.parent(5) == 8 + assert t.parent(8) == 6 + self.verify_extend_edges(ts) + + def test_wright_fisher_trees(self): + tables = wf.wf_sim(N=5, ngens=20, deep_history=False, seed=3) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_trees_unsimplified(self): + tables = wf.wf_sim(N=6, ngens=22, deep_history=False, seed=4) + tables.sort() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + def test_wright_fisher_trees_with_history(self): + tables = wf.wf_sim(N=8, ngens=15, deep_history=True, seed=5) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.verify_extend_edges(ts, max_iter=1) + self.verify_extend_edges(ts) + + # def test_bigger_wright_fisher(self): + # tables = wf.wf_sim(N=50, ngens=15, deep_history=True, seed=6) + # tables.sort() + # tables.simplify() + # ts = tables.tree_sequence() + # self.verify_extend_edges(ts, max_iter=1) + # self.verify_extend_edges(ts, max_iter=200) + + +class TestExamples: + """ + Compare the ts method with local implementation. + """ + + def check(self, ts): + lib_ts = ts.extend_edges() + py_ts = extend_edges(ts) + lib_ts.tables.assert_equals(py_ts.tables) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_suite_examples_defaults(self, ts): + self.check(ts) + + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_trees_ts(self, n): + ts = tsutil.all_trees_ts(n) + self.check(ts) diff --git a/python/tests/test_file_format.py b/python/tests/test_file_format.py index c41327b288..26abc0f2d8 100644 --- a/python/tests/test_file_format.py +++ b/python/tests/test_file_format.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -36,12 +36,12 @@ import msprime import numpy as np import pytest +import tszip as tszip import tests.tsutil as tsutil import tskit import tskit.exceptions as exceptions - CURRENT_FILE_MAJOR = 12 CURRENT_FILE_MINOR = 7 @@ -262,11 +262,17 @@ def test_format_too_old_raised_for_hdf5(self): ] for filename in files: path = os.path.join(test_data_dir, "hdf5-formats", filename) + with pytest.raises( exceptions.FileFormatError, - match="uses the old HDF5-based format which can no longer", + match="appears to be in HDF5 format", ): tskit.load(path) + with pytest.raises( + exceptions.FileFormatError, + match="appears to be in HDF5 format", + ): + tskit.TableCollection.load(path) def test_msprime_v_0_5_0(self): path = os.path.join(test_data_dir, "hdf5-formats", "msprime-0.5.0_v10.0.hdf5") @@ -511,6 +517,14 @@ def test_no_h5py(self): with pytest.raises(ImportError, match=msg): tskit.dump_legacy(ts, path) + def test_tszip_file(self): + ts = msprime.simulate(5) + tszip.compress(ts, self.temp_file) + with pytest.raises(tskit.FileFormatError, match="appears to be in zip format"): + tskit.load(self.temp_file) + with pytest.raises(tskit.FileFormatError, match="appears to be in zip format"): + tskit.TableCollection.load(self.temp_file) + class TestDumpFormat(TestFileFormat): """ diff --git a/python/tests/test_fileobj.py b/python/tests/test_fileobj.py index e4d05b63b6..1740094462 100644 --- a/python/tests/test_fileobj.py +++ b/python/tests/test_fileobj.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ import traceback import pytest +import tszip from pytest import fixture import tskit @@ -308,3 +309,59 @@ def verify_stream(self, ts_list, client_fd): def test_single_then_multi(self, ts_fixture, replicate_ts_fixture, client_fd): self.verify_stream([ts_fixture], client_fd) self.verify_stream(replicate_ts_fixture, client_fd) + + +def write_to_fifo(path, file_path): + with open(path, "wb") as fifo: + with open(file_path, "rb") as file: + fifo.write(file.read()) + + +def read_from_fifo(path, expected_exception, error_text, read_func): + with open(path) as fifo: + with pytest.raises(expected_exception, match=error_text): + read_func(fifo) + + +def write_and_read_from_fifo(fifo_path, file_path, expected_exception, error_text): + os.mkfifo(fifo_path) + for read_func in [tskit.load, tskit.TableCollection.load]: + read_process = multiprocessing.Process( + target=read_from_fifo, + args=(fifo_path, expected_exception, error_text, read_func), + ) + read_process.start() + write_process = multiprocessing.Process( + target=write_to_fifo, args=(fifo_path, file_path) + ) + write_process.start() + write_process.join(timeout=3) + read_process.join(timeout=3) + + +@pytest.mark.skipif(IS_WINDOWS, reason="No FIFOs on Windows") +class TestBadStream: + def test_bad_stream(self, tmp_path): + fifo_path = tmp_path / "fifo" + bad_file_path = tmp_path / "bad_file" + bad_file_path.write_bytes(b"bad data") + write_and_read_from_fifo( + fifo_path, bad_file_path, tskit.FileFormatError, "not in kastore format" + ) + + def test_legacy_stream(self, tmp_path): + fifo_path = tmp_path / "fifo" + legacy_file_path = os.path.join( + os.path.dirname(__file__), "data", "hdf5-formats", "msprime-0.3.0_v2.0.hdf5" + ) + write_and_read_from_fifo( + fifo_path, legacy_file_path, tskit.FileFormatError, "not in kastore format" + ) + + def test_tszip_stream(self, tmp_path, ts_fixture): + fifo_path = tmp_path / "fifo" + zip_file_path = tmp_path / "tszip_file" + tszip.compress(ts_fixture, zip_file_path) + write_and_read_from_fifo( + fifo_path, zip_file_path, tskit.FileFormatError, "not in kastore format" + ) diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 88dd7a754d..761eadf403 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -1,4 +1,3 @@ -# Simulation import copy import itertools @@ -14,6 +13,8 @@ REF_HOM_OBS_HET = 1 REF_HET_OBS_HOM = 2 +MISSING = -1 + def mirror_coordinates(ts): """ @@ -174,8 +175,8 @@ def stupid_compress_dict(self): # Retain the old T_index, because the internal T that's passed up the tree will # retain this ordering. old_T_index = copy.deepcopy(self.T_index) - self.T_index = np.zeros(tree.num_nodes, dtype=int) - 1 - self.N = np.zeros(tree.num_nodes, dtype=int) + self.T_index = np.zeros(tree.tree_sequence.num_nodes, dtype=int) - 1 + self.N = np.zeros(tree.tree_sequence.num_nodes, dtype=int) self.T.clear() # First, create T root. @@ -345,7 +346,7 @@ def update_tree(self): vt.tree_node = -1 vt.value_index = -1 - self.N = np.zeros(self.tree.num_nodes, dtype=int) + self.N = np.zeros(self.tree.tree_sequence.num_nodes, dtype=int) node_map = {st.tree_node: st for st in self.T} for u in self.tree.samples(): @@ -411,6 +412,7 @@ def update_probabilities(self, site, genotype_state): ] query_is_het = genotype_state == 1 + query_is_missing = genotype_state == MISSING for st1 in T: u1 = st1.tree_node @@ -444,6 +446,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, + query_is_missing, ) # This will ensure that allelic_state[:n] is filled @@ -561,7 +564,14 @@ def compute_normalisation_factor_dict(self): raise NotImplementedError() def compute_next_probability_dict( - self, site_id, p_last, inner_summation, is_match, template_is_het, query_is_het + self, + site_id, + p_last, + inner_summation, + is_match, + template_is_het, + query_is_het, + query_is_missing, ): raise NotImplementedError() @@ -670,41 +680,45 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, ): rho = self.rho[site_id] mu = self.mu[site_id] n = self.ts.num_samples - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - p_t = ( (rho / n) ** 2 + ((1 - rho) * (rho / n)) * inner_normalisation_factor + (1 - rho) ** 2 * p_last ) - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + + if query_is_missing: + p_e = 1 + else: + query_is_hom = np.logical_not(query_is_het) + template_is_hom = np.logical_not(template_is_het) + + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) return p_t * p_e -# DEV: Sort this class BackwardAlgorithm(LsHmmAlgorithm): """Runs the Li and Stephens forward algorithm.""" @@ -737,29 +751,34 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, ): mu = self.mu[site_id] - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + if query_is_missing: + p_e = 1 + else: + query_is_hom = np.logical_not(query_is_het) + + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) + return p_next * p_e @@ -797,6 +816,21 @@ def example_genotypes(self, ts): s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + genotypes.append(s_tmp) + m = ts.get_num_sites() n = H.shape[1] @@ -804,11 +838,11 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, s + return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) - H, G, s = self.example_genotypes(ts) + H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -819,13 +853,16 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu + for s in genotypes: + yield n, m, G, s, e, r, mu # Mixture of random and extremes rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for r, mu in itertools.product(rs, mus): + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu diff --git a/python/tests/test_genotype_matching_viterbi.py b/python/tests/test_genotype_matching_viterbi.py index 89377bdb33..acab5d1c28 100644 --- a/python/tests/test_genotype_matching_viterbi.py +++ b/python/tests/test_genotype_matching_viterbi.py @@ -13,6 +13,8 @@ REF_HOM_OBS_HET = 1 REF_HET_OBS_HOM = 2 +MISSING = -1 + class ValueTransition: """Simple struct holding value transition values.""" @@ -390,6 +392,7 @@ def update_probabilities(self, site, genotype_state): ] query_is_het = genotype_state == 1 + query_is_missing = genotype_state == MISSING for st1 in T: u1 = st1.tree_node @@ -423,6 +426,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, + query_is_missing, u1, u2, ) @@ -486,6 +490,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -830,6 +835,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -841,26 +847,28 @@ def compute_next_probability_dict( double_recombination_required = False single_recombination_required = False - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + if query_is_missing: + p_e = 1 + else: + template_is_hom = np.logical_not(template_is_het) + query_is_hom = np.logical_not(query_is_het) + equal_both_hom = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + unequal_both_hom = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + both_het = np.logical_and(template_is_het, query_is_het) + ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het) + ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + equal_both_hom * (1 - mu) ** 2 + + unequal_both_hom * (mu**2) + + ref_hom_obs_het * (2 * mu * (1 - mu)) + + ref_het_obs_hom * (mu * (1 - mu)) + + both_het * ((1 - mu) ** 2 + mu**2) + ) no_switch = (1 - r) ** 2 + 2 * (r_n * (1 - r)) + r_n**2 single_switch = r_n * (1 - r) + r_n**2 @@ -919,6 +927,21 @@ def example_genotypes(self, ts): s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + genotypes.append(s_tmp) + m = ts.get_num_sites() n = H.shape[1] @@ -926,11 +949,11 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, s + return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) - H, G, s = self.example_genotypes(ts) + H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -941,13 +964,16 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu + for s in genotypes: + yield n, m, G, s, e, r, mu # Mixture of random and extremes rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for r, mu in itertools.product(rs, mus): + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 1443e40061..329867b600 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -2215,3 +2215,17 @@ def test_variant_html_repr_no_site(self): html = v._repr_html_() ElementTree.fromstring(html) assert len(html) > 1600 + + def test_variant_repr(self, ts_fixture): + v = next(ts_fixture.variants()) + str_rep = repr(v) + assert len(str_rep) > 0 and len(str_rep) < 10000 + assert re.search(r"\AVariant", str_rep) + assert re.search(rf"\'site\': Site\(id={v.site.id}", str_rep) + assert re.search(rf"position={v.position}", str_rep) + alleles = re.escape("'alleles': " + str(v.alleles)) + assert re.search(rf"{alleles}", str_rep) + assert re.search(r"\'genotypes\': array\(\[", str_rep) + assert re.search(rf"position={v.position}", str_rep) + assert re.search(rf"\'has_missing_data\': {v.has_missing_data}", str_rep) + assert re.search(rf"\'isolated_as_missing\': {v.isolated_as_missing}", str_rep) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 55f102939c..b09ebcc005 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2021 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,332 +20,55 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -Python implementation of the Li and Stephens algorithms. +Python implementation of the Li and Stephens forwards and backwards algorithms. """ import itertools -import unittest +import lshmm as ls import msprime import numpy as np -import pytest -import _tskit # TMP import tskit -from tests import tsutil +MISSING = -1 -def in_sorted(values, j): - # Take advantage of the fact that the numpy array is sorted. - ret = False - index = np.searchsorted(values, j) - if index < values.shape[0]: - ret = values[index] == j - return ret - -def ls_forward_matrix_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Viterbi algorithm below. This depends on the different - # normalisation approach. - p_t = f[j] * (1 - rho[el]) + rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def ls_viterbi_naive(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS Viterbi algorithm using Python loops. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - L = np.ones(n) - T = [set() for _ in range(m)] - T_dest = np.zeros(m, dtype=int) - - for el in range(m): - # The calculation below is undefined otherwise. - if len(alleles[el]) > 1: - assert mu[el] <= 1 / (len(alleles[el]) - 1) - L_next = np.zeros(n) - for j in range(n): - # NOTE Careful with the difference between this expression and - # the Forward algorithm above. This depends on the different - # normalisation approach. - p_no_recomb = L[j] * (1 - rho[el] + rho[el] / n) - p_recomb = rho[el] / n - if p_no_recomb > p_recomb: - p_t = p_no_recomb - else: - p_t = p_recomb - T[el].add(j) - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - L_next[j] = p_t * p_e - L = L_next - j = np.argmax(L) - T_dest[el] = j - if L[j] == 0: - assert mu[el] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - L /= L[j] - - P = np.zeros(m, dtype=int) - P[m - 1] = T_dest[m - 1] - for el in range(m - 1, 0, -1): - j = P[el] - if j in T[el]: - j = T_dest[el - 1] - P[el - 1] = j - return P - - -def ls_viterbi_vectorised(h, alleles, G, rho, mu): - # We must have a non-zero mutation rate, or we'll end up with - # division by zero problems. - # assert np.all(mu > 0) - - m, n = G.shape - alleles = check_alleles(alleles, m) - V = np.ones(n) - T = [None for _ in range(m)] - max_index = np.zeros(m, dtype=int) - - for site in range(m): - # Transition - p_neq = rho[site] / n - p_t = (1 - rho[site] + rho[site] / n) * V - recombinations = np.where(p_neq > p_t)[0] - p_t[recombinations] = p_neq - T[site] = recombinations - # Emission - p_e = np.zeros(n) + mu[site] - index = G[site] == h[site] - if h[site] == tskit.MISSING_DATA: - # Missing data is considered equal to everything - index[:] = True - p_e[index] = 1 - (len(alleles[site]) - 1) * mu[site] - V = p_t * p_e - # Normalise - max_index[site] = np.argmax(V) - # print(site, ":", V) - if V[max_index[site]] == 0: - assert mu[site] == 0 - raise ValueError( - "Trying to match non-existent allele with zero mutation rate" - ) - V /= V[max_index[site]] - - # Traceback - P = np.zeros(m, dtype=int) - site = m - 1 - P[site] = max_index[site] - while site > 0: - j = P[site] - if in_sorted(T[site], j): - j = max_index[site - 1] - P[site - 1] = j - site -= 1 - return P - - -def check_alleles(alleles, num_sites): +def check_alleles(alleles, m): """ Checks the specified allele list and returns a list of lists of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length num_sites. """ if isinstance(alleles[0], str): - return [alleles for _ in range(num_sites)] - if len(alleles) != num_sites: + return [alleles for _ in range(m)], np.int8([len(alleles) for _ in range(m)]) + if len(alleles) != m: raise ValueError("Malformed alleles list") - return alleles + n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + return alleles, n_alleles -def ls_forward_matrix(h, alleles, G, rho, mu): +def mirror_coordinates(ts): """ - Simple matrix based method for LS forward algorithm using numpy vectorisation. + Returns a copy of the specified tree sequence in which all + coordinates x are transformed into L - x. """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - S = np.zeros(m) - f = np.zeros(n) + 1 / n - p_e = np.zeros(n) - - for el in range(0, m): - p_t = f * (1 - rho[el]) + rho[el] / n - eq = G[el] == h[el] - if h[el] == tskit.MISSING_DATA: - # Missing data is equal to everything - eq[:] = True - p_e[:] = mu[el] - p_e[eq] = 1 - (len(alleles[el]) - 1) * mu[el] - f = p_t * p_e - S[el] = np.sum(f) - # TODO need to handle the 0 case. - assert S[el] > 0 - f /= S[el] - F[el] = f - return F, S - - -def forward_matrix_log_proba(F, S): - """ - Given the specified forward matrix and scaling factor array, return the - overall log probability of the input haplotype. - """ - return np.sum(np.log(S)) - np.log(np.sum(F[-1])) - - -def ls_forward_matrix_unscaled(h, alleles, G, rho, mu): - """ - Simple matrix based method for LS forward algorithm. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - F = np.zeros((m, n)) - f = np.zeros(n) + 1 / n - - for el in range(0, m): - s = np.sum(f) - for j in range(n): - p_t = f[j] * (1 - rho[el]) + s * rho[el] / n - p_e = mu[el] - if G[el, j] == h[el] or h[el] == tskit.MISSING_DATA: - p_e = 1 - (len(alleles[el]) - 1) * mu[el] - f[j] = p_t * p_e - F[el] = f - return F - - -# TODO change this to use the log_proba function below. -def ls_path_probability(h, path, G, rho, mu): - """ - Returns the probability of the specified path through the genotypes for the - specified haplotype. - """ - # Assuming num_alleles = 2 - assert rho[0] == 0 - m, n = G.shape - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - proba = 1 / n - for site in range(0, m): - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - mu[site] - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - proba *= pt * pe - return proba - - -def ls_path_log_probability(h, path, alleles, G, rho, mu): - """ - Returns the log probability of the specified path through the genotypes for the - specified haplotype. - """ - assert rho[0] == 0 - m, n = G.shape - alleles = check_alleles(alleles, m) - # TODO It's not entirely clear why we're starting with a proba of 1 / n for the - # model. This was done because it made it easier to compare with an existing - # HMM implementation. Need to figure this one out when writing up. - log_proba = np.log(1 / n) - for site in range(0, m): - if len(alleles[site]) > 1: - assert mu[site] <= 1 / (len(alleles[site]) - 1) - pe = mu[site] - if h[site] == G[site, path[site]] or h[site] == tskit.MISSING_DATA: - pe = 1 - (len(alleles[site]) - 1) * mu[site] - assert 0 <= pe <= 1 - pt = rho[site] / n - if site == 0 or path[site] == path[site - 1]: - pt = 1 - rho[site] + rho[site] / n - assert 0 <= pt <= 1 - log_proba += np.log(pt) + np.log(pe) - return log_proba - - -def ls_forward_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Forward matrix computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - cm = _tskit.CompressedMatrix(ts.ll_tree_sequence) - ls_hmm.forward_matrix(h, cm) - return cm - else: - fa = ForwardAlgorithm(ts, rho, mu, alleles, precision=precision) - return fa.run(h) - - -def ls_viterbi_tree(h, alleles, ts, rho, mu, precision=30, use_lib=True): - """ - Viterbi path computation based on a tree sequence. - """ - if use_lib: - acgt_alleles = tuple(alleles) == tskit.ALLELES_ACGT - ls_hmm = _tskit.LsHmm( - ts.ll_tree_sequence, - recombination_rate=rho, - mutation_rate=mu, - precision=precision, - acgt_alleles=acgt_alleles, - ) - vm = _tskit.ViterbiMatrix(ts.ll_tree_sequence) - ls_hmm.viterbi_matrix(h, vm) - return vm - else: - va = ViterbiAlgorithm(ts, rho, mu, alleles, precision=precision) - return va.run(h) + L = ts.sequence_length + tables = ts.dump_tables() + left = tables.edges.left + right = tables.edges.right + tables.edges.left = L - right + tables.edges.right = L - left + tables.sites.position = L - tables.sites.position # + 1 + # TODO migrations. + tables.sort() + return tables.tree_sequence() class ValueTransition: - """ - Simple struct holding value transition values. - """ + """Simple struct holding value transition values.""" def __init__(self, tree_node=-1, value=-1, value_index=-1): self.tree_node = tree_node @@ -353,7 +76,11 @@ def __init__(self, tree_node=-1, value=-1, value_index=-1): self.value_index = value_index def copy(self): - return ValueTransition(self.tree_node, self.value, self.value_index) + return ValueTransition( + self.tree_node, + self.value, + self.value_index, + ) def __repr__(self): return repr(self.__dict__) @@ -367,11 +94,12 @@ class LsHmmAlgorithm: Abstract superclass of Li and Stephens HMM algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): + def __init__( + self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + ): self.ts = ts self.mu = mu self.rho = rho - self.alleles = check_alleles(alleles, ts.num_sites) self.precision = precision # The array of ValueTransitions. self.T = [] @@ -386,6 +114,10 @@ def __init__(self, ts, rho, mu, alleles, precision=10): self.parent = np.zeros(self.ts.num_nodes, dtype=int) - 1 self.tree = tskit.Tree(self.ts) self.output = None + # Vector of the number of alleles at each site + self.n_alleles = n_alleles + self.alleles = alleles + self.scale_mutation_based_on_n_alleles = scale_mutation def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -422,10 +154,6 @@ def compute(u, parent_state): for j in range(num_values): value_count[j] += child[j] max_value_count = np.max(value_count) - # NOTE: we need to set the set to zero here because we actually - # visit some nodes more than once during the postorder traversal. - # This would seem to be wasteful, so we should revisit this when - # cleaning up the algorithm logic. optimal_set[u, :] = 0 optimal_set[u, value_count == max_value_count] = 1 @@ -566,9 +294,9 @@ def update_probabilities(self, site, haplotype_state): T = self.T alleles = self.alleles[site.id] allelic_state = self.allelic_state - # Set the allelic_state for this site. allelic_state[tree.root] = alleles.index(site.ancestral_state) + for mutation in site.mutations: u = mutation.node allelic_state[u] = alleles.index(mutation.derived_state) @@ -590,8 +318,7 @@ def update_probabilities(self, site, haplotype_state): v = tree.parent(v) assert v != -1 match = ( - haplotype_state == tskit.MISSING_DATA - or haplotype_state == allelic_state[v] + haplotype_state == MISSING or haplotype_state == allelic_state[v] ) st.value = self.compute_next_probability(site.id, st.value, match, u) @@ -600,31 +327,41 @@ def update_probabilities(self, site, haplotype_state): for mutation in site.mutations: allelic_state[mutation.node] = -1 - def process_site(self, site, haplotype_state): - # print(site.id, "num_transitions=", len(self.T)) - self.update_probabilities(site, haplotype_state) - # FIXME We don't want to call compress here. - # What we really want to do is just call compress after - # the values have been normalised and rounded. However, we can't - # compute the normalisation factor in the forwards algorithm without - # the N counts (number of samples directly below each value transition - # in T), and these are currently computed during compress. So to make - # things work for now we call compress before and put up with having - # a slightly less than optimally compressed output matrix. It might - # end up that this makes no difference and compressing the - # pre-rounded values is basically the same thing. - self.compress() - s = self.compute_normalisation_factor() - for st in self.T: - if st.tree_node != tskit.NULL: - st.value /= s - st.value = round(st.value, self.precision) - # *This* is where we want to compress (and can, for viterbi). - # self.compress() - self.output.store_site(site.id, s, [(st.tree_node, st.value) for st in self.T]) - - def run(self, h): + def process_site(self, site, haplotype_state, forwards=True): + if forwards: + # Forwards algorithm, or forwards pass in Viterbi + self.update_probabilities(site, haplotype_state) + self.compress() + s = self.compute_normalisation_factor() + for st in self.T: + if st.tree_node != tskit.NULL: + st.value /= s + st.value = round(st.value, self.precision) + self.output.store_site( + site.id, s, [(st.tree_node, st.value) for st in self.T] + ) + else: + # Backwards algorithm + self.output.store_site( + site.id, + self.output.normalisation_factor[site.id], + [(st.tree_node, st.value) for st in self.T], + ) + self.update_probabilities(site, haplotype_state) + self.compress() + b_last_sum = self.compute_normalisation_factor() + s = self.output.normalisation_factor[site.id] + for st in self.T: + if st.tree_node != tskit.NULL: + st.value = ( + self.rho[site.id] / self.ts.num_samples + ) * b_last_sum + (1 - self.rho[site.id]) * st.value + st.value /= s + st.value = round(st.value, self.precision) + + def run_forward(self, h): n = self.ts.num_samples + self.tree.clear() for u in self.ts.samples(): self.T_index[u] = len(self.T) self.T.append(ValueTransition(tree_node=u, value=1 / n)) @@ -634,6 +371,17 @@ def run(self, h): self.process_site(site, h[site.id]) return self.output + def run_backward(self, h): + self.tree.clear() + for u in self.ts.samples(): + self.T_index[u] = len(self.T) + self.T.append(ValueTransition(tree_node=u, value=1)) + while self.tree.next(): + self.update_tree() + for site in self.tree.sites(): + self.process_site(site, h[site.id], forwards=False) + return self.output + def compute_normalisation_factor(self): raise NotImplementedError() @@ -650,12 +398,16 @@ class CompressedMatrix: values are on the path). """ - def __init__(self, ts): + def __init__(self, ts, normalisation_factor=None): self.ts = ts self.num_sites = ts.num_sites self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] - self.normalisation_factor = np.zeros(self.num_sites) + if normalisation_factor is None: + self.normalisation_factor = np.zeros(self.num_sites) + else: + self.normalisation_factor = normalisation_factor + assert len(self.normalisation_factor) == self.num_sites def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor @@ -688,39 +440,11 @@ def decode(self): class ForwardMatrix(CompressedMatrix): - """ - Class representing a compressed forward matrix. - """ - - -class ForwardAlgorithm(LsHmmAlgorithm): - """ - Runs the Li and Stephens forward algorithm. - """ - - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) - self.output = ForwardMatrix(ts) - - def compute_normalisation_factor(self): - s = 0 - for j, st in enumerate(self.T): - assert st.tree_node != tskit.NULL - assert self.N[j] > 0 - s += self.N[j] * st.value - return s + """Class representing a compressed forward matrix.""" - def compute_next_probability(self, site_id, p_last, is_match, node): - rho = self.rho[site_id] - mu = self.mu[site_id] - alleles = self.alleles[site_id] - n = self.ts.num_samples - p_t = p_last * (1 - rho) + rho / n - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e +class BackwardMatrix(CompressedMatrix): + """Class representing a compressed backward matrix.""" class ViterbiMatrix(CompressedMatrix): @@ -730,6 +454,8 @@ class ViterbiMatrix(CompressedMatrix): def __init__(self, ts): super().__init__(ts) + # Tuple containing the site, the node in the tree, and whether + # recombination is required self.recombination_required = [(-1, 0, False)] def add_recombination_required(self, site, node, required): @@ -801,13 +527,144 @@ def traceback(self): return match +class ForwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens forward algorithm.""" + + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = ForwardMatrix(ts) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_last, is_match, node + ): # Note node only used in Viterbi + rho = self.rho[site_id] + mu = self.mu[site_id] + n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + + p_t = p_last * (1 - rho) + rho / n + return p_t * p_e + + +class BackwardAlgorithm(LsHmmAlgorithm): + """Runs the Li and Stephens backward algorithm.""" + + def __init__( + self, + ts, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + scale_mutation=False, + precision=10, + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) + self.output = BackwardMatrix(ts, normalisation_factor) + + def compute_normalisation_factor(self): + s = 0 + for j, st in enumerate(self.T): + assert st.tree_node != tskit.NULL + assert self.N[j] > 0 + s += self.N[j] * st.value + return s + + def compute_next_probability( + self, site_id, p_next, is_match, node + ): # Note node only used in Viterbi + mu = self.mu[site_id] + n_alleles = self.n_alleles[site_id] + + if self.scale_mutation_based_on_n_alleles: + if is_match: + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + else: + if n_alleles == 1: + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + return p_next * p_e + + class ViterbiAlgorithm(LsHmmAlgorithm): """ Runs the Li and Stephens Viterbi algorithm. """ - def __init__(self, ts, rho, mu, alleles, precision=10): - super().__init__(ts, rho, mu, alleles, precision) + def __init__( + self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 + ): + super().__init__( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation, + ) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -825,8 +682,8 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] mu = self.mu[site_id] - alleles = self.alleles[site_id] n = self.ts.num_samples + n_alleles = self.n_alleles[site_id] p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -837,474 +694,427 @@ def compute_next_probability(self, site_id, p_last, is_match, node): p_t = p_recomb recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - p_e = mu - if is_match: - p_e = 1 - (len(alleles) - 1) * mu - return p_t * p_e + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) -################################################################ -# Tests -################################################################ + return p_t * p_e -class LiStephensBase: +def ls_forward_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) + + """Forward matrix computation based on a tree sequence.""" + fa = ForwardAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return fa.run_forward(h) + + +def ls_backward_tree( + h, ts_mirror, rho, mu, normalisation_factor, precision=30, alleles=None +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts_mirror.genotype_matrix()[j, :], h[j]))) + for j in range(ts_mirror.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts_mirror.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts_mirror.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts_mirror.num_sites) + + """Backward matrix computation based on a tree sequence.""" + ba = BackwardAlgorithm( + ts_mirror, + rho, + mu, + alleles, + n_alleles, + normalisation_factor, + precision=precision, + ) + return ba.run_backward(h) + + +def ls_viterbi_tree( + h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False +): + if alleles is None: + n_alleles = np.int8( + [ + len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) + for j in range(ts.num_sites) + ] + ) + alleles = tskit.ALLELES_ACGT + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + alleles = tskit.ALLELES_01 + if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: + raise ValueError( + """Alleles list could not be identified. + Please pass a list of lists of alleles of length m, + or a list of alleles (e.g. tskit.ALLELES_ACGT)""" + ) + alleles = [alleles for _ in range(ts.num_sites)] + else: + alleles, n_alleles = check_alleles(alleles, ts.num_sites) """ - Superclass of Li and Stephens tests. + Viterbi path computation based on a tree sequence. """ - - def assertCompressedMatricesEqual(self, cm1, cm2): - """ - Checks that the specified compressed matrices contain the same data. - """ - A1 = cm1.decode() - A2 = cm2.decode() - assert np.allclose(A1, A2) - assert A1.shape == A2.shape - assert cm1.num_sites == cm2.num_sites - nf1 = cm1.normalisation_factor - nf2 = cm1.normalisation_factor - assert np.allclose(nf1, nf2) - assert nf1.shape == nf2.shape - # It seems that we can't rely on the number of transitions in the two - # implementations being equal, which seems odd given that we should - # be doing things identically. Still, once the decoded matrices are the - # same then it seems highly likely to be correct. - - # if not np.array_equal(cm1.num_transitions, cm2.num_transitions): - # print() - # print(cm1.num_transitions) - # print(cm2.num_transitions) - # self.assertTrue(np.array_equal(cm1.num_transitions, cm2.num_transitions)) - # for j in range(cm1.num_sites): - # s1 = dict(cm1.get_site(j)) - # s2 = dict(cm2.get_site(j)) - # self.assertEqual(set(s1.keys()), set(s2.keys())) - # for key in s1.keys(): - # self.assertAlmostEqual(s1[key], s2[key]) - - def example_haplotypes(self, ts, alleles, num_random=10, seed=2): - rng = np.random.RandomState(seed) - H = ts.genotype_matrix(alleles=alleles).T - haplotypes = [H[0], H[-1]] - for _ in range(num_random): - # Choose a random path through H - p = rng.randint(0, ts.num_samples, ts.num_sites) - h = H[p, np.arange(ts.num_sites)] - haplotypes.append(h) - h = H[0].copy() - h[-1] = tskit.MISSING_DATA - haplotypes.append(h) - h = H[0].copy() - h[ts.num_sites // 2] = tskit.MISSING_DATA - haplotypes.append(h) - # All missing is OK tool - h = H[0].copy() - h[:] = tskit.MISSING_DATA - haplotypes.append(h) - return haplotypes - - def example_parameters(self, ts, alleles, seed=1): - """ - Returns an iterator over combinations of haplotype, recombination and mutation - rates. - """ - rng = np.random.RandomState(seed) - haplotypes = self.example_haplotypes(ts, alleles, seed=seed) - - # This is the exact matching limit. - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + va = ViterbiAlgorithm( + ts, + rho, + mu, + alleles, + n_alleles, + precision=precision, + scale_mutation=scale_mutation_based_on_n_alleles, + ) + return va.run_forward(h) + + +class LSBase: + """Superclass of Li and Stephens tests.""" + + def example_haplotypes(self, ts): + + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + + haplotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]), + ] + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + haplotypes.append(s_tmp) + + return H, haplotypes + + def example_parameters_haplotypes(self, ts, seed=42): + """Returns an iterator over combinations of haplotype, + recombination and mutation rates.""" + np.random.seed(seed) + H, haplotypes = self.example_haplotypes(ts) + n = H.shape[1] + m = ts.get_num_sites() # Here we have equal mutation and recombination - rho = np.zeros(ts.num_sites) + 0.01 - mu = np.zeros(ts.num_sites) + 0.01 - rho[0] = 0 - for h in haplotypes: - yield h, rho, mu + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + for s in haplotypes: + yield n, H, s, r, mu # Mixture of random and extremes - rhos = [ - np.zeros(ts.num_sites) + 0.999, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 1, ts.num_sites), - ] - # mu can't be more than 1 / 3 if we have 4 alleles - mus = [ - np.zeros(ts.num_sites) + 0.33, - np.zeros(ts.num_sites) + 1e-6, - rng.uniform(0, 0.33, ts.num_sites), - ] - for h, rho, mu in itertools.product(haplotypes, rhos, mus): - rho[0] = 0 - yield h, rho, mu + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + yield n, H, s, r, mu def assertAllClose(self, A, B): - assert np.allclose(A, B) + """Assert that all entries of two matrices are 'close'""" + assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + + # Define a bunch of very small tree-sequences for testing a collection + # of parameters on + def test_simple_n_10_no_recombination(self): + ts = msprime.simulate( + 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 + ) + assert ts.num_sites > 3 + self.verify(ts) - def test_simple_n_4_no_recombination(self): - ts = msprime.simulate(4, recombination_rate=0, mutation_rate=0.5, random_seed=1) + def test_simple_n_10_no_recombination_high_mut(self): + ts = msprime.simulate(10, recombination_rate=0, mutation_rate=3, random_seed=42) assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) - assert ts.num_sites > 5 + def test_simple_n_10_no_recombination_higher_mut(self): + ts = msprime.simulate(20, recombination_rate=0, mutation_rate=3, random_seed=42) + assert ts.num_sites > 3 self.verify(ts) - def test_simple_n_7(self): - ts = msprime.simulate(7, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_6(self): + ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=2) - assert ts.num_trees > 15 + def test_simple_n_8(self): + ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 self.verify(ts) - def test_simple_n_15(self): - ts = msprime.simulate(15, recombination_rate=2, mutation_rate=5, random_seed=2) + def test_simple_n_8_high_recombination(self): + ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) + assert ts.num_trees > 15 assert ts.num_sites > 5 self.verify(ts) - def test_jukes_cantor_n_3(self): - ts = msprime.simulate(3, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=10, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=20, mu=5, seed=4) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_n_15(self): - ts = msprime.simulate(15, mutation_rate=2, random_seed=2) - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - def test_jukes_cantor_balanced_ternary(self): - ts = tskit.Tree.generate_balanced(27, arity=3).tree_sequence - ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) - self.verify(ts, tskit.ALLELES_ACGT) - - @pytest.mark.skip(reason="Not supporting internal samples yet") - def test_ancestors_n_3(self): - ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) + def test_simple_n_16(self): + ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) assert ts.num_sites > 5 - tables = ts.dump_tables() - print(tables.nodes) - tables.nodes.flags = np.ones_like(tables.nodes.flags) - print(tables.nodes) - ts = tables.tree_sequence() self.verify(ts) + # # Define a bunch of very small tree-sequences for testing a collection + # # of parameters on + # def test_simple_n_10_no_recombination_blah(self): + # ts = msprime.sim_ancestry( + # samples=10, + # recombination_rate=0, + # random_seed=42, + # sequence_length=10, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-5, random_seed=42) + # assert ts.num_sites > 3 + # self.verify(ts) + + # def test_simple_n_6_blah(self): + # ts = msprime.sim_ancestry( + # samples=6, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=40, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + # def test_simple_n_8_blah(self): + # ts = msprime.sim_ancestry( + # samples=8, + # recombination_rate=1e-4, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # assert ts.num_trees > 15 + # self.verify(ts) + + # def test_simple_n_16_blah(self): + # ts = msprime.sim_ancestry( + # samples=16, + # recombination_rate=1e-2, + # random_seed=42, + # sequence_length=20, + # population_size=10000, + # ) + # ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + # assert ts.num_sites > 5 + # self.verify(ts) + + def verify(self, ts): + raise NotImplementedError() -@pytest.mark.slow -class ForwardAlgorithmBase(LiStephensBase): - """ - Base for forward algorithm tests. - """ +class FBAlgorithmBase(LSBase): + """Base for forwards backwards algorithm tests.""" -class TestNumpyMatrixMethod(ForwardAlgorithmBase): - """ - Tests that we compute the same values from the numpy matrix method as - the naive algorithm. - """ - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F1, S1 = ls_forward_matrix(h, alleles, G, rho, mu) - F2, S2 = ls_forward_matrix_naive(h, alleles, G, rho, mu) - self.assertAllClose(F1, F2) - self.assertAllClose(S1, S2) +class VitAlgorithmBase(LSBase): + """Base for viterbi algoritm tests.""" -class ViterbiAlgorithmBase(LiStephensBase): - """ - Base for viterbi algoritm tests. - """ +class TestMirroringHap(FBAlgorithmBase): + """Tests that mirroring the tree sequence and running forwards and backwards + algorithms gives the same log-likelihood of observing the data.""" + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) -class TestExactMatchViterbi(ViterbiAlgorithmBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - # print(H) - rho = np.zeros(ts.num_sites) + 0.1 - mu = np.zeros(ts.num_sites) - rho[0] = 0 - for h in H: - p1 = ls_viterbi_naive(h, alleles, G, rho, mu) - p2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - - assert len(np.unique(p1)) == 1 - assert len(np.unique(p2)) == 1 - assert len(np.unique(p3)) == 1 - assert len(np.unique(p4)) == 1 - m1 = H[p1, np.arange(H.shape[1])] - assert np.array_equal(m1, h) - m2 = H[p2, np.arange(H.shape[1])] - assert np.array_equal(m2, h) - m3 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m3, h) - m4 = H[p3, np.arange(H.shape[1])] - assert np.array_equal(m4, h) - - -@pytest.mark.slow -class TestGeneralViterbi(ViterbiAlgorithmBase, unittest.TestCase): - def verify(self, ts, alleles=tskit.ALLELES_01): - # np.set_printoptions(linewidth=20000) - # np.set_printoptions(threshold=20000000) - G = ts.genotype_matrix(alleles=alleles) - # m, n = G.shape - for h, rho, mu in self.example_parameters(ts, alleles): - # print("h = ", h) - # print("rho=", rho) - # print("mu = ", mu) - p1 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - p2 = ls_viterbi_naive(h, alleles, G, rho, mu) - cm1 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - p3 = cm1.traceback() - cm2 = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=False) - p4 = cm1.traceback() - self.assertCompressedMatricesEqual(cm1, cm2) - # print() - # m1 = H[p1, np.arange(m)] - # m2 = H[p2, np.arange(m)] - # m3 = H[p3, np.arange(m)] - # count = np.unique(p1).shape[0] - # print() - # print("\tp1 = ", p1) - # print("\tp2 = ", p2) - # print("\tp3 = ", p3) - # print("\tm1 = ", m1) - # print("\tm2 = ", m2) - # print("\t h = ", h) - proba1 = ls_path_log_probability(h, p1, alleles, G, rho, mu) - proba2 = ls_path_log_probability(h, p2, alleles, G, rho, mu) - proba3 = ls_path_log_probability(h, p3, alleles, G, rho, mu) - proba4 = ls_path_log_probability(h, p4, alleles, G, rho, mu) - # print("\t P = ", proba1, proba2) - self.assertAlmostEqual(proba1, proba2, places=6) - self.assertAlmostEqual(proba1, proba3, places=6) - self.assertAlmostEqual(proba1, proba4, places=6) - - -class TestMissingHaplotypes(LiStephensBase): - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - H = G.T - - rho = np.zeros(ts.num_sites) + 0.1 - rho[0] = 0 - mu = np.zeros(ts.num_sites) + 0.001 - - # When everything is missing data we should have no recombinations. - h = H[0].copy() - h[:] = tskit.MISSING_DATA - path = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.all(path == 0) - cm = ls_viterbi_tree(h, alleles, ts, rho, mu, use_lib=True) - # For the tree base algorithm it's not simple which particular sample - # gets chosen. - path = cm.traceback() - assert len(set(path)) == 1 - - # TODO Not clear what else we can check about missing data. - - -class TestForwardMatrixScaling(ForwardAlgorithmBase, unittest.TestCase): - """ - Tests that we get the correct values from scaling version of the matrix - algorithm works correctly. - """ + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.insert(np.flip(r)[:-1], 0, 0) + cm_mirror = ls_forward_tree( + np.flip(s[0, :]), ts_check_mirror, r_flip, np.flip(mu) + ) + ll_mirror_tree = np.sum(np.log10(cm_mirror.normalisation_factor)) + self.assertAllClose(ll_tree, ll_mirror_tree) + + # Ensure that the decoded matrices are the same + F_mirror_matrix, c, ll = ls.forwards( + np.flip(H, axis=0), + np.flip(s, axis=1), + r_flip, + mutation_rate=np.flip(mu), + scale_mutation_based_on_n_alleles=False, + ) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - computed_log_proba = False - for h, rho, mu in self.example_parameters(ts, alleles): - F_unscaled = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - column = np.atleast_2d(np.cumprod(S)).T - F_scaled = F * column - self.assertAllClose(F_scaled, F_unscaled) - log_proba1 = forward_matrix_log_proba(F, S) - psum = np.sum(F_unscaled[-1]) - # If the computed probability is close to zero, there's no point in - # computing. - if psum > 1e-20: - computed_log_proba = True - log_proba2 = np.log(psum) - self.assertAlmostEqual(log_proba1, log_proba2) - assert computed_log_proba - - -class TestForwardTree(ForwardAlgorithmBase): - """ - Tests that the tree algorithm computes the same forward matrix as the - simple method. - """ + self.assertAllClose(F_mirror_matrix, cm_mirror.decode()) + self.assertAllClose(ll, ll_tree) - def verify(self, ts, alleles=tskit.ALLELES_01): - G = ts.genotype_matrix(alleles=alleles) - for h, rho, mu in self.example_parameters(ts, alleles): - F, S = ls_forward_matrix(h, alleles, G, rho, mu) - cm1 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=True) - cm2 = ls_forward_tree(h, alleles, ts, rho, mu, use_lib=False) - self.assertCompressedMatricesEqual(cm1, cm2) - Ft = cm1.decode() - self.assertAllClose(S, cm1.normalisation_factor) - self.assertAllClose(F, Ft) +class TestForwardHapTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" -class TestAllPaths(unittest.TestCase): - """ - Tests that we compute the correct forward probablities if we sum over all - possible paths through the genotype matrix. - """ + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + for scale_mutation in [False, True]: + F, c, ll = ls.forwards( + H, + s, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + # Note, need to remove the first sample from the ts, and ensure + # that invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_forward_tree( + s[0, :], + ts_check, + r, + mu, + scale_mutation_based_on_n_alleles=scale_mutation, + ) + self.assertAllClose(cm.decode(), F) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) - def verify(self, G, h): - m, n = G.shape - rho = np.zeros(m) + 0.1 - mu = np.zeros(m) + 0.01 - rho[0] = 0 - proba = 0 - for path in itertools.product(range(n), repeat=m): - proba += ls_path_probability(h, path, G, rho, mu) - - alleles = [["0", "1"] for _ in range(m)] - F = ls_forward_matrix_unscaled(h, alleles, G, rho, mu) - forward_proba = np.sum(F[-1]) - self.assertAlmostEqual(proba, forward_proba) - - def test_n3_m4(self): - G = np.array( - [ - # fmt: off - [1, 0, 0], - [0, 0, 1], - [1, 0, 1], - [0, 1, 1], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0]) - def test_n4_m5(self): - G = np.array( - [ - # fmt: off - [1, 0, 0, 0], - [0, 0, 1, 1], - [1, 0, 1, 1], - [0, 1, 1, 0], - # fmt: on - ] - ) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) +class TestForwardBackwardTree(FBAlgorithmBase): + """Tests that the tree algorithm computes the same forward matrix as the + simple method.""" - def test_n5_m5(self): - G = np.zeros((5, 5), dtype=int) - np.fill_diagonal(G, 1) - self.verify(G, [0, 0, 0, 0, 0]) - self.verify(G, [1, 1, 1, 1, 1]) - self.verify(G, [1, 1, 0, 0, 0]) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + F, c, ll = ls.forwards( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + B = ls.backwards( + H, + s, + c, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + # Note, need to remove the first sample from the ts, and ensure that + # invariant sites aren't removed. + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + c_f = ls_forward_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(c_f.normalisation_factor)) + + ts_check_mirror = mirror_coordinates(ts_check) + r_flip = np.flip(r) + c_b = ls_backward_tree( + np.flip(s[0, :]), + ts_check_mirror, + r_flip, + np.flip(mu), + np.flip(c_f.normalisation_factor), + ) + B_tree = np.flip(c_b.decode(), axis=0) + F_tree = c_f.decode() -class TestBasicViterbi: - """ - Very simple tests of the Viterbi algorithm. - """ + self.assertAllClose(B, B_tree) + self.assertAllClose(F, F_tree) + self.assertAllClose(ll, ll_tree) - def verify_exact_match(self, G, h, path): - m, n = G.shape - rho = np.zeros(m) + 1e-9 - mu = np.zeros(m) # Set mu to zero exact match - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - path2 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert list(path1) == path - assert list(path2) == path - - def test_n2_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0], - [1, 0], - [1, 0], - [0, 1], - [0, 1], - [0, 1], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 0], [1, 1, 1, 1, 1, 0]) - self.verify_exact_match(G, [0, 0, 0, 0, 1, 0], [1, 1, 1, 0, 1, 0]) - - def test_n3_m6_exact(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) - self.verify_exact_match(G, [1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1]) - self.verify_exact_match(G, [0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0]) - self.verify_exact_match(G, [0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1]) - self.verify_exact_match(G, [1, 0, 1, 0, 1, 0], [2, 2, 2, 2, 2, 2]) - def test_n3_m6(self): - G = np.array( - [ - # fmt: off - [1, 0, 1], - [1, 0, 0], - [1, 0, 1], - [0, 1, 0], - [0, 1, 1], - [0, 1, 0], - # fmt: on - ] - ) +class TestTreeViterbiHap(VitAlgorithmBase): + """Test that we have the same log-likelihood between tree and matrix + implementations""" - m, n = G.shape - rho = np.zeros(m) + 1e-2 - mu = np.zeros(m) - rho[0] = 0 - alleles = [["0", "1"] for _ in range(m)] - h = np.ones(m, dtype=int) - path1 = ls_viterbi_naive(h, alleles, G, rho, mu) - - # Add in mutation at a very low rate. - mu[:] = 1e-8 - path2 = ls_viterbi_naive(h, alleles, G, rho, mu) - path3 = ls_viterbi_vectorised(h, alleles, G, rho, mu) - assert np.array_equal(path1, path2) - assert np.array_equal(path2, path3) + def verify(self, ts): + for n, H, s, r, mu in self.example_parameters_haplotypes(ts): + path, ll = ls.viterbi( + H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False + ) + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + cm = ls_viterbi_tree(s[0, :], ts_check, r, mu) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + self.assertAllClose(ll, ll_tree) + + # Now, need to ensure that the likelihood of the preferred path is + # the same as ll_tree (and ll). + path_tree = cm.traceback() + ll_check = ls.path_ll( + H, + s, + path_tree, + r, + mutation_rate=mu, + scale_mutation_based_on_n_alleles=False, + ) + self.assertAllClose(ll, ll_check) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b124ea6fca..0529dda001 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -228,14 +228,14 @@ def get_gap_examples(): assert len(t.parent_dict) == 0 found = True assert found - ret.append((f"gap {x}", ts)) + ret.append((f"gap_{x}", ts)) # Give an example with a gap at the end. ts = msprime.simulate(10, random_seed=5, recombination_rate=1) tables = get_table_collection_copy(ts.dump_tables(), 2) tables.sites.clear() tables.mutations.clear() insert_uniform_mutations(tables, 100, list(ts.samples())) - ret.append(("gap at end", tables.tree_sequence())) + ret.append(("gap_at_end", tables.tree_sequence())) return ret @@ -254,19 +254,19 @@ def get_internal_samples_examples(): # Set all nodes to be samples. flags[:] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("all nodes samples", tables.tree_sequence())) + ret.append(("all_nodes_samples", tables.tree_sequence())) # Set just internal nodes to be samples. flags[:] = 0 flags[n:] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("internal nodes samples", tables.tree_sequence())) + ret.append(("internal_nodes_samples", tables.tree_sequence())) # Set a mixture of internal and leaf samples. flags[:] = 0 flags[n // 2 : n + n // 2] = tskit.NODE_IS_SAMPLE nodes.flags = flags - ret.append(("mixture of internal and leaf samples", tables.tree_sequence())) + ret.append(("mixed_internal_leaf_samples", tables.tree_sequence())) return ret @@ -281,7 +281,7 @@ def get_decapitated_examples(): ts = msprime.simulate(20, recombination_rate=1, random_seed=1234) assert ts.num_trees > 2 - ret.append(("decapitate recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) + ret.append(("decapitate_recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) return ret @@ -302,7 +302,7 @@ def get_bottleneck_examples(): demographic_events=bottlenecks, random_seed=n, ) - yield (f"bottleneck n={n}", ts) + yield (f"bottleneck_n={n}", ts) def get_back_mutation_examples(): @@ -337,13 +337,13 @@ def make_example_tree_sequences(): ) ts = tsutil.insert_random_ploidy_individuals(ts, 4, seed=seed) yield ( - f"n={n} m={m} rho={rho}", + f"n={n}_m={m}_rho={rho}", tsutil.add_random_metadata(ts, seed=seed), ) seed += 1 for name, ts in get_bottleneck_examples(): yield ( - f"{name} mutated", + f"{name}_mutated", msprime.mutate( ts, rate=0.1, @@ -352,7 +352,7 @@ def make_example_tree_sequences(): ), ) ts = tskit.Tree.generate_balanced(8).tree_sequence - yield ("rev node order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) + yield ("rev_node_order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) ts = msprime.sim_ancestry( 8, sequence_length=40, recombination_rate=0.1, random_seed=seed ) @@ -361,20 +361,20 @@ def make_example_tree_sequences(): ts = tables.tree_sequence() assert ts.num_trees > 1 yield ( - "back mutations", + "back_mutations", tsutil.insert_branch_mutations(ts, mutations_per_branch=2), ) ts = tsutil.insert_multichar_mutations(ts) yield ("multichar", ts) - yield ("multichar w/ metadata", tsutil.add_random_metadata(ts)) + yield ("multichar_no_metadata", tsutil.add_random_metadata(ts)) tables = ts.dump_tables() tables.nodes.flags = np.zeros_like(tables.nodes.flags) - yield ("no samples", tables.tree_sequence()) # no samples + yield ("no_samples", tables.tree_sequence()) # no samples tables = ts.dump_tables() tables.edges.clear() - yield ("empty tree", tables.tree_sequence()) # empty tree + yield ("empty_tree", tables.tree_sequence()) # empty tree yield ( - "empty ts", + "empty_ts", tskit.TableCollection(sequence_length=1).tree_sequence(), ) # empty tree seq yield ("all_fields", tsutil.all_fields_ts()) @@ -384,6 +384,8 @@ def make_example_tree_sequences(): def get_example_tree_sequences(pytest_params=True): + # NOTE: pytest names should not contain spaces and be shell safe so + # that they can be easily specified on the command line. if pytest_params: return [pytest.param(ts, id=name) for name, ts in _examples] else: @@ -1828,181 +1830,6 @@ def test_max_root_time_corner_cases(self): tables.edges.add_row(0, 1, 3, 1) assert tables.tree_sequence().max_root_time == 3 - def verify_simplify_provenance(self, ts): - new_ts = ts.simplify() - assert new_ts.num_provenances == ts.num_provenances + 1 - old = list(ts.provenances()) - new = list(new_ts.provenances()) - assert old == new[:-1] - # TODO call verify_provenance on this. - assert len(new[-1].timestamp) > 0 - assert len(new[-1].record) > 0 - - new_ts = ts.simplify(record_provenance=False) - assert new_ts.tables.provenances == ts.tables.provenances - - def verify_simplify_topology(self, ts, sample): - new_ts, node_map = ts.simplify(sample, map_nodes=True) - if len(sample) == 0: - assert new_ts.num_nodes == 0 - assert new_ts.num_edges == 0 - assert new_ts.num_sites == 0 - assert new_ts.num_mutations == 0 - elif len(sample) == 1: - assert new_ts.num_nodes == 1 - assert new_ts.num_edges == 0 - # The output samples should be 0...n - assert new_ts.num_samples == len(sample) - assert list(range(len(sample))) == list(new_ts.samples()) - for j in range(new_ts.num_samples): - assert node_map[sample[j]] == j - for u in range(ts.num_nodes): - old_node = ts.node(u) - if node_map[u] != tskit.NULL: - new_node = new_ts.node(node_map[u]) - assert old_node.time == new_node.time - assert old_node.population == new_node.population - assert old_node.metadata == new_node.metadata - for u in sample: - old_node = ts.node(u) - new_node = new_ts.node(node_map[u]) - assert old_node.flags == new_node.flags - assert old_node.time == new_node.time - assert old_node.population == new_node.population - assert old_node.metadata == new_node.metadata - old_trees = ts.trees() - old_tree = next(old_trees) - assert ts.get_num_trees() >= new_ts.get_num_trees() - for new_tree in new_ts.trees(): - new_left, new_right = new_tree.get_interval() - old_left, old_right = old_tree.get_interval() - # Skip ahead on the old tree until new_left is within its interval - while old_right <= new_left: - old_tree = next(old_trees) - old_left, old_right = old_tree.get_interval() - # If the MRCA of all pairs of samples is the same, then we have the - # same information. We limit this to at most 500 pairs - pairs = itertools.islice(itertools.combinations(sample, 2), 500) - for pair in pairs: - mapped_pair = [node_map[u] for u in pair] - mrca1 = old_tree.get_mrca(*pair) - mrca2 = new_tree.get_mrca(*mapped_pair) - if mrca1 == tskit.NULL: - assert mrca2 == mrca1 - else: - assert mrca2 == node_map[mrca1] - assert old_tree.get_time(mrca1) == new_tree.get_time(mrca2) - assert old_tree.get_population(mrca1) == new_tree.get_population( - mrca2 - ) - - def verify_simplify_equality(self, ts, sample): - for filter_sites in [False, True]: - s1, node_map1 = ts.simplify( - sample, map_nodes=True, filter_sites=filter_sites - ) - t1 = s1.dump_tables() - s2, node_map2 = simplify_tree_sequence( - ts, sample, filter_sites=filter_sites - ) - t2 = s2.dump_tables() - assert s1.num_samples == len(sample) - assert s2.num_samples == len(sample) - assert all(node_map1 == node_map2) - assert t1.individuals == t2.individuals - assert t1.nodes == t2.nodes - assert t1.edges == t2.edges - assert t1.migrations == t2.migrations - assert t1.sites == t2.sites - assert t1.mutations == t2.mutations - assert t1.populations == t2.populations - - def verify_simplify_variants(self, ts, sample): - subset = ts.simplify(sample) - sample_map = {u: j for j, u in enumerate(ts.samples())} - # Need to map IDs back to their sample indexes - s = np.array([sample_map[u] for u in sample]) - # Build a map of genotypes by position - full_genotypes = {} - for variant in ts.variants(isolated_as_missing=False): - alleles = [variant.alleles[g] for g in variant.genotypes] - full_genotypes[variant.position] = alleles - for variant in subset.variants(isolated_as_missing=False): - if variant.position in full_genotypes: - a1 = [full_genotypes[variant.position][u] for u in s] - a2 = [variant.alleles[g] for g in variant.genotypes] - assert a1 == a2 - - def verify_tables_api_equality(self, ts): - for samples in [None, list(ts.samples()), ts.samples()]: - tables = ts.dump_tables() - tables.simplify(samples=samples) - tables.assert_equals( - ts.simplify(samples=samples).tables, ignore_timestamps=True - ) - - @pytest.mark.slow - def test_simplify(self): - num_mutations = 0 - for ts in get_example_tree_sequences(pytest_params=False): - # Can't simplify edges with metadata - if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): - self.verify_tables_api_equality(ts) - self.verify_simplify_provenance(ts) - n = ts.num_samples - num_mutations += ts.num_mutations - sample_sizes = {0} - if n > 1: - sample_sizes |= {1} - if n > 2: - sample_sizes |= {2, max(2, n // 2), n - 1} - for k in sample_sizes: - subset = random.sample(list(ts.samples()), k) - self.verify_simplify_topology(ts, subset) - self.verify_simplify_equality(ts, subset) - self.verify_simplify_variants(ts, subset) - assert num_mutations > 0 - - def test_simplify_bugs(self): - prefix = os.path.join(os.path.dirname(__file__), "data", "simplify-bugs") - j = 1 - while True: - nodes_file = os.path.join(prefix, f"{j:02d}-nodes.txt") - if not os.path.exists(nodes_file): - break - edges_file = os.path.join(prefix, f"{j:02d}-edges.txt") - sites_file = os.path.join(prefix, f"{j:02d}-sites.txt") - mutations_file = os.path.join(prefix, f"{j:02d}-mutations.txt") - with open(nodes_file) as nodes, open(edges_file) as edges, open( - sites_file - ) as sites, open(mutations_file) as mutations: - ts = tskit.load_text( - nodes=nodes, - edges=edges, - sites=sites, - mutations=mutations, - strict=False, - ) - samples = list(ts.samples()) - self.verify_simplify_equality(ts, samples) - j += 1 - assert j > 1 - - def test_simplify_migrations_fails(self): - ts = msprime.simulate( - population_configurations=[ - msprime.PopulationConfiguration(10), - msprime.PopulationConfiguration(10), - ], - migration_matrix=[[0, 1], [1, 0]], - random_seed=2, - record_migrations=True, - ) - assert ts.num_migrations > 0 - # We don't support simplify with migrations, so should fail. - with pytest.raises(_tskit.LibraryError): - ts.simplify() - def test_subset_reverse_all_nodes(self): ts = tskit.Tree.generate_comb(5).tree_sequence assert np.all(ts.samples() == np.arange(ts.num_samples)) @@ -2769,6 +2596,239 @@ def test_arrays_equal_to_tables(self, ts_fixture): ts.indexes_edge_removal_order, tables.indexes.edge_removal_order ) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_impute_unknown_mutations_time(self, ts): + # Tests for method='min' + imputed_time = ts.impute_unknown_mutations_time(method="min") + mutations = ts.tables.mutations + nodes_time = ts.nodes_time + table_time = np.zeros(len(mutations)) + + for mut_idx, mut in enumerate(mutations): + if tskit.is_unknown_time(mut.time): + node_time = nodes_time[mut.node] + table_time[mut_idx] = node_time + else: + table_time[mut_idx] = mut.time + + assert np.allclose(imputed_time, table_time, rtol=1e-10, atol=1e-10) + + # Check we have valid times + tables = ts.dump_tables() + tables.mutations.time = imputed_time + tables.sort() + tables.tree_sequence() + + # Test for unallowed methods + with pytest.raises( + ValueError, match="Mutations time imputation method must be chosen" + ): + ts.impute_unknown_mutations_time(method="foobar") + + +class TestSimplify: + # This class was factored out of the old TestHighlevel class 2022-12-13, + # and is a mishmash of different testing paradigms. There is some valuable + # testing done here, so it would be good to fully bring it up to date. + + def verify_simplify_provenance(self, ts): + new_ts = ts.simplify() + assert new_ts.num_provenances == ts.num_provenances + 1 + old = list(ts.provenances()) + new = list(new_ts.provenances()) + assert old == new[:-1] + # TODO call verify_provenance on this. + assert len(new[-1].timestamp) > 0 + assert len(new[-1].record) > 0 + + new_ts = ts.simplify(record_provenance=False) + assert new_ts.tables.provenances == ts.tables.provenances + + def verify_simplify_topology(self, ts, sample): + new_ts, node_map = ts.simplify(sample, map_nodes=True) + if len(sample) == 0: + assert new_ts.num_nodes == 0 + assert new_ts.num_edges == 0 + assert new_ts.num_sites == 0 + assert new_ts.num_mutations == 0 + elif len(sample) == 1: + assert new_ts.num_nodes == 1 + assert new_ts.num_edges == 0 + # The output samples should be 0...n + assert new_ts.num_samples == len(sample) + assert list(range(len(sample))) == list(new_ts.samples()) + for j in range(new_ts.num_samples): + assert node_map[sample[j]] == j + for u in range(ts.num_nodes): + old_node = ts.node(u) + if node_map[u] != tskit.NULL: + new_node = new_ts.node(node_map[u]) + assert old_node.time == new_node.time + assert old_node.population == new_node.population + assert old_node.metadata == new_node.metadata + for u in sample: + old_node = ts.node(u) + new_node = new_ts.node(node_map[u]) + assert old_node.flags == new_node.flags + assert old_node.time == new_node.time + assert old_node.population == new_node.population + assert old_node.metadata == new_node.metadata + old_trees = ts.trees() + old_tree = next(old_trees) + assert ts.get_num_trees() >= new_ts.get_num_trees() + for new_tree in new_ts.trees(): + new_left, new_right = new_tree.get_interval() + old_left, old_right = old_tree.get_interval() + # Skip ahead on the old tree until new_left is within its interval + while old_right <= new_left: + old_tree = next(old_trees) + old_left, old_right = old_tree.get_interval() + # If the MRCA of all pairs of samples is the same, then we have the + # same information. We limit this to at most 500 pairs + pairs = itertools.islice(itertools.combinations(sample, 2), 500) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = old_tree.get_mrca(*pair) + mrca2 = new_tree.get_mrca(*mapped_pair) + if mrca1 == tskit.NULL: + assert mrca2 == mrca1 + else: + assert mrca2 == node_map[mrca1] + assert old_tree.get_time(mrca1) == new_tree.get_time(mrca2) + assert old_tree.get_population(mrca1) == new_tree.get_population( + mrca2 + ) + + def verify_simplify_equality(self, ts, sample): + for filter_sites in [False, True]: + s1, node_map1 = ts.simplify( + sample, map_nodes=True, filter_sites=filter_sites + ) + t1 = s1.dump_tables() + s2, node_map2 = simplify_tree_sequence( + ts, sample, filter_sites=filter_sites + ) + t2 = s2.dump_tables() + assert s1.num_samples == len(sample) + assert s2.num_samples == len(sample) + assert all(node_map1 == node_map2) + assert t1.individuals == t2.individuals + assert t1.nodes == t2.nodes + assert t1.edges == t2.edges + assert t1.migrations == t2.migrations + assert t1.sites == t2.sites + assert t1.mutations == t2.mutations + assert t1.populations == t2.populations + + def verify_simplify_variants(self, ts, sample): + subset = ts.simplify(sample) + sample_map = {u: j for j, u in enumerate(ts.samples())} + # Need to map IDs back to their sample indexes + s = np.array([sample_map[u] for u in sample]) + # Build a map of genotypes by position + full_genotypes = {} + for variant in ts.variants(isolated_as_missing=False): + alleles = [variant.alleles[g] for g in variant.genotypes] + full_genotypes[variant.position] = alleles + for variant in subset.variants(isolated_as_missing=False): + if variant.position in full_genotypes: + a1 = [full_genotypes[variant.position][u] for u in s] + a2 = [variant.alleles[g] for g in variant.genotypes] + assert a1 == a2 + + def verify_tables_api_equality(self, ts): + for samples in [None, list(ts.samples()), ts.samples()]: + tables = ts.dump_tables() + tables.simplify(samples=samples) + tables.assert_equals( + ts.simplify(samples=samples).tables, ignore_timestamps=True + ) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify_tables_equality(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + self.verify_tables_api_equality(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify_provenance(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + self.verify_simplify_provenance(ts) + + # TODO this test needs to be broken up into discrete bits, so that we can + # test them independently. A way of getting a random-ish subset of samples + # from the pytest param would be useful. + @pytest.mark.slow + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_simplify(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + n = ts.num_samples + sample_sizes = {0} + if n > 1: + sample_sizes |= {1} + if n > 2: + sample_sizes |= {2, max(2, n // 2), n - 1} + for k in sample_sizes: + subset = random.sample(list(ts.samples()), k) + self.verify_simplify_topology(ts, subset) + self.verify_simplify_equality(ts, subset) + self.verify_simplify_variants(ts, subset) + + def test_simplify_bugs(self): + prefix = os.path.join(os.path.dirname(__file__), "data", "simplify-bugs") + j = 1 + while True: + nodes_file = os.path.join(prefix, f"{j:02d}-nodes.txt") + if not os.path.exists(nodes_file): + break + edges_file = os.path.join(prefix, f"{j:02d}-edges.txt") + sites_file = os.path.join(prefix, f"{j:02d}-sites.txt") + mutations_file = os.path.join(prefix, f"{j:02d}-mutations.txt") + with open(nodes_file) as nodes, open(edges_file) as edges, open( + sites_file + ) as sites, open(mutations_file) as mutations: + ts = tskit.load_text( + nodes=nodes, + edges=edges, + sites=sites, + mutations=mutations, + strict=False, + ) + samples = list(ts.samples()) + self.verify_simplify_equality(ts, samples) + j += 1 + assert j > 1 + + def test_simplify_migrations_fails(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(10), + msprime.PopulationConfiguration(10), + ], + migration_matrix=[[0, 1], [1, 0]], + random_seed=2, + record_migrations=True, + ) + assert ts.num_migrations > 0 + # We don't support simplify with migrations, so should fail. + with pytest.raises(_tskit.LibraryError): + ts.simplify() + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_no_update_sample_flags_no_filter_nodes(self, ts): + # Can't simplify edges with metadata + if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): + k = min(ts.num_samples, 3) + subset = ts.samples()[:k] + ts1 = ts.simplify(subset) + ts2 = ts.simplify(subset, update_sample_flags=False, filter_nodes=False) + assert ts1.num_samples == len(subset) + assert ts2.num_samples == ts.num_samples + assert ts1.num_edges == ts2.num_edges + assert ts2.tables.nodes == ts.tables.nodes + class TestMinMaxTime: def get_example_tree_sequence(self, use_unknown_time): @@ -3867,6 +3927,37 @@ def test_branch_length_empty_tree(self): assert tree.branch_length(1) == 0 assert tree.total_branch_length == 0 + @pytest.mark.parametrize("r_threshold", [0, -1]) + def test_bad_val_root_threshold(self, r_threshold): + with pytest.raises(ValueError, match="greater than 0"): + tskit.Tree.generate_balanced(2, root_threshold=r_threshold) + + @pytest.mark.parametrize("r_threshold", [None, 0.5, 1.5, np.inf]) + def test_bad_type_root_threshold(self, r_threshold): + with pytest.raises(TypeError): + tskit.Tree.generate_balanced(2, root_threshold=r_threshold) + + def test_simple_root_threshold(self): + tree = tskit.Tree.generate_balanced(3, root_threshold=3) + assert tree.num_roots == 1 + tree = tskit.Tree.generate_balanced(3, root_threshold=4) + assert tree.num_roots == 0 + + @pytest.mark.parametrize("root_threshold", [1, 2, 3]) + def test_is_root(self, root_threshold): + # Make a tree with multiple roots with different numbers of samples under each + ts = tskit.Tree.generate_balanced(5).tree_sequence + ts = ts.decapitate(ts.max_root_time - 0.1) + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) # Isolated non-sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) # Isolated sample + ts = tables.tree_sequence() + assert {ts.first().num_samples(u) for u in ts.first().roots} == {1, 2, 3} + tree = ts.first(root_threshold=root_threshold) + roots = set(tree.roots) + for u in range(ts.num_nodes): # Will also test isolated nodes + assert tree.is_root(u) == (u in roots) + def test_is_descendant(self): def is_descendant(tree, u, v): path = [] @@ -4270,6 +4361,69 @@ def test_node_edges(self): assert edge == tskit.NULL +class TestSiblings: + def test_balanced_binary_tree(self): + t = tskit.Tree.generate_balanced(num_leaves=3) + assert t.has_single_root + # Nodes 0 to 2 are leaves + for u in range(2): + assert t.is_leaf(u) + assert t.siblings(0) == (3,) + assert t.siblings(1) == (2,) + assert t.siblings(2) == (1,) + # Node 3 is the internal node + assert t.is_internal(3) + assert t.siblings(3) == (0,) + # Node 4 is the root + assert 4 == t.root + assert t.siblings(4) == tuple() + # Node 5 is the virtual root + assert 5 == t.virtual_root + assert t.siblings(5) == tuple() + + def test_star(self): + t = tskit.Tree.generate_star(num_leaves=3) + assert t.has_single_root + # Nodes 0 to 2 are leaves + for u in range(2): + assert t.is_leaf(u) + assert t.siblings(0) == (1, 2) + assert t.siblings(1) == (0, 2) + assert t.siblings(2) == (0, 1) + # Node 3 is the root + assert 3 == t.root + assert t.siblings(3) == tuple() + # Node 4 is the virtual root + assert 4 == t.virtual_root + assert t.siblings(4) == tuple() + + def test_multiroot_tree(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + t = ts.decapitate(ts.node(5).time).first() + assert t.has_multiple_roots + # Nodes 0 to 3 are leaves + assert t.siblings(0) == (1,) + assert t.siblings(1) == (0,) + assert t.siblings(2) == (3,) + assert t.siblings(3) == (2,) + # Nodes 4 and 5 are both roots + assert 4 in t.roots + assert t.siblings(4) == (5,) + assert 5 in t.roots + assert t.siblings(5) == (4,) + # Node 7 is the virtual root + assert 7 == t.virtual_root + assert t.siblings(7) == tuple() + + @pytest.mark.parametrize("flag,expected", [(0, ()), (1, (2,))]) + def test_isolated_node(self, flag, expected): + tables = tskit.Tree.generate_balanced(2, arity=2).tree_sequence.dump_tables() + tables.nodes.add_row(flags=flag) # Add node 3 + t = tables.tree_sequence().first() + assert t.is_isolated(3) + assert t.siblings(3) == expected + + class TestNodeOrdering(HighLevelTestCase): """ Verify that we can use any node ordering for internal nodes @@ -4473,10 +4627,13 @@ def test_index_from_different_directions(self, index): t2.prev() assert_same_tree_different_order(t1, t2) - def test_seek_0_from_null(self): + @pytest.mark.parametrize("position", [0, 1, 2, 3]) + def test_seek_from_null(self, position): t1, t2 = self.setup() - t1.first() - t2.seek(0) + t1.clear() + t1.seek(position) + t2.first() + t2.seek(position) assert_trees_identical(t1, t2) @pytest.mark.parametrize("index", range(3)) @@ -4529,6 +4686,14 @@ def test_seek_3_from_null(self): t2.seek(3) assert_trees_identical(t1, t2) + def test_seek_3_from_null_prev(self): + t1, t2 = self.setup() + t1.last() + t1.prev() + t2.seek(3) + t2.prev() + assert_trees_identical(t1, t2) + def test_seek_3_from_0(self): t1, t2 = self.setup() t1.last() @@ -4544,6 +4709,37 @@ def test_seek_0_from_3(self): t2.seek(0) assert_trees_identical(t1, t2) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_mid_null_and_middle(self, ts): + breakpoints = ts.breakpoints(as_array=True) + mid = breakpoints[:-1] + np.diff(breakpoints) / 2 + for index, x in enumerate(mid[:-1]): + t1 = tskit.Tree(ts) + t1.seek(x) + # Also seek to this point manually to make sure we're not + # reusing the seek from null under the hood. + t2 = tskit.Tree(ts) + if index <= ts.num_trees / 2: + while t2.index != index: + t2.next() + else: + while t2.index != index: + t2.prev() + assert t1.index == t2.index + assert np.all(t1.parent_array == t2.parent_array) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_last_then_prev(self, ts): + t1 = tskit.Tree(ts) + t1.seek(ts.sequence_length - 0.00001) + assert t1.index == ts.num_trees - 1 + t2 = tskit.Tree(ts) + t2.prev() + assert_trees_identical(t1, t2) + t1.prev() + t2.prev() + assert_trees_identical(t1, t2) + class TestSeek: @pytest.mark.parametrize("ts", get_example_tree_sequences()) diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py new file mode 100644 index 0000000000..f4ac31dfea --- /dev/null +++ b/python/tests/test_intervals.py @@ -0,0 +1,858 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# Copyright (C) 2020-2021 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +""" +Test cases for the intervals module. +""" +import decimal +import fractions +import gzip +import io +import os +import pickle +import textwrap +import xml + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import tskit + + +class TestRateMapErrors: + @pytest.mark.parametrize( + ("position", "rate"), + [ + ([], []), + ([0], []), + ([0], [0]), + ([1, 2], [0]), + ([0, -1], [0]), + ([0, 1], [-1]), + ], + ) + def test_bad_input(self, position, rate): + with pytest.raises(ValueError): + tskit.RateMap(position=position, rate=rate) + + def test_zero_length_interval(self): + with pytest.raises(ValueError, match=r"at indexes \[2 4\]"): + tskit.RateMap(position=[0, 1, 1, 2, 2, 3], rate=[0, 0, 0, 0, 0]) + + def test_bad_length(self): + positions = np.array([0, 1, 2]) + rates = np.array([0, 1, 2]) + with pytest.raises(ValueError, match="one less entry"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_first_pos(self): + positions = np.array([1, 2, 3]) + rates = np.array([1, 1]) + with pytest.raises(ValueError, match="First position"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, -1]) + with pytest.raises(ValueError, match="negative.*1"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_rate_with_missing(self): + positions = np.array([0, 1, 2]) + rates = np.array([np.nan, -1]) + with pytest.raises(ValueError, match="negative.*1"): + tskit.RateMap(position=positions, rate=rates) + + def test_read_only(self): + positions = np.array([0, 0.25, 0.5, 0.75, 1]) + rates = np.array([0.125, 0.25, 0.5, 0.75]) # 1 shorter than positions + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.all(rates == rate_map.rate) + assert np.all(positions == rate_map.position) + with pytest.raises(AttributeError): + rate_map.rate = 2 * rate_map.rate + with pytest.raises(AttributeError): + rate_map.position = 2 * rate_map.position + with pytest.raises(AttributeError): + rate_map.left = 1234 + with pytest.raises(AttributeError): + rate_map.right = 1234 + with pytest.raises(AttributeError): + rate_map.mid = 1234 + with pytest.raises(ValueError): + rate_map.rate[0] = 1 + with pytest.raises(ValueError): + rate_map.position[0] = 1 + with pytest.raises(ValueError): + rate_map.left[0] = 1 + with pytest.raises(ValueError): + rate_map.mid[0] = 1 + with pytest.raises(ValueError): + rate_map.right[0] = 1 + + +class TestGetRateAllKnown: + examples = [ + tskit.RateMap(position=[0, 1], rate=[0]), + tskit.RateMap(position=[0, 1], rate=[0.1]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_mid(self, rate_map): + rate = rate_map.get_rate(rate_map.mid) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.mid[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_left(self, rate_map): + rate = rate_map.get_rate(rate_map.left) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.left[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_right(self, rate_map): + rate = rate_map.get_rate(rate_map.right[:-1]) + assert len(rate) == len(rate_map) - 1 + for j in range(len(rate_map) - 1): + assert rate[j] == rate_map[rate_map.right[j]] + + +class TestOperations: + examples = [ + tskit.RateMap.uniform(sequence_length=1, rate=0), + tskit.RateMap.uniform(sequence_length=1, rate=0.1), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), + # Missing data + tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), + tskit.RateMap(position=[0, 1, 2], rate=[0, np.nan]), + tskit.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_num_intervals(self, rate_map): + assert rate_map.num_intervals == len(rate_map.rate) + assert rate_map.num_missing_intervals == np.sum(np.isnan(rate_map.rate)) + assert rate_map.num_non_missing_intervals == np.sum(~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_mask_arrays(self, rate_map): + assert_array_equal(rate_map.missing, np.isnan(rate_map.rate)) + assert_array_equal(rate_map.non_missing, ~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_missing_intervals(self, rate_map): + missing = [] + for left, right, rate in zip(rate_map.left, rate_map.right, rate_map.rate): + if np.isnan(rate): + missing.append([left, right]) + if len(missing) == 0: + assert len(rate_map.missing_intervals()) == 0 + else: + assert_array_equal(missing, rate_map.missing_intervals()) + + @pytest.mark.parametrize("rate_map", examples) + def test_mean_rate(self, rate_map): + total_span = 0 + total_mass = 0 + for span, mass in zip(rate_map.span, rate_map.mass): + if not np.isnan(mass): + total_span += span + total_mass += mass + assert total_mass / total_span == rate_map.mean_rate + + @pytest.mark.parametrize("rate_map", examples) + def test_total_mass(self, rate_map): + assert rate_map.total_mass == np.nansum(rate_map.mass) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_cumulative_mass(self, rate_map): + assert list(rate_map.get_cumulative_mass([0])) == [0] + assert list(rate_map.get_cumulative_mass([rate_map.sequence_length])) == [ + rate_map.total_mass + ] + assert_array_equal( + rate_map.get_cumulative_mass(rate_map.right), np.nancumsum(rate_map.mass) + ) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate(self, rate_map): + assert_array_equal(rate_map.get_rate([0]), rate_map.rate[0]) + assert_array_equal( + rate_map.get_rate([rate_map.sequence_length - 1e-9]), rate_map.rate[-1] + ) + assert_array_equal(rate_map.get_rate(rate_map.left), rate_map.rate) + + @pytest.mark.parametrize("rate_map", examples) + def test_map_semantics(self, rate_map): + assert len(rate_map) == rate_map.num_non_missing_intervals + assert_array_equal(list(rate_map.keys()), rate_map.mid[rate_map.non_missing]) + for x in rate_map.left[rate_map.missing]: + assert x not in rate_map + for x in rate_map.mid[rate_map.missing]: + assert x not in rate_map + + def test_asdict(self): + rate_map = tskit.RateMap.uniform(sequence_length=2, rate=4) + d = rate_map.asdict() + assert_array_equal(d["position"], np.array([0.0, 2.0])) + assert_array_equal(d["rate"], np.array([4.0])) + + +class TestFindIndex: + def test_one_interval(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + for j in range(10): + assert rate_map.find_index(j) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(9.999) == 0 + + def test_two_intervals(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + + def test_three_intervals(self): + rate_map = tskit.RateMap(position=[0, 5, 10, 15], rate=[0.1, 0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + assert rate_map.find_index(10) == 2 + assert rate_map.find_index(10.1) == 2 + assert rate_map.find_index(12) == 2 + assert rate_map.find_index(14.9999) == 2 + + def test_out_of_bounds(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + for bad_value in [-1, -0.0001, 10, 10.0001, 1e9]: + with pytest.raises(KeyError, match="out of bounds"): + rate_map.find_index(bad_value) + + def test_input_types(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0) == 0 + assert rate_map.find_index(np.zeros(1)[0]) == 0 + + +class TestSimpleExamples: + def test_all_missing_one_interval(self): + with pytest.raises(ValueError, match="missing data"): + tskit.RateMap(position=[0, 10], rate=[np.nan]) + + def test_all_missing_two_intervals(self): + with pytest.raises(ValueError, match="missing data"): + tskit.RateMap(position=[0, 5, 10], rate=[np.nan, np.nan]) + + def test_count(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert rate_map.num_intervals == 2 + assert rate_map.num_missing_intervals == 1 + assert rate_map.num_non_missing_intervals == 1 + + def test_missing_arrays(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert list(rate_map.missing) == [True, False] + assert list(rate_map.non_missing) == [False, True] + + def test_missing_at_start_mean_rate(self): + positions = np.array([0, 0.5, 1, 2]) + rates = np.array([np.nan, 0, 1]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_missing_at_end_mean_rate(self): + positions = np.array([0, 1, 1.5, 2]) + rates = np.array([1, 0, np.nan]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_interval_properties_all_known(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.left) == [0, 1, 2] + assert list(rate_map.right) == [1, 2, 3] + assert list(rate_map.mid) == [0.5, 1.5, 2.5] + assert list(rate_map.span) == [1, 1, 1] + assert list(rate_map.mass) == [0.1, 0.2, 0.3] + + def test_pickle_non_missing(self): + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_pickle_missing(self): + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, np.nan, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_get_cumulative_mass_all_known(self): + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.mass) == [1, 2, 3] + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 3, 6] + + def test_cumulative_mass_missing(self): + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, np.nan, 0.3]) + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 1, 4] + + +class TestDisplay: + def test_str(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + s = """\ + ╔════╤═════╤═══╤════╤════╗ + ║left│right│mid│span│rate║ + ╠════╪═════╪═══╪════╪════╣ + ║0 │10 │ 5│ 10│ 0.1║ + ╚════╧═════╧═══╧════╧════╝ + """ + assert textwrap.dedent(s) == str(rate_map) + + def test_str_scinot(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.000001]) + s = """\ + ╔════╤═════╤═══╤════╤═════╗ + ║left│right│mid│span│rate ║ + ╠════╪═════╪═══╪════╪═════╣ + ║0 │10 │ 5│ 10│1e-06║ + ╚════╧═════╧═══╧════╧═════╝ + """ + assert textwrap.dedent(s) == str(rate_map) + + def test_repr(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + s = "RateMap(position=array([ 0., 10.]), rate=array([0.1]))" + assert repr(rate_map) == s + + def test_repr_html(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + html = rate_map._repr_html_() + root = xml.etree.ElementTree.fromstring(html) + assert root.tag == "div" + table = root.find("table") + rows = list(table.find("tbody")) + assert len(rows) == 1 + + def test_long_table(self): + n = 100 + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) + assert len(headers) == 5 + assert len(data) == 21 + # check some left values + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + def test_short_table(self): + n = 10 + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) + assert len(headers) == 5 + assert len(data) == n + # check some left values. + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + +class TestRateMapIsMapping: + def test_items(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + items = list(rate_map.items()) + assert items[0] == (0.5, 0.1) + assert items[1] == (1.5, 0.2) + assert items[2] == (2.5, 0.3) + + def test_keys(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.keys()) == [0.5, 1.5, 2.5] + + def test_values(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.values()) == [0.1, 0.2, 0.3] + + def test_in_points(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # Any point within the map are True + for x in [0, 0.5, 1, 2.9999]: + assert x in rate_map + # Points outside the map are False + for x in [-1, -0.0001, 3, 3.1]: + assert x not in rate_map + + def test_in_slices(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # slices that are within the map are "in" + for x in [slice(0, 0.5), slice(0, 1), slice(0, 2), slice(2, 3), slice(0, 3)]: + assert x in rate_map + # Any slice that doesn't fully intersect with the map "not in" + assert slice(-0.001, 1) not in rate_map + assert slice(0, 3.0001) not in rate_map + assert slice(2.9999, 3.0001) not in rate_map + assert slice(3, 4) not in rate_map + assert slice(-2, -1) not in rate_map + + def test_other_types_not_in(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + for other_type in [None, "sdf", "123", {}, [], Exception]: + assert other_type not in rate_map + + def test_len(self): + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) + assert len(rate_map) == 1 + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert len(rate_map) == 2 + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert len(rate_map) == 3 + + def test_immutable(self): + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) + with pytest.raises(TypeError, match="item assignment"): + rate_map[0] = 1 + with pytest.raises(TypeError, match="item deletion"): + del rate_map[0] + + def test_eq(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r2 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert r1 == r1 + assert r1 == r2 + r2 = tskit.RateMap(position=[0, 1, 3], rate=[0.1, 0.2]) + assert r1 != r2 + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( + position=[0, 1], rate=[0.2] + ) + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( + position=[0, 10], rate=[0.1] + ) + + def test_getitem_value(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert rate_map[0] == 0.1 + assert rate_map[0.5] == 0.1 + assert rate_map[1] == 0.2 + assert rate_map[1.5] == 0.2 + assert rate_map[1.999] == 0.2 + # Try other types + assert rate_map[np.array([1], dtype=np.float32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.int32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.float64)[0]] == 0.2 + assert rate_map[1 / 2] == 0.1 + assert rate_map[fractions.Fraction(1, 3)] == 0.1 + assert rate_map[decimal.Decimal(1)] == 0.2 + + def test_getitem_slice(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # The semantics of the slice() function are tested elsewhere. + assert r1[:] == r1.copy() + assert r1[:] is not r1 + assert r1[1:] == r1.slice(left=1) + assert r1[:1.5] == r1.slice(right=1.5) + assert r1[0.5:1.5] == r1.slice(left=0.5, right=1.5) + + def test_getitem_slice_step(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # Trying to set a "step" is a error + with pytest.raises(TypeError, match="interval slicing"): + r1[0:3:1] + + +class TestMappingMissingData: + def test_get_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0] + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0.999] + + def test_in_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert 0 not in rate_map + assert 0.999 not in rate_map + assert 1 in rate_map + + def test_keys_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert list(rate_map.keys()) == [1.5] + + +class TestGetIntermediates: + def test_get_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.all(rate_map.get_rate([0.5, 1.5]) == rates) + + def test_get_rate_out_of_bounds(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([1, -0.1]) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([2]) + + def test_get_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.allclose(rate_map.get_cumulative_mass([0.5, 1.5]), np.array([0.5, 3])) + assert rate_map.get_cumulative_mass([2]) == rate_map.total_mass + + def test_get_bad_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, -0.1]) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, 2.1]) + + +class TestSlice: + def test_slice_no_params(self): + # test RateMap.slice(..., trim=False) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice() + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + assert a == b + + def test_slice_left_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, 3], b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + b = a.slice(left=150) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + def test_slice_right_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + b = a.slice(right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 250, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + def test_slice_left_right_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 250, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=160) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 160, 400], b.position) + assert_array_equal([np.nan, 1, np.nan], b.rate) + + def test_slice_right_missing(self): + # If we take a right-slice into a trailing missing region, + # we should recover the same map. + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, np.nan]) + b = a.slice(right=350) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_left_missing(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[np.nan, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_with_floats(self): + # test RateMap.slice(..., trim=False) with floats + a = tskit.RateMap( + position=[np.pi * x for x in [0, 100, 200, 300, 400]], rate=[0, 1, 2, 3] + ) + b = a.slice(left=50 * np.pi) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50 * np.pi] + list(a.position[1:]), b.position) + assert_array_equal([np.nan] + list(a.rate), b.rate) + + def test_slice_trim_left(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(left=100, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[2, 3, 4]) + b = a.slice(left=50, trim=True) + assert b == tskit.RateMap(position=[0, 50, 150, 250, 350], rate=[1, 2, 3, 4]) + + def test_slice_trim_right(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(right=300, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[1, 2, 3]) + b = a.slice(right=350, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300, 350], rate=[1, 2, 3, 4]) + + def test_slice_error(self): + recomb_map = tskit.RateMap(position=[0, 100], rate=[1]) + with pytest.raises(KeyError): + recomb_map.slice(left=-1) + with pytest.raises(KeyError): + recomb_map.slice(right=-1) + with pytest.raises(KeyError): + recomb_map.slice(left=200) + with pytest.raises(KeyError): + recomb_map.slice(right=200) + with pytest.raises(KeyError): + recomb_map.slice(left=20, right=10) + + +class TestReadHapmap: + def test_read_hapmap_simple(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + def test_read_hapmap_from_filename(self, tmp_path): + with open(tmp_path / "hapfile.txt", "w") as hapfile: + hapfile.write( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = tskit.RateMap.read_hapmap(tmp_path / "hapfile.txt") + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + @pytest.mark.filterwarnings("ignore:loadtxt") + def test_read_hapmap_empty(self): + hapfile = io.StringIO( + """\ + HEADER""" + ) + with pytest.raises(ValueError, match="Empty"): + tskit.RateMap.read_hapmap(hapfile) + + def test_read_hapmap_col_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile, position_col=1, map_col=0) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + def test_read_hapmap_map_and_rate(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 0 0 + chr1 1 1 0.000001 x + chr1 2 2 0.000006 x x""" + ) + with pytest.raises(ValueError, match="both rate_col and map_col"): + tskit.RateMap.read_hapmap(hapfile, rate_col=2, map_col=3) + + def test_read_hapmap_duplicate_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + with pytest.raises(ValueError, match="same columns"): + tskit.RateMap.read_hapmap(hapfile, map_col=1) + + def test_read_hapmap_nonzero_rate_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 5 x + chr1 2 0 x x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [np.nan, 5e-8]) + + def test_read_hapmap_nonzero_rate_end(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 5 x + chr1 2 1 x x x""" + ) + with pytest.raises(ValueError, match="last entry.*must be zero"): + tskit.RateMap.read_hapmap(hapfile, rate_col=2) + + def test_read_hapmap_gzipped(self, tmp_path): + hapfile = os.path.join(tmp_path, "hapmap.txt.gz") + with gzip.GzipFile(hapfile, "wb") as gzfile: + gzfile.write(b"HEADER\n") + gzfile.write(b"chr1 0 1\n") + gzfile.write(b"chr1 1 5.5\n") + gzfile.write(b"chr1 2 0\n") + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [1e-8, 5.5e-8]) + + def test_read_hapmap_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0.000001 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [1e-8, 0, 5e-8]) + + def test_read_hapmap_bad_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0.0000005 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="start.*must be zero"): + tskit.RateMap.read_hapmap(hapfile) + + def test_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + # test identical seq len + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + hapfile.seek(0) + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=10) + assert_array_equal(rm.position, [0, 1, 2, 10]) + assert np.allclose(rm.rate, [1e-8, 5e-8, np.nan], equal_nan=True) + + def test_bad_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="sequence_length"): + tskit.RateMap.read_hapmap(hapfile, sequence_length=1.999) + + def test_no_header(self): + data = """\ + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + hapfile_noheader = io.StringIO(data) + hapfile_header = io.StringIO("chr pos rate cM\n" + data) + with pytest.raises(ValueError): + tskit.RateMap.read_hapmap(hapfile_header, has_header=False) + rm1 = tskit.RateMap.read_hapmap(hapfile_header) + rm2 = tskit.RateMap.read_hapmap(hapfile_noheader, has_header=False) + assert_array_equal(rm1.rate, rm2.rate) + assert_array_equal(rm1.position, rm2.position) + + def test_hapmap_fragment(self): + hapfile = io.StringIO( + """\ + chr pos rate cM + 1 4283592 3.79115663174456 0 + 1 4361401 0.0664276817058413 0.294986106359414 + 1 7979763 10.9082897515584 0.535345505591925 + 1 8007051 0.0976780648822495 0.833010916332456 + 1 8762788 0.0899929572085616 0.906829844052373 + 1 9477943 0.0864382908650907 0.971188757364862 + 1 9696341 4.76495005895746 0.990066707213216 + 1 9752154 0.0864316558730679 1.25601286485381 + 1 9881751 0.0 1.26721414815999""" + ) + rm1 = tskit.RateMap.read_hapmap(hapfile) + hapfile.seek(0) + rm2 = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert np.allclose(rm1.position, rm2.position) + assert np.allclose(rm1.rate, rm2.rate, equal_nan=True) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8133c9d7eb..5e5c5b2272 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -347,9 +347,32 @@ def test_simplify_bad_args(self): tc.simplify([0, 1], keep_input_roots="sdf") with pytest.raises(TypeError): tc.simplify([0, 1], filter_populations="x") + with pytest.raises(TypeError): + tc.simplify([0, 1], filter_nodes="x") + with pytest.raises(TypeError): + tc.simplify([0, 1], update_sample_flags="x") with pytest.raises(_tskit.LibraryError): tc.simplify([0, -1]) + @pytest.mark.parametrize("value", [True, False]) + @pytest.mark.parametrize( + "flag", + [ + "filter_sites", + "filter_populations", + "filter_individuals", + "filter_nodes", + "update_sample_flags", + "reduce_to_site_topology", + "keep_unary", + "keep_unary_in_individuals", + "keep_input_roots", + ], + ) + def test_simplify_flags(self, flag, value): + tables = _tskit.TableCollection(1) + tables.simplify([], **{flag: value}) + def test_link_ancestors_bad_args(self): ts = msprime.simulate(10, random_seed=1) tc = ts.tables._ll_tables @@ -740,6 +763,54 @@ def test_table_extend_types( for i, expected_row in enumerate(expected_rows): assert table[len(table_copy) + i] == table_copy[expected_row] + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_errors(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + with pytest.raises(ValueError, match="must be of length"): + ll_table.keep_rows(np.ones(n - 1, dtype=bool)) + with pytest.raises(ValueError, match="must be of length"): + ll_table.keep_rows(np.ones(n + 1, dtype=bool)) + with pytest.raises(TypeError, match="Cannot cast"): + ll_table.keep_rows(np.ones(n, dtype=int)) + + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_all(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + a = ll_table.keep_rows(np.ones(n, dtype=bool)) + assert ll_table.num_rows == n + assert a.shape == (n,) + assert a.dtype == np.int32 + assert np.all(a == np.arange(n)) + + @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) + def test_table_keep_rows_none(self, table_name, ts_fixture): + table = getattr(ts_fixture.tables, table_name) + n = len(table) + ll_table = table.ll_table + a = ll_table.keep_rows(np.zeros(n, dtype=bool)) + assert ll_table.num_rows == 0 + assert a.shape == (n,) + assert a.dtype == np.int32 + assert np.all(a == -1) + + def test_mutation_table_keep_rows_ref_error(self): + table = _tskit.MutationTable() + table.add_row(site=0, node=0, derived_state="A", parent=2) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_MUTATION_OUT_OF_BOUNDS"): + table.keep_rows([True]) + + def test_individual_table_keep_rows_ref_error(self): + table = _tskit.IndividualTable() + table.add_row(parents=[2]) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS" + ): + table.keep_rows([True]) + @pytest.mark.parametrize( ["table_name", "column_name"], [ @@ -1425,6 +1496,13 @@ def test_time_units(self): ts.load_tables(tables) assert ts.get_time_units() == value + def test_extend_edges_bad_args(self): + ts1 = self.get_example_tree_sequence(10) + with pytest.raises(TypeError): + ts1.extend_edges() + with pytest.raises(TypeError, match="as an int"): + ts1.extend_edges("sdf") + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): @@ -1458,6 +1536,26 @@ def test_kc_distance(self): x2 = ts2.get_kc_distance(ts1, lambda_) assert x1 == x2 + def test_divergence_matrix(self): + n = 10 + ts = self.get_example_tree_sequence(n, random_seed=12) + D = ts.divergence_matrix([0, ts.get_sequence_length()]) + assert D.shape == (1, n, n) + D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1]) + assert D.shape == (1, 2, 2) + with pytest.raises(TypeError): + ts.divergence_matrix(windoze=[0, 1]) + with pytest.raises(ValueError, match="at least 2"): + ts.divergence_matrix(windows=[0]) + with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): + ts.divergence_matrix(windows=[-1, 0, 1]) + with pytest.raises(ValueError): + ts.divergence_matrix(windows=[0, 1], samples="sdf") + with pytest.raises(ValueError, match="Unrecognised stats mode"): + ts.divergence_matrix(windows=[0, 1], mode="sdf") + with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): + ts.divergence_matrix(windows=[0, 1], mode="node") + def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) @@ -1663,6 +1761,15 @@ def test_window_errors(self): with pytest.raises(_tskit.LibraryError): f(windows=bad_window, **params) + def test_polarisation(self): + ts, f, params = self.get_example() + with pytest.raises(TypeError): + f(polarised="sdf", **params) + x1 = f(polarised=False, **params) + x2 = f(polarised=True, **params) + # Basic check just to run both code paths + assert x1.shape == x2.shape + def test_windows_output(self): ts, f, params = self.get_example() del params["windows"] @@ -2025,6 +2132,78 @@ def f(indexes): f(bad_dim) +class TwoWayWeightedStatsMixin(StatsInterfaceMixin): + """ + Tests for the weighted two way sample stats. + """ + + def get_example(self): + ts, method = self.get_method() + params = { + "weights": np.zeros((ts.get_num_samples(), 2)) + 0.5, + "indexes": [[0, 1]], + "windows": [0, ts.get_sequence_length()], + } + return ts, method, params + + def test_basic_example(self): + ts, method = self.get_method() + div = method( + np.zeros((ts.get_num_samples(), 1)) + 0.5, + [[0, 1]], + windows=[0, ts.get_sequence_length()], + ) + assert div.shape == (1, 1) + + def test_bad_weights(self): + ts, f, params = self.get_example() + del params["weights"] + n = ts.get_num_samples() + + for bad_weight_type in [None, [None, None]]: + with pytest.raises(ValueError, match="object of too small depth"): + f(weights=bad_weight_type, **params) + + for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]: + with pytest.raises(ValueError, match="First dimension must be num_samples"): + f(weights=np.ones(bad_weight_shape), **params) + + def test_output_dims(self): + ts, method, params = self.get_example() + weights = params.pop("weights") + params["windows"] = [0, ts.get_sequence_length()] + + for mode in ["site", "branch"]: + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, 1) + mode = "node" + N = ts.get_num_nodes() + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, N, 1) + + def test_set_index_errors(self): + ts, method, params = self.get_example() + del params["indexes"] + + def f(indexes): + method(indexes=indexes, **params) + + for bad_array in ["wer", {}, [[[], []], [[], []]]]: + with pytest.raises(ValueError): + f(bad_array) + for bad_dim in [[[]], [[1], [1]]]: + with pytest.raises(ValueError): + f(bad_dim) + + class ThreeWaySampleStatsMixin(SampleSetMixin): """ Tests for the two way sample stats. @@ -2211,6 +2390,12 @@ def get_method(self): return ts, ts.f2 +class TestGeneticRelatedness(LowLevelTestCase, TwoWaySampleStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness + + class TestY3(LowLevelTestCase, ThreeWaySampleStatsMixin): def get_method(self): ts = self.get_example_tree_sequence() @@ -2229,6 +2414,12 @@ def get_method(self): return ts, ts.f4 +class TestWeightedGeneticRelatedness(LowLevelTestCase, TwoWayWeightedStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness_weighted + + class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin): """ Tests for the general stats interface. @@ -2897,6 +3088,16 @@ def test_seek_errors(self): with pytest.raises(_tskit.LibraryError): tree.seek(bad_pos) + def test_seek_index_errors(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + for bad_type in ["", "x", {}]: + with pytest.raises(TypeError): + tree.seek_index(bad_type) + for bad_index in [-1, 10**6]: + with pytest.raises(_tskit.LibraryError): + tree.seek_index(bad_index) + def test_root_threshold(self): for ts in self.get_example_tree_sequences(): tree = _tskit.Tree(ts) @@ -3792,7 +3993,7 @@ def test_kastore_version(self): def test_tskit_version(self): version = _tskit.get_tskit_version() - assert version == (1, 1, 1) + assert version == (1, 1, 2) def test_tskit_version_file(self): maj, min_, patch = _tskit.get_tskit_version() diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index dae69c6f85..265fee48b2 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -140,6 +140,14 @@ def table_5row(self, test_rows): table_5row.add_row(**row) return table_5row + def test_asdict(self, table, test_rows): + for table_row, test_row in zip(table, test_rows): + for k, v in table_row.asdict().items(): + if isinstance(v, np.ndarray): + assert np.array_equal(v, test_row[k]) + else: + assert v == test_row[k] + def test_max_rows_increment(self): for bad_value in [-1, -(2**10)]: with pytest.raises(ValueError): @@ -625,6 +633,18 @@ def test_append_columns_max_rows(self): else: assert table.max_rows == max(max_rows + 1, table.num_rows) + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + t1.append_columns(**input_data) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + def test_str(self): for num_rows in [0, 10]: input_data = self.make_input_data(num_rows) @@ -1729,6 +1749,21 @@ def test_various_not_equals(self): a = tskit.MutationTableRow(**args) assert a == b + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + # Set the parent column to -1s for this simple test as + # we need to reason about reference integrity + t1.append_columns(**input_data) + t1.parents = np.full_like(t1.parents, -1) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + class TestNodeTable(*common_tests): @@ -1992,6 +2027,21 @@ def test_packset_derived_state(self): assert np.array_equal(table.derived_state, derived_state) assert np.array_equal(table.derived_state_offset, derived_state_offset) + def test_keep_rows_data(self): + input_data = self.make_input_data(100) + t1 = self.table_class() + # Set the parent column to -1s for this simple test as + # we need to reason about reference integrity + t1.append_columns(**input_data) + t1.parent = np.full_like(t1.parent, -1) + t2 = t1.copy() + keep = np.ones(len(t1), dtype=bool) + # Only keep even + keep[::2] = 0 + t1.keep_rows(keep) + keep_rows_definition(t2, keep) + assert t1.equals(t2) + class TestMigrationTable(*common_tests): columns = [ @@ -3029,8 +3079,8 @@ def test_full_samples(self): def test_bad_samples(self): n = 10 ts = msprime.simulate(n, random_seed=self.random_seed) - tables = ts.dump_tables() - for bad_node in [-1, n, n + 1, ts.num_nodes - 1, ts.num_nodes, 2**31 - 1]: + for bad_node in [-1, ts.num_nodes, 2**31 - 1]: + tables = ts.dump_tables() with pytest.raises(_tskit.LibraryError): tables.simplify(samples=[0, bad_node]) @@ -5011,3 +5061,362 @@ def test_setitem_metadata(self, ts_fixture, table_name): assert table[0].metadata != table[1].metadata table[0] = table[1] assert table[0] == table[1] + + +def keep_rows_definition(table, keep): + id_map = np.full(len(table), -1, np.int32) + copy = table.copy() + table.clear() + for j, row in enumerate(copy): + if keep[j]: + id_map[j] = len(table) + table.append(row) + return id_map + + +class KeepRowsBaseTest: + # Simple tests assuming that rows aren't self-referential + + def test_keep_all(self, ts_fixture): + table = self.get_table(ts_fixture) + before = table.copy() + table.keep_rows(np.ones(len(table), dtype=bool)) + assert table.equals(before) + + def test_keep_none(self, ts_fixture): + table = self.get_table(ts_fixture) + table.keep_rows(np.zeros(len(table), dtype=bool)) + assert len(table) == 0 + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_keep_even(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.ones(len(table), dtype=bool) + keep[1::2] = 0 + self.check_keep_rows(table, keep) + + def test_keep_odd(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.ones(len(table), dtype=bool) + keep[::2] = 0 + self.check_keep_rows(table, keep) + + def test_keep_first(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=bool) + keep[0] = 1 + self.check_keep_rows(table, keep) + assert len(table) == 1 + + def test_keep_last(self, ts_fixture): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=bool) + keep[-1] = 1 + self.check_keep_rows(table, keep) + assert len(table) == 1 + + @pytest.mark.parametrize("dtype", [np.int32, int, np.float32]) + def test_bad_array_dtype(self, ts_fixture, dtype): + table = self.get_table(ts_fixture) + keep = np.zeros(len(table), dtype=dtype) + with pytest.raises(TypeError, match="Cannot cast array"): + table.keep_rows(keep) + + @pytest.mark.parametrize("truthy", [False, 0, "", None]) + def test_python_falsey_input(self, ts_fixture, truthy): + table = self.get_table(ts_fixture) + keep = [truthy] * len(table) + self.check_keep_rows(table, keep) + assert len(table) == 0 + + @pytest.mark.parametrize("truthy", [True, 1, "string", 1e-6]) + def test_python_truey_input(self, ts_fixture, truthy): + table = self.get_table(ts_fixture) + n = len(table) + keep = [truthy] * len(table) + self.check_keep_rows(table, keep) + assert len(table) == n + + @pytest.mark.parametrize("offset", [-1, 1, 100]) + def test_bad_length(self, ts_fixture, offset): + table = self.get_table(ts_fixture) + keep = [True] * (len(table) + offset) + match_str = f"need:{len(table)}, got:{len(table) + offset}" + with pytest.raises(ValueError, match=match_str): + table.keep_rows(keep) + + @pytest.mark.parametrize("bad_type", [False, 0, None]) + def test_non_list_input(self, ts_fixture, bad_type): + table = self.get_table(ts_fixture) + with pytest.raises(TypeError, match="has no len"): + table.keep_rows(bad_type) + + +class TestNodeTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().nodes + + +class TestEdgeTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().edges + + +class TestSiteTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().sites + + +class TestMigrationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().migrations + + +class TestPopulationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().populations + + +class TestProvenanceTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + return ts.dump_tables().provenances + + +# Null out the self-referential columns (this is why the tests are structed via +# classes rather than pytest parametrize. + + +class TestIndividualTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + table = ts.dump_tables().individuals + table.parents = np.zeros_like(table.parents) - 1 + return table + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + for j, row in enumerate(copy): + parents = [p if p == tskit.NULL else id_map1[p] for p in row.parents] + copy[j] = row.replace(parents=parents) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_delete_unreferenced(self, ts_fixture): + table = ts_fixture.dump_tables().individuals + ref_count = np.zeros(len(table)) + for row in table: + for parent in row.parents: + ref_count[parent] += 1 + self.check_keep_rows(table, ref_count > 0) + + +class TestMutationTableKeepRows(KeepRowsBaseTest): + def get_table(self, ts): + table = ts.dump_tables().mutations + table.parent = np.zeros_like(table.parent) - 1 + return table + + def check_keep_rows(self, table, keep): + copy = table.copy() + id_map1 = keep_rows_definition(copy, keep) + for j, row in enumerate(copy): + if row.parent != tskit.NULL: + copy[j] = row.replace(parent=id_map1[row.parent]) + id_map2 = table.keep_rows(keep) + table.assert_equals(copy) + np.testing.assert_array_equal(id_map1, id_map2) + + def test_delete_unreferenced(self, ts_fixture): + table = ts_fixture.dump_tables().mutations + parent = table.parent.copy() + parent[parent == tskit.NULL] = len(table) + references = np.bincount(parent) + self.check_keep_rows(table, references[:-1] > 0) + + def test_error_on_bad_ids(self, ts_fixture): + table = ts_fixture.dump_tables().mutations + table.add_row(site=0, node=0, derived_state="A", parent=10000) + before = table.copy() + with pytest.raises(tskit.LibraryError, match="TSK_ERR_MUTATION_OUT_OF_BOUNDS"): + table.keep_rows(np.ones(len(table), dtype=bool)) + table.assert_equals(before) + + +class TestKeepRowsExamples: + """ + Some examples of how to use the keep_rows method in an idiomatic + and efficient way. + + TODO these should be converted into documentation examples when we + write an "examples" section for table editing. + """ + + def test_detach_subtree(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + tables.edges.keep_rows(tables.edges.child != 3) + + # 2.00┊ 4 ┊ + # ┊ ┃ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 4, 1: 3, 2: 3} + + def test_delete_older_edges(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + tables.edges.keep_rows(tables.nodes.time[tables.edges.parent] <= 1) + + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {1: 3, 2: 3} + + def test_delete_unreferenced_nodes(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + tables = ts.dump_tables() + edges = tables.edges + nodes = tables.nodes + edges.keep_rows(nodes.time[edges.parent] <= 1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ref_count = np.bincount(edges.child, minlength=len(nodes)) + ref_count += np.bincount(edges.parent, minlength=len(nodes)) + assert list(ref_count) == [0, 1, 1, 2, 0] + id_map = nodes.keep_rows(ref_count > 0) + assert list(id_map) == [-1, 0, 1, 2, -1] + assert len(nodes) == 3 + # Remap the edges IDs + edges.child = id_map[edges.child] + edges.parent = id_map[edges.parent] + ts = tables.tree_sequence() + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 2, 1: 2} + + def test_mutation_ids_auto_remapped(self): + mutations = tskit.MutationTable() + # Add 5 initial rows with no parents + for j in range(5): + mutations.add_row(site=j, node=j, derived_state=f"{j}") + # Now 5 more in a chain + last = -1 + for j in range(5): + last = mutations.add_row( + site=10 + j, node=10 + j, parent=last, derived_state=f"{j}" + ) + + # ╔══╤════╤════╤════╤═════════════╤══════╤════════╗ + # ║id│site│node│time│derived_state│parent│metadata║ + # ╠══╪════╪════╪════╪═════════════╪══════╪════════╣ + # ║0 │ 0│ 0│ nan│ 0│ -1│ ║ + # ║1 │ 1│ 1│ nan│ 1│ -1│ ║ + # ║2 │ 2│ 2│ nan│ 2│ -1│ ║ + # ║3 │ 3│ 3│ nan│ 3│ -1│ ║ + # ║4 │ 4│ 4│ nan│ 4│ -1│ ║ + # ║5 │ 10│ 10│ nan│ 0│ -1│ ║ + # ║6 │ 11│ 11│ nan│ 1│ 5│ ║ + # ║7 │ 12│ 12│ nan│ 2│ 6│ ║ + # ║8 │ 13│ 13│ nan│ 3│ 7│ ║ + # ║9 │ 14│ 14│ nan│ 4│ 8│ ║ + # ╚══╧════╧════╧════╧═════════════╧══════╧════════╝ + + keep = np.ones(len(mutations), dtype=bool) + keep[:5] = False + mutations.keep_rows(keep) + + # ╔══╤════╤════╤════╤═════════════╤══════╤════════╗ + # ║id│site│node│time│derived_state│parent│metadata║ + # ╠══╪════╪════╪════╪═════════════╪══════╪════════╣ + # ║0 │ 10│ 10│ nan│ 0│ -1│ ║ + # ║1 │ 11│ 11│ nan│ 1│ 0│ ║ + # ║2 │ 12│ 12│ nan│ 2│ 1│ ║ + # ║3 │ 13│ 13│ nan│ 3│ 2│ ║ + # ║4 │ 14│ 14│ nan│ 4│ 3│ ║ + # ╚══╧════╧════╧════╧═════════════╧══════╧════════╝ + assert list(mutations.site) == [10, 11, 12, 13, 14] + assert list(mutations.node) == [10, 11, 12, 13, 14] + assert list(mutations.parent) == [-1, 0, 1, 2, 3] + + def test_individual_ids_auto_remapped(self): + individuals = tskit.IndividualTable() + # Add some rows with missing parents in different forms + individuals.add_row() + individuals.add_row(parents=[-1]) + individuals.add_row(parents=[-1, -1]) + # Now 5 more in a chain + last = -1 + for _ in range(5): + last = individuals.add_row(parents=[last]) + last = individuals.add_row(parents=[last, last]) + + # ╔══╤═════╤════════╤═══════╤════════╗ + # ║id│flags│location│parents│metadata║ + # ╠══╪═════╪════════╪═══════╪════════╣ + # ║0 │ 0│ │ │ ║ + # ║1 │ 0│ │ -1│ ║ + # ║2 │ 0│ │ -1, -1│ ║ + # ║3 │ 0│ │ -1│ ║ + # ║4 │ 0│ │ 3│ ║ + # ║5 │ 0│ │ 4│ ║ + # ║6 │ 0│ │ 5│ ║ + # ║7 │ 0│ │ 6│ ║ + # ║8 │ 0│ │ 7, 7│ ║ + # ╚══╧═════╧════════╧═══════╧════════╝ + + keep = np.ones(len(individuals), dtype=bool) + # Only delete one row + keep[1] = False + individuals.keep_rows(keep) + + # ╔══╤═════╤════════╤═══════╤════════╗ + # ║id│flags│location│parents│metadata║ + # ╠══╪═════╪════════╪═══════╪════════╣ + # ║0 │ 0│ │ │ ║ + # ║1 │ 0│ │ -1, -1│ ║ + # ║2 │ 0│ │ -1│ ║ + # ║3 │ 0│ │ 2│ ║ + # ║4 │ 0│ │ 3│ ║ + # ║5 │ 0│ │ 4│ ║ + # ║6 │ 0│ │ 5│ ║ + # ║7 │ 0│ │ 6, 6│ ║ + # ╚══╧═════╧════════╧═══════╧════════╝ + parents = [list(ind.parents) for ind in individuals] + assert parents == [[], [-1, -1], [-1], [2], [3], [4], [5], [6, 6]] diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 45c4dc9019..277aa1dd16 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -2641,8 +2641,10 @@ class TestSimplifyExamples(TopologyTestCase): def verify_simplify( self, samples, + *, filter_sites=True, keep_input_roots=False, + filter_nodes=True, nodes_before=None, edges_before=None, sites_before=None, @@ -2657,7 +2659,7 @@ def verify_simplify( Verifies that if we run simplify on the specified input we get the required output. """ - ts = tskit.load_text( + before = tskit.load_text( nodes=io.StringIO(nodes_before), edges=io.StringIO(edges_before), sites=io.StringIO(sites_before) if sites_before is not None else None, @@ -2666,9 +2668,8 @@ def verify_simplify( ), strict=False, ) - before = ts.dump_tables() - ts = tskit.load_text( + after = tskit.load_text( nodes=io.StringIO(nodes_after), edges=io.StringIO(edges_after), sites=io.StringIO(sites_after) if sites_after is not None else None, @@ -2678,23 +2679,26 @@ def verify_simplify( strict=False, sequence_length=before.sequence_length, ) - after = ts.dump_tables() - # Make sure it's a valid tree sequence - ts = before.tree_sequence() - before.simplify( + + result, _ = do_simplify( + before, samples=samples, filter_sites=filter_sites, keep_input_roots=keep_input_roots, - record_provenance=False, + filter_nodes=filter_nodes, + compare_lib=True, ) if debug: print("before") print(before) - print(before.tree_sequence().draw_text()) + print(before.draw_text()) print("after") print(after) - print(after.tree_sequence().draw_text()) - assert before == after + print(after.draw_text()) + print("result") + print(result) + print(result.draw_text()) + after.tables.assert_equals(result.tables) def test_unsorted_edges(self): # We have two nodes at the same time and interleave edges for @@ -3250,6 +3254,38 @@ def test_unary_edges_no_overlap_internal_sample(self): edges_after=edges_before, ) + def test_keep_nodes(self): + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 2 + 4 0 3 + """ + edges_before = """\ + left right parent child + 0 1 2 0 + 0 1 2 1 + 0 1 3 2 + 0 1 4 3 + """ + edges_after = """\ + left right parent child + 0 1 2 0 + 0 1 2 1 + 0 1 4 2 + """ + self.verify_simplify( + samples=[0, 1], + nodes_before=nodes_before, + edges_before=edges_before, + nodes_after=nodes_before, + edges_after=edges_after, + filter_nodes=False, + keep_input_roots=True, + ) + class TestNonSampleExternalNodes(TopologyTestCase): """ @@ -4711,74 +4747,59 @@ def test_kwargs(self): assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4 -class SimplifyTestBase: +def do_simplify( + ts, + samples=None, + compare_lib=True, + filter_sites=True, + filter_populations=True, + filter_individuals=True, + filter_nodes=True, + keep_unary=False, + keep_input_roots=False, + update_sample_flags=True, +): """ - Base class for simplify tests. + Runs the Python test implementation of simplify. """ - - def do_simplify( - self, + if samples is None: + samples = ts.samples() + s = tests.Simplifier( ts, - samples=None, - compare_lib=True, - filter_sites=True, - filter_populations=True, - filter_individuals=True, - keep_unary=False, - keep_input_roots=False, - ): - """ - Runs the Python test implementation of simplify. - """ - if samples is None: - samples = ts.samples() - s = tests.Simplifier( - ts, + samples, + filter_sites=filter_sites, + filter_populations=filter_populations, + filter_individuals=filter_individuals, + filter_nodes=filter_nodes, + keep_unary=keep_unary, + keep_input_roots=keep_input_roots, + update_sample_flags=update_sample_flags, + ) + new_ts, node_map = s.simplify() + if compare_lib: + sts, lib_node_map1 = ts.simplify( samples, filter_sites=filter_sites, - filter_populations=filter_populations, filter_individuals=filter_individuals, + filter_populations=filter_populations, + filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, keep_unary=keep_unary, keep_input_roots=keep_input_roots, + map_nodes=True, ) - new_ts, node_map = s.simplify() - if compare_lib: - sts, lib_node_map1 = ts.simplify( - samples, - filter_sites=filter_sites, - filter_individuals=filter_individuals, - filter_populations=filter_populations, - keep_unary=keep_unary, - keep_input_roots=keep_input_roots, - map_nodes=True, - ) - lib_tables1 = sts.dump_tables() - - lib_tables2 = ts.dump_tables() - lib_node_map2 = lib_tables2.simplify( - samples, - filter_sites=filter_sites, - keep_unary=keep_unary, - keep_input_roots=keep_input_roots, - filter_individuals=filter_individuals, - filter_populations=filter_populations, - ) + lib_tables1 = sts.dump_tables() - py_tables = new_ts.dump_tables() - for lib_tables, lib_node_map in [ - (lib_tables1, lib_node_map1), - (lib_tables2, lib_node_map2), - ]: + py_tables = new_ts.dump_tables() + py_tables.assert_equals(lib_tables1, ignore_provenance=True) + assert all(node_map == lib_node_map1) + return new_ts, node_map - assert lib_tables.nodes == py_tables.nodes - assert lib_tables.edges == py_tables.edges - assert lib_tables.migrations == py_tables.migrations - assert lib_tables.sites == py_tables.sites - assert lib_tables.mutations == py_tables.mutations - assert lib_tables.individuals == py_tables.individuals - assert lib_tables.populations == py_tables.populations - assert all(node_map == lib_node_map) - return new_ts, node_map + +class SimplifyTestBase: + """ + Base class for simplify tests. + """ class TestSimplify(SimplifyTestBase): @@ -4824,11 +4845,9 @@ def verify_no_samples(self, ts, keep_unary=False): """ t1 = ts.dump_tables() t1.nodes.flags = np.zeros_like(t1.nodes.flags) - ts1, node_map1 = self.do_simplify( - ts, samples=ts.samples(), keep_unary=keep_unary - ) + ts1, node_map1 = do_simplify(ts, samples=ts.samples(), keep_unary=keep_unary) t1 = ts1.dump_tables() - ts2, node_map2 = self.do_simplify(ts, keep_unary=keep_unary) + ts2, node_map2 = do_simplify(ts, keep_unary=keep_unary) t2 = ts2.dump_tables() t1.assert_equals(t2) @@ -4841,7 +4860,7 @@ def verify_single_childified(self, ts, keep_unary=False): """ ts_single = tsutil.single_childify(ts) - tss, node_map = self.do_simplify(ts_single, keep_unary=keep_unary) + tss, node_map = do_simplify(ts_single, keep_unary=keep_unary) # All original nodes should still be present. for u in range(ts.num_samples): assert u == node_map[u] @@ -4865,7 +4884,7 @@ def verify_single_childified(self, ts, keep_unary=False): def verify_multiroot_internal_samples(self, ts, keep_unary=False): ts_multiroot = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts1 = tsutil.jiggle_samples(ts_multiroot) - ts2, node_map = self.do_simplify(ts1, keep_unary=keep_unary) + ts2, node_map = do_simplify(ts1, keep_unary=keep_unary) assert ts1.num_trees >= ts2.num_trees trees2 = ts2.trees() t2 = next(trees2) @@ -4897,10 +4916,10 @@ def test_single_tree(self): def test_single_tree_mutations(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=self.random_seed) assert ts.num_sites > 1 - self.do_simplify(ts) + do_simplify(ts) self.verify_single_childified(ts) # Also with keep_unary == True. - self.do_simplify(ts, keep_unary=True) + do_simplify(ts, keep_unary=True) self.verify_single_childified(ts, keep_unary=True) def test_many_trees_mutations(self): @@ -4910,10 +4929,10 @@ def test_many_trees_mutations(self): assert ts.num_trees > 2 assert ts.num_sites > 2 self.verify_no_samples(ts) - self.do_simplify(ts) + do_simplify(ts) self.verify_single_childified(ts) # Also with keep_unary == True. - self.do_simplify(ts, keep_unary=True) + do_simplify(ts, keep_unary=True) self.verify_single_childified(ts, keep_unary=True) def test_many_trees(self): @@ -4944,14 +4963,14 @@ def test_small_tree_internal_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 5 - tss, node_map = self.do_simplify(ts, [3, 5]) + tss, node_map = do_simplify(ts, [3, 5]) assert node_map[3] == 0 assert node_map[5] == 1 assert tss.num_nodes == 3 assert tss.num_edges == 2 self.verify_no_samples(ts) # with keep_unary == True - tss, node_map = self.do_simplify(ts, [3, 5], keep_unary=True) + tss, node_map = do_simplify(ts, [3, 5], keep_unary=True) assert node_map[3] == 0 assert node_map[5] == 1 assert tss.num_nodes == 5 @@ -4974,7 +4993,7 @@ def test_small_tree_linear_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 2 - tss, node_map = self.do_simplify(ts, [0, 7]) + tss, node_map = do_simplify(ts, [0, 7]) assert node_map[0] == 0 assert node_map[7] == 1 assert tss.num_nodes == 2 @@ -4982,7 +5001,7 @@ def test_small_tree_linear_samples(self): t = next(tss.trees()) assert t.parent_dict == {0: 1} # with keep_unary == True - tss, node_map = self.do_simplify(ts, [0, 7], keep_unary=True) + tss, node_map = do_simplify(ts, [0, 7], keep_unary=True) assert node_map[0] == 0 assert node_map[7] == 1 assert tss.num_nodes == 4 @@ -5006,7 +5025,7 @@ def test_small_tree_internal_and_external_samples(self): nodes.flags = flags ts = tables.tree_sequence() assert ts.sample_size == 3 - tss, node_map = self.do_simplify(ts, [0, 1, 7]) + tss, node_map = do_simplify(ts, [0, 1, 7]) assert node_map[0] == 0 assert node_map[1] == 1 assert node_map[7] == 2 @@ -5015,7 +5034,7 @@ def test_small_tree_internal_and_external_samples(self): t = next(tss.trees()) assert t.parent_dict == {0: 3, 1: 3, 3: 2} # with keep_unary == True - tss, node_map = self.do_simplify(ts, [0, 1, 7], keep_unary=True) + tss, node_map = do_simplify(ts, [0, 1, 7], keep_unary=True) assert node_map[0] == 0 assert node_map[1] == 1 assert node_map[7] == 2 @@ -5044,7 +5063,7 @@ def test_small_tree_mutations(self): assert ts.num_sites == 4 assert ts.num_mutations == 4 for keep in [True, False]: - tss = self.do_simplify(ts, [0, 2], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 2], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_mutations == 4 assert list(tss.haplotypes()) == ["1011", "0100"] @@ -5059,12 +5078,10 @@ def test_small_tree_filter_zero_mutations(self): assert ts.num_sites == 8 assert ts.num_mutations == 8 for keep in [True, False]: - tss, _ = self.do_simplify(ts, [4, 0, 1], filter_sites=True, keep_unary=keep) + tss, _ = do_simplify(ts, [4, 0, 1], filter_sites=True, keep_unary=keep) assert tss.num_sites == 5 assert tss.num_mutations == 5 - tss, _ = self.do_simplify( - ts, [4, 0, 1], filter_sites=False, keep_unary=keep - ) + tss, _ = do_simplify(ts, [4, 0, 1], filter_sites=False, keep_unary=keep) assert tss.num_sites == 8 assert tss.num_mutations == 5 @@ -5086,7 +5103,7 @@ def test_small_tree_fixed_sites(self): assert ts.num_sites == 3 assert ts.num_mutations == 3 for keep in [True, False]: - tss, _ = self.do_simplify(ts, [4, 1], keep_unary=keep) + tss, _ = do_simplify(ts, [4, 1], keep_unary=keep) assert tss.sample_size == 2 assert tss.num_mutations == 0 assert list(tss.haplotypes()) == ["", ""] @@ -5104,7 +5121,7 @@ def test_small_tree_mutations_over_root(self): assert ts.num_sites == 1 assert ts.num_mutations == 1 for keep_unary, filter_sites in itertools.product([True, False], repeat=2): - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, [0, 1], filter_sites=filter_sites, keep_unary=keep_unary ) assert tss.num_sites == 1 @@ -5125,7 +5142,7 @@ def test_small_tree_recurrent_mutations(self): assert ts.num_sites == 1 assert ts.num_mutations == 2 for keep in [True, False]: - tss = self.do_simplify(ts, [4, 3], keep_unary=keep)[0] + tss = do_simplify(ts, [4, 3], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_sites == 1 assert tss.num_mutations == 2 @@ -5149,7 +5166,7 @@ def test_small_tree_back_mutations(self): assert list(ts.haplotypes()) == ["0", "1", "0", "0", "1"] # First check if we simplify for all samples and keep original state. for keep in [True, False]: - tss = self.do_simplify(ts, [0, 1, 2, 3, 4], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 1, 2, 3, 4], keep_unary=keep)[0] assert tss.sample_size == 5 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5157,7 +5174,7 @@ def test_small_tree_back_mutations(self): # The ancestral state above 5 should be 0. for keep in [True, False]: - tss = self.do_simplify(ts, [0, 1], keep_unary=keep)[0] + tss = do_simplify(ts, [0, 1], keep_unary=keep)[0] assert tss.sample_size == 2 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5165,7 +5182,7 @@ def test_small_tree_back_mutations(self): # The ancestral state above 7 should be 1. for keep in [True, False]: - tss = self.do_simplify(ts, [4, 0, 1], keep_unary=keep)[0] + tss = do_simplify(ts, [4, 0, 1], keep_unary=keep)[0] assert tss.sample_size == 3 assert tss.num_sites == 1 assert tss.num_mutations == 3 @@ -5192,7 +5209,7 @@ def test_overlapping_unary_edges(self): assert ts.num_trees == 3 assert ts.sequence_length == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, samples=[0, 1, 2], keep_unary=keep) + tss, node_map = do_simplify(ts, samples=[0, 1, 2], keep_unary=keep) assert list(node_map) == [0, 1, 2] trees = [{0: 2}, {0: 2, 1: 2}, {1: 2}] for t in tss.trees(): @@ -5220,7 +5237,7 @@ def test_overlapping_unary_edges_internal_samples(self): trees = [{0: 2}, {0: 2, 1: 2}, {1: 2}] for t in ts.trees(): assert t.parent_dict == trees[t.index] - tss, node_map = self.do_simplify(ts) + tss, node_map = do_simplify(ts) assert list(node_map) == [0, 1, 2] def test_isolated_samples(self): @@ -5242,7 +5259,7 @@ def test_isolated_samples(self): assert ts.num_trees == 1 assert ts.num_nodes == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert ts.tables.nodes == tss.tables.nodes assert ts.tables.edges == tss.tables.edges assert list(node_map) == [0, 1, 2] @@ -5275,7 +5292,7 @@ def test_internal_samples(self): ) ts = tskit.load_text(nodes, edges, strict=False) - tss, node_map = self.do_simplify(ts, [5, 2, 0]) + tss, node_map = do_simplify(ts, [5, 2, 0]) assert node_map[0] == 2 assert node_map[1] == -1 assert node_map[2] == 1 @@ -5290,7 +5307,7 @@ def test_internal_samples(self): for t in tss.trees(): assert t.parent_dict == trees[t.index] # with keep_unary == True - tss, node_map = self.do_simplify(ts, [5, 2, 0], keep_unary=True) + tss, node_map = do_simplify(ts, [5, 2, 0], keep_unary=True) assert node_map[0] == 2 assert node_map[1] == 4 assert node_map[2] == 1 @@ -5343,7 +5360,7 @@ def test_many_mutations_over_single_sample_ancestral_state(self): assert ts.num_sites == 1 assert ts.num_mutations == 2 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert tss.num_sites == 1 assert tss.num_mutations == 2 assert list(tss.haplotypes(isolated_as_missing=False)) == ["0"] @@ -5384,7 +5401,7 @@ def test_many_mutations_over_single_sample_derived_state(self): assert ts.num_sites == 1 assert ts.num_mutations == 3 for keep in [True, False]: - tss, node_map = self.do_simplify(ts, keep_unary=keep) + tss, node_map = do_simplify(ts, keep_unary=keep) assert tss.num_sites == 1 assert tss.num_mutations == 3 assert list(tss.haplotypes(isolated_as_missing=False)) == ["1"] @@ -5397,7 +5414,7 @@ def test_many_trees_filter_zero_mutations(self): assert ts.num_sites > ts.num_trees for keep in [True, False]: for filter_sites in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, samples=None, filter_sites=filter_sites, keep_unary=keep ) assert ts.num_sites == tss.num_sites @@ -5411,7 +5428,7 @@ def test_many_trees_filter_zero_multichar_mutations(self): assert ts.num_mutations == ts.num_trees for keep in [True, False]: for filter_sites in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( ts, samples=None, filter_sites=filter_sites, keep_unary=keep ) assert ts.num_sites == tss.num_sites @@ -5423,11 +5440,11 @@ def test_simple_population_filter(self): tables.populations.add_row(metadata=b"unreferenced") assert len(tables.populations) == 2 for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=True, keep_unary=keep ) assert tss.num_populations == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=False, keep_unary=keep ) assert tss.num_populations == 2 @@ -5451,14 +5468,14 @@ def test_interleaved_populations_filter(self): ts = tables.tree_sequence() id_map = np.array([-1, 0, -1, -1], dtype=np.int32) for keep in [True, False]: - tss, _ = self.do_simplify(ts, filter_populations=True, keep_unary=keep) + tss, _ = do_simplify(ts, filter_populations=True, keep_unary=keep) assert tss.num_populations == 1 population = tss.population(0) assert population.metadata == bytes([1]) assert np.array_equal( id_map[ts.tables.nodes.population], tss.tables.nodes.population ) - tss, _ = self.do_simplify(ts, filter_populations=False, keep_unary=keep) + tss, _ = do_simplify(ts, filter_populations=False, keep_unary=keep) assert tss.num_populations == 4 def test_removed_node_population_filter(self): @@ -5472,7 +5489,7 @@ def test_removed_node_population_filter(self): tables.nodes.add_row(flags=0, population=1) tables.nodes.add_row(flags=1, population=2) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=True, keep_unary=keep ) assert tss.num_nodes == 2 @@ -5482,7 +5499,7 @@ def test_removed_node_population_filter(self): assert tss.node(0).population == 0 assert tss.node(1).population == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_populations=False, keep_unary=keep ) assert tss.tables.populations == tables.populations @@ -5494,14 +5511,14 @@ def test_simple_individual_filter(self): tables.nodes.add_row(flags=1, individual=0) tables.nodes.add_row(flags=1, individual=0) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 2 assert tss.num_individuals == 1 assert tss.individual(0).flags == 0 - tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=False) + tss, _ = do_simplify(tables.tree_sequence(), filter_individuals=False) assert tss.tables.individuals == tables.individuals def test_interleaved_individual_filter(self): @@ -5513,14 +5530,14 @@ def test_interleaved_individual_filter(self): tables.nodes.add_row(flags=1, individual=-1) tables.nodes.add_row(flags=1, individual=1) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 3 assert tss.num_individuals == 1 assert tss.individual(0).flags == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=False, keep_unary=keep ) assert tss.tables.individuals == tables.individuals @@ -5536,7 +5553,7 @@ def test_removed_node_individual_filter(self): tables.nodes.add_row(flags=0, individual=1) tables.nodes.add_row(flags=1, individual=2) for keep in [True, False]: - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=True, keep_unary=keep ) assert tss.num_nodes == 2 @@ -5546,13 +5563,13 @@ def test_removed_node_individual_filter(self): assert tss.node(0).individual == 0 assert tss.node(1).individual == 1 - tss, _ = self.do_simplify( + tss, _ = do_simplify( tables.tree_sequence(), filter_individuals=False, keep_unary=keep ) assert tss.tables.individuals == tables.individuals def verify_simplify_haplotypes(self, ts, samples, keep_unary=False): - sub_ts, node_map = self.do_simplify( + sub_ts, node_map = do_simplify( ts, samples, filter_sites=False, keep_unary=keep_unary ) assert ts.num_sites == sub_ts.num_sites @@ -5628,6 +5645,96 @@ def test_many_trees_recurrent_mutations_internal_samples(self): self.verify_simplify_haplotypes(ts, samples, keep_unary=keep) +class TestSimplifyUnreferencedPopulations: + def example(self): + tables = tskit.TableCollection(1) + tables.populations.add_row() + tables.populations.add_row() + # No references to population 0 + tables.nodes.add_row(time=0, population=1, flags=1) + tables.nodes.add_row(time=0, population=1, flags=1) + tables.nodes.add_row(time=1, population=1, flags=0) + # Unreference node + tables.nodes.add_row(time=1, population=1, flags=0) + tables.edges.add_row(0, 1, parent=2, child=0) + tables.edges.add_row(0, 1, parent=2, child=1) + tables.sort() + return tables + + def test_no_filter_populations(self): + tables = self.example() + tables.simplify(filter_populations=False) + assert len(tables.populations) == 2 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.population == 1) + + def test_no_filter_populations_nodes(self): + tables = self.example() + tables.simplify(filter_populations=False, filter_nodes=False) + assert len(tables.populations) == 2 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.population == 1) + + def test_filter_populations_no_filter_nodes(self): + tables = self.example() + tables.simplify(filter_populations=True, filter_nodes=False) + assert len(tables.populations) == 1 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.population == 0) + + def test_remapped_default(self): + tables = self.example() + tables.simplify() + assert len(tables.populations) == 1 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.population == 0) + + +class TestSimplifyUnreferencedIndividuals: + def example(self): + tables = tskit.TableCollection(1) + tables.individuals.add_row() + tables.individuals.add_row() + # No references to individual 0 + tables.nodes.add_row(time=0, individual=1, flags=1) + tables.nodes.add_row(time=0, individual=1, flags=1) + tables.nodes.add_row(time=1, individual=1, flags=0) + # Unreference node + tables.nodes.add_row(time=1, individual=1, flags=0) + tables.edges.add_row(0, 1, parent=2, child=0) + tables.edges.add_row(0, 1, parent=2, child=1) + tables.sort() + return tables + + def test_no_filter_individuals(self): + tables = self.example() + tables.simplify(filter_individuals=False) + assert len(tables.individuals) == 2 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.individual == 1) + + def test_no_filter_individuals_nodes(self): + tables = self.example() + tables.simplify(filter_individuals=False, filter_nodes=False) + assert len(tables.individuals) == 2 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.individual == 1) + + def test_filter_individuals_no_filter_nodes(self): + tables = self.example() + tables.simplify(filter_individuals=True, filter_nodes=False) + assert len(tables.individuals) == 1 + assert len(tables.nodes) == 4 + assert np.all(tables.nodes.individual == 0) + + def test_remapped_default(self): + tables = self.example() + tables.simplify() + assert len(tables.individuals) == 1 + assert len(tables.nodes) == 3 + assert np.all(tables.nodes.individual == 0) + + class TestSimplifyKeepInputRoots(SimplifyTestBase, ExampleTopologyMixin): """ Tests for the keep_input_roots option to simplify. @@ -5643,7 +5750,7 @@ def verify(self, ts): def verify_keep_input_roots(self, ts, samples): ts = tsutil.insert_unique_metadata(ts, ["individuals"]) - ts_with_roots, node_map = self.do_simplify( + ts_with_roots, node_map = do_simplify( ts, samples, keep_input_roots=True, filter_sites=False, compare_lib=True ) new_to_input_map = { @@ -5784,6 +5891,254 @@ def test_many_trees_recurrent_mutations(self): self.verify_keep_input_roots(ts, samples) +class TestSimplifyFilterNodes: + """ + Tests simplify when nodes are kept in the ts with filter_nodes=False + """ + + def reverse_node_indexes(self, ts): + tables = ts.dump_tables() + nodes = tables.nodes + edges = tables.edges + mutations = tables.mutations + nodes.replace_with(nodes[::-1]) + edges.parent = ts.num_nodes - edges.parent - 1 + edges.child = ts.num_nodes - edges.child - 1 + mutations.node = ts.num_nodes - mutations.node - 1 + tables.sort() + return tables.tree_sequence() + + def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): + if resample_size is None: + samples = None + else: + np.random.seed(42) + samples = np.sort( + np.random.choice(ts_in.num_nodes, resample_size, replace=False) + ) + + for ts in (ts_in, self.reverse_node_indexes(ts_in)): + filtered, n_map = do_simplify( + ts, samples=samples, filter_nodes=False, compare_lib=True, **kwargs + ) + assert np.array_equal(n_map, np.arange(ts.num_nodes, dtype=n_map.dtype)) + referenced_nodes = set(filtered.samples()) + referenced_nodes.update(filtered.edges_parent) + referenced_nodes.update(filtered.edges_child) + for n1, n2 in zip(ts.nodes(), filtered.nodes()): + # Ignore the tskit.NODE_IS_SAMPLE flag which can be changed by simplify + n1 = n1.replace(flags=n1.flags | tskit.NODE_IS_SAMPLE) + n2 = n2.replace(flags=n2.flags | tskit.NODE_IS_SAMPLE) + assert n1 == n2 + + # Check that edges are identical to the normal simplify(), + # with the normal "simplify" having altered IDs + simplified, node_map = ts.simplify( + samples=samples, map_nodes=True, **kwargs + ) + simplified_edges = {e for e in simplified.tables.edges} + filtered_edges = { + e.replace(parent=node_map[e.parent], child=node_map[e.child]) + for e in filtered.tables.edges + } + assert filtered_edges == simplified_edges + + def test_empty(self): + ts = tskit.TableCollection(1).tree_sequence() + self.verify_nodes_unchanged(ts) + + def test_all_samples(self): + ts = tskit.Tree.generate_comb(5).tree_sequence + tables = ts.dump_tables() + flags = tables.nodes.flags + flags |= tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + assert ts.num_samples == ts.num_nodes + self.verify_nodes_unchanged(ts) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_no_topology(self, resample_size): + ts = tskit.Tree.generate_comb(5).tree_sequence + ts = ts.keep_intervals([], simplify=False) + assert ts.num_nodes > 5 # has unreferenced nodes + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 2]) + def test_stick_tree(self, resample_size): + ts = tskit.Tree.generate_comb(2).tree_sequence + ts = ts.simplify([0], keep_unary=True) + assert ts.first().parent(0) != tskit.NULL + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + # switch to an internal sample + tables = ts.dump_tables() + flags = tables.nodes.flags + flags[0] = 0 + flags[1] = tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + self.verify_nodes_unchanged(tables.tree_sequence(), resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_internal_samples(self, resample_size): + ts = tskit.Tree.generate_comb(4).tree_sequence + tables = ts.dump_tables() + flags = tables.nodes.flags + flags ^= tskit.NODE_IS_SAMPLE + tables.nodes.flags = flags + ts = tables.tree_sequence() + assert np.all(ts.samples() >= ts.num_samples) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_blank_flanks(self, resample_size): + ts = tskit.Tree.generate_comb(4).tree_sequence + ts = ts.keep_intervals([[0.25, 0.75]], simplify=False) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 4]) + def test_multiroot(self, resample_size): + ts = tskit.Tree.generate_balanced(6).tree_sequence + ts = ts.decapitate(2.5) + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + @pytest.mark.parametrize("resample_size", [None, 10]) + def test_with_metadata(self, ts_fixture_for_simplify, resample_size): + assert ts_fixture_for_simplify.num_nodes > 10 + self.verify_nodes_unchanged( + ts_fixture_for_simplify, resample_size=resample_size + ) + + @pytest.mark.parametrize("resample_size", [None, 7]) + def test_complex_ts_with_unary(self, resample_size): + ts = msprime.sim_ancestry( + 3, + sequence_length=10, + recombination_rate=1, + record_full_arg=True, + random_seed=123, + ) + assert ts.num_trees > 2 + ts = msprime.sim_mutations(ts, rate=1, random_seed=123) + # Add some unreferenced nodes + tables = ts.dump_tables() + tables.nodes.add_row(flags=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE) + ts = tables.tree_sequence() + self.verify_nodes_unchanged(ts, resample_size=resample_size) + + def test_keeping_unary(self): + # Test interaction with keeping unary nodes + n_samples = 6 + ts = tskit.Tree.generate_comb(n_samples).tree_sequence + num_nodes = ts.num_nodes + reduced_n_samples = [2, n_samples - 1] # last sample is most deeply nested + ts_with_unary = ts.simplify(reduced_n_samples, keep_unary=True) + assert ts_with_unary.num_nodes == num_nodes - n_samples + len(reduced_n_samples) + tree = ts_with_unary.first() + assert any([tree.num_children(u) == 1 for u in tree.nodes()]) + self.verify_nodes_unchanged(ts_with_unary, keep_unary=True) + self.verify_nodes_unchanged(ts_with_unary, keep_unary=False) + + def test_find_unreferenced_nodes(self): + # Simple test to show we can find unreferenced nodes easily. + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 2], + filter_nodes=False, + ) + assert np.array_equal(node_map, np.arange(ts1.num_nodes)) + node_references = np.zeros(ts1.num_nodes, dtype=np.int32) + node_references[ts2.edges_parent] += 1 + node_references[ts2.edges_child] += 1 + # Simplifying for [0, 1, 2] should remove references to node 3 and 5 + assert list(node_references) == [1, 1, 1, 0, 2, 0, 1] + + def test_mutations_on_removed_branches(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables() + # A mutation on a removed branch should get removed + tables.sites.add_row(0.5, "A") + tables.mutations.add_row(0, node=3, derived_state="T") + ts2, node_map = do_simplify( + tables.tree_sequence(), + [0, 1, 2], + filter_nodes=False, + ) + assert ts2.num_sites == 0 + assert ts2.num_mutations == 0 + + +class TestSimplifyNoUpdateSampleFlags: + """ + Tests for simplify when we don't update the sample flags. + """ + + def test_simple_case_filter_nodes(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 6], + update_sample_flags=False, + ) + # Because we don't retain 2 and 3 here, they don't stay as + # samples. But, we specified 6 as a sample, so it's coming + # through where it would ordinarily be dropped. + + # 2.00┊ 2 ┊ + # ┊ ┃ ┊ + # 1.00┊ 3 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 ┊ + # 0 1 + assert list(ts2.nodes_flags) == [1, 1, 0, 0] + tree = ts2.first() + assert list(tree.parent_array) == [3, 3, -1, 2, -1] + + def test_simple_case_no_filter_nodes(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts1 = tskit.Tree.generate_balanced(4).tree_sequence + ts2, node_map = do_simplify( + ts1, + [0, 1, 6], + update_sample_flags=False, + filter_nodes=False, + ) + + # 2.00┊ 6 ┊ + # ┊ ┃ ┊ + # 1.00┊ 4 ┊ + # ┊ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + assert list(ts2.nodes_flags) == list(ts1.nodes_flags) + tree = ts2.first() + assert list(tree.parent_array) == [4, 4, -1, -1, 6, -1, -1, -1] + + class TestMapToAncestors: """ Tests the AncestorMap class. diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py new file mode 100644 index 0000000000..961f0810f7 --- /dev/null +++ b/python/tests/test_tree_positioning.py @@ -0,0 +1,470 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tests for tree iterator schemes. Mostly used to develop the incremental +iterator infrastructure. +""" +import msprime +import numpy as np +import pytest + +import tests +import tskit +from tests import tsutil +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +class StatefulTree: + """ + Just enough functionality to mimic the low-level tree implementation + for testing of forward/backward moving. + """ + + def __init__(self, ts): + self.ts = ts + self.tree_pos = tsutil.TreePosition(ts) + self.parent = [-1 for _ in range(ts.num_nodes)] + + def __str__(self): + s = f"parent: {self.parent}\nposition:\n" + for line in str(self.tree_pos).splitlines(): + s += f"\t{line}\n" + return s + + def assert_equal(self, other): + assert self.parent == other.parent + assert self.tree_pos.index == other.tree_pos.index + assert self.tree_pos.interval == other.tree_pos.interval + + def next(self): # NOQA: A003 + valid = self.tree_pos.next() + if valid: + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def prev(self): + valid = self.tree_pos.prev() + if valid: + for j in range( + self.tree_pos.out_range.start, self.tree_pos.out_range.stop, -1 + ): + e = self.tree_pos.out_range.order[j] + c = self.ts.edges_child[e] + self.parent[c] = -1 + for j in range( + self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1 + ): + e = self.tree_pos.in_range.order[j] + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + return valid + + def iter_forward(self, index): + while self.tree_pos.index != index: + self.next() + + def seek_forward(self, index): + old_left, old_right = self.tree_pos.interval + self.tree_pos.seek_forward(index) + left, right = self.tree_pos.interval + # print() + # print("Current interval:", old_left, old_right) + # print("New interval:", left, right) + # print("index:", index, "out_range:", self.tree_pos.out_range) + for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): + e = self.tree_pos.out_range.order[j] + e_left = self.ts.edges_left[e] + # We only need to remove an edge if it's in the current tree, which + # can only happen if the edge's left coord is < the current tree's + # right coordinate. + if e_left < old_right: + c = self.ts.edges_child[e] + assert self.parent[c] != -1 + self.parent[c] = -1 + assert e_left < left + # print("index:", index, "in_range:", self.tree_pos.in_range) + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop): + e = self.tree_pos.in_range.order[j] + if self.ts.edges_left[e] <= left < self.ts.edges_right[e]: + # print("keep", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + # print( + # "INSERT:", + # self.ts.edge(e), + # self.ts.nodes_time[self.ts.edges_parent[e]], + # ) + c = self.ts.edges_child[e] + p = self.ts.edges_parent[e] + self.parent[c] = p + else: + a = self.tree_pos.in_range.start + b = self.tree_pos.in_range.stop + # The first and last indexes in the range should always be valid + # for the tree. + assert a < j < b - 1 + # print("skip", j, e, self.ts.edges_left[e], self.ts.edges_right[e]) + + def seek_backward(self, index): + # TODO + while self.tree_pos.index != index: + self.prev() + + def iter_backward(self, index): + while self.tree_pos.index != index: + self.prev() + + +def check_iters_forward(ts): + alg_t_output = tsutil.algorithm_T(ts) + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + sample_count = np.zeros(ts.num_nodes, dtype=int) + sample_count[ts.samples()] = 1 + parent1 = [-1 for _ in range(ts.num_nodes)] + i = 0 + lib_tree.next() + while tree_pos.next(): + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = next(alg_t_output) + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.next() + i += 1 + assert i == ts.num_trees + assert lib_tree.index == -1 + assert next(alg_t_output, None) is None + + +def check_iters_back(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + i = len(alg_t_output) - 1 + + lib_tree = tskit.Tree(ts) + tree_pos = tsutil.TreePosition(ts) + parent1 = [-1 for _ in range(ts.num_nodes)] + + lib_tree.last() + + while tree_pos.prev(): + # print(tree_pos.out_range) + out_times = [] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, -1): + e = tree_pos.out_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + out_times.append(ts.nodes_time[p]) + parent1[c] = -1 + in_times = [] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, -1): + e = tree_pos.in_range.order[j] + c = ts.edges_child[e] + p = ts.edges_parent[e] + in_times.append(ts.nodes_time[p]) + parent1[c] = p + + # We must visit the edges in *increasing* time order on the way in, + # and *decreasing* order on the way out. Otherwise we get quadratic + # behaviour for algorithms that need to propagate changes up to the + # root. + assert out_times == sorted(out_times, reverse=True) + assert in_times == sorted(in_times) + + interval, parent2 = alg_t_output[i] + assert list(interval) == list(tree_pos.interval) + assert parent1 == parent2 + + assert lib_tree.index == i + assert list(lib_tree.interval) == list(interval) + assert list(lib_tree.parent_array[:-1]) == parent1 + + lib_tree.prev() + i -= 1 + + assert lib_tree.index == -1 + assert i == -1 + + +def check_forward_back_sweep(ts): + alg_t_output = [ + (list(interval), list(parent)) for interval, parent in tsutil.algorithm_T(ts) + ] + for j in range(ts.num_trees - 1): + tree = StatefulTree(ts) + # Seek forward to j + k = 0 + while k <= j: + tree.next() + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + k += 1 + k = j + # And back to zero + while k >= 0: + interval, parent = alg_t_output[k] + assert tree.tree_pos.index == k + assert list(tree.tree_pos.interval) == interval + assert parent == tree.parent + tree.prev() + k -= 1 + + +class TestDirectionSwitching: + # 2.00┊ ┊ 4 ┊ 4 ┊ 4 ┊ + # ┊ ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ ┏┻━┓ ┊ + # 1.00┊ 3 ┊ ┃ 3 ┊ 3 ┃ ┊ 3 ┃ ┊ + # ┊ ┏━╋━┓ ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ 0 2 1 ┊ 0 1 2 ┊ + # 0 1 2 3 4 + # index 0 1 2 3 + def ts(self): + return tsutil.all_trees_ts(3) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_to_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_backward_to_next(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_backward(index) + tree1.next() + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index + 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index + 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_next_prev(self, index): + tree1 = StatefulTree(self.ts()) + tree1.iter_forward(index) + tree1.prev() + tree2 = StatefulTree(self.ts()) + tree2.seek_forward(index - 1) + tree1.assert_equal(tree2) + tree2 = StatefulTree(self.ts()) + tree2.iter_backward(index - 1) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", [1, 2, 3]) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + def test_seek_forward_next_null(self): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(3) + tree1.next() + assert tree1.tree_pos.index == -1 + assert list(tree1.tree_pos.interval) == [0, 0] + + +class TestSeeking: + @tests.cached_example + def ts(self): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + return ts + + @pytest.mark.parametrize("index", range(26)) + def test_seek_forward_from_null(self, index): + tree1 = StatefulTree(self.ts()) + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_forward_from_first(self, index): + tree1 = StatefulTree(self.ts()) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(self.ts()) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("index", range(1, 26)) + def test_seek_last_from_index(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.iter_forward(index) + tree1.seek_forward(ts.num_trees - 1) + tree2 = StatefulTree(ts) + tree2.prev() + tree1.assert_equal(tree2) + + +class TestAllTreesTs: + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_forward(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_back_full(self, n): + ts = tsutil.all_trees_ts(n) + check_iters_back(ts) + + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_forward_back(self, n): + ts = tsutil.all_trees_ts(n) + check_forward_back_sweep(ts) + + +class TestManyTreesSimulationExample: + @tests.cached_example + def ts(self): + ts = msprime.sim_ancestry( + 10, sequence_length=1000, recombination_rate=0.1, random_seed=1234 + ) + assert ts.num_trees > 250 + return ts + + @pytest.mark.parametrize("index", [1, 5, 10, 50, 100]) + def test_seek_forward_from_null(self, index): + ts = self.ts() + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("num_trees", [1, 5, 10, 50, 100]) + def test_seek_forward_from_mid(self, num_trees): + ts = self.ts() + start_index = ts.num_trees // 2 + dest_index = min(start_index + num_trees, ts.num_trees - 1) + tree1 = StatefulTree(ts) + tree1.iter_forward(start_index) + tree1.seek_forward(dest_index) + tree2 = StatefulTree(ts) + tree2.iter_forward(dest_index) + tree1.assert_equal(tree2) + + def test_forward_full(self): + check_iters_forward(self.ts()) + + def test_back_full(self): + check_iters_back(self.ts()) + + +class TestSuiteExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_forward_full(self, ts): + check_iters_forward(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_back_full(self, ts): + check_iters_back(ts) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_null(self, ts): + index = ts.num_trees // 2 + tree1 = StatefulTree(ts) + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_forward_from_first(self, ts): + index = ts.num_trees - 1 + tree1 = StatefulTree(ts) + tree1.next() + tree1.seek_forward(index) + tree2 = StatefulTree(ts) + tree2.iter_forward(index) + tree1.assert_equal(tree2) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 7725931b73..99e8e11c55 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -453,7 +453,6 @@ def node_summary(u): # contains the location of the last time we updated the output for a node. last_update = np.zeros((ts.num_nodes, 1)) for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -980,7 +979,6 @@ def verify(self, ts): self.verify_weighted_stat(ts, W, windows=windows) def verify_definition(self, ts, W, windows, summary_func, ts_method, definition): - # general_stat will need an extra column for p gW = self.transform_weights(W) @@ -1025,7 +1023,6 @@ def verify(self, ts): def verify_definition( self, ts, sample_sets, windows, summary_func, ts_method, definition ): - W = np.array([[u in A for A in sample_sets] for u in ts.samples()], dtype=float) def wrapped_summary_func(x): @@ -1762,7 +1759,6 @@ def divergence( class TestDivergence(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -1974,7 +1970,6 @@ def genetic_relatedness( class TestGeneticRelatedness(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2035,7 +2030,6 @@ def wrapped_summary_func(x): self.assertArrayAlmostEqual(sigma1, sigma4) def verify_sample_sets_indexes(self, ts, sample_sets, indexes, windows): - n = np.array([len(x) for x in sample_sets]) n_total = sum(n) @@ -2101,6 +2095,291 @@ def test_match_K_c0(self): self.assertArrayAlmostEqual(A, B) +############################################ +# Genetic relatedness weighted +############################################ + + +def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): + n = len(sample_sets) + indexes = [ + (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) + ] + if windows is None: + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_nodes, n, n)) + out = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + for node in range(n_nodes): + this_K = np.zeros((n, n)) + this_K[np.triu_indices(n)] = out[node, :] + this_K = this_K + np.triu(this_K, 1).transpose() + K[node, :, :] = this_K + else: + K = np.zeros((n, n)) + K[np.triu_indices(n)] = ts.genetic_relatedness( + sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + ) + K = K + np.triu(K, 1).transpose() + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + out = ts.genetic_relatedness( + sample_sets, + indexes, + mode=mode, + windows=windows, + proportion=False, + span_normalise=True, + ) + if mode == "node": + n_nodes = ts.num_nodes + K = np.zeros((n_windows, n_nodes, n, n)) + for win in range(n_windows): + for node in range(n_nodes): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, node, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, node, :, :] = K_this + else: + K = np.zeros((n_windows, n, n)) + for win in range(n_windows): + K_this = np.zeros((n, n)) + K_this[np.triu_indices(n)] = out[win, :] + K_this = K_this + np.triu(K_this, 1).transpose() + K[win, :, :] = K_this + return K + + +def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"): + W_mean = W.mean(axis=0) + W = W - W_mean + sample_sets = [[u] for u in ts.samples()] + K = genetic_relatedness_matrix(ts, sample_sets, windows, mode) + n_indexes = len(indexes) + n_nodes = ts.num_nodes + if windows is None: + if mode == "node": + out = np.zeros((n_nodes, n_indexes)) + else: + out = np.zeros(n_indexes) + else: + windows = ts.parse_windows(windows) + n_windows = len(windows) - 1 + if mode == "node": + out = np.zeros((n_windows, n_nodes, n_indexes)) + else: + out = np.zeros((n_windows, n_indexes)) + for pair in range(n_indexes): + i1 = indexes[pair][0] + i2 = indexes[pair][1] + if windows is None: + if mode == "node": + for node in range(n_nodes): + this_K = K[node, :, :] + out[node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + out[pair] = W[:, i1] @ K @ W[:, i2] + else: + for win in range(n_windows): + if mode == "node": + for node in range(n_nodes): + this_K = K[win, node, :, :] + out[win, node, pair] = W[:, i1] @ this_K @ W[:, i2] + else: + this_K = K[win, :, :] + out[win, pair] = W[:, i1] @ this_K @ W[:, i2] + return out + + +def example_index_pairs(weights): + assert weights.shape[1] >= 2 + yield [(0, 1)] + yield [(1, 0), (0, 1)] + if weights.shape[1] > 2: + yield [(0, 1), (1, 2), (0, 2)] + + +class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): + # Derived classes define this to get a specific stats mode. + mode = None + + def verify_definition( + self, ts, W, indexes, windows, summary_func, ts_method, definition + ): + # Determine output_dim of the function + M = len(indexes) + + sigma1 = ts.general_stat( + W, summary_func, M, windows, mode=self.mode, span_normalise=True + ) + sigma2 = general_stat( + ts, W, summary_func, windows, mode=self.mode, span_normalise=True + ) + + sigma3 = ts_method( + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + sigma4 = definition( + ts, + W, + indexes=indexes, + windows=windows, + mode=self.mode, + ) + assert sigma1.shape == sigma2.shape + assert sigma1.shape == sigma3.shape + assert sigma1.shape == sigma4.shape + self.assertArrayAlmostEqual(sigma1, sigma2) + self.assertArrayAlmostEqual(sigma1, sigma3) + self.assertArrayAlmostEqual(sigma1, sigma4) + + def verify(self, ts): + for W, windows in subset_combos( + self.example_weights(ts, min_size=2), example_windows(ts), p=0.1 + ): + for indexes in example_index_pairs(W): + self.verify_weighted_stat(ts, W, indexes, windows) + + def verify_weighted_stat(self, ts, W, indexes, windows): + W_mean = W.mean(axis=0) + W = W - W_mean + W_sum = W.sum(axis=0) + n = W.shape[0] + + def f(x): + mx = np.sum(x) / n + return np.array( + [ + (x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 + for i, j in indexes + ] + ) + + self.verify_definition( + ts, + W, + indexes, + windows, + f, + ts.genetic_relatedness_weighted, + genetic_relatedness_weighted, + ) + + +class TestBranchGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "branch" + + +class TestNodeGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, TopologyExamplesMixin +): + mode = "node" + + +class TestSiteGeneticRelatednessWeighted( + TestGeneticRelatednessWeighted, MutatedTopologyExamplesMixin +): + mode = "site" + + +# NOTE: these classes don't follow the same (anti)-patterns as used elsewhere as they +# were added in several years afterwards. + + +class TestGeneticRelatednessWeightedSimpleExamples: + # Values verified against the simple implementations above + site_value = 11.12 + branch_value = 14.72 + + def fixture(self): + ts = tskit.Tree.generate_balanced(5).tree_sequence + # Abitrary weights that give non-zero results + W = np.zeros((ts.num_samples, 2)) + W[0, :] = 1 + W[1, :] = 2 + return tsutil.insert_branch_sites(ts), W + + def test_no_arguments_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.site_value) + + def test_windows_site(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="site", windows=[0, 1 - 1e-12, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X[0], self.site_value) + nt.assert_almost_equal(X[1], 0) + + def test_no_arguments_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_windows_branch(self): + ts, W = self.fixture() + X = ts.genetic_relatedness_weighted(W, mode="branch", windows=[0, 0.5, 1]) + assert X.shape == (2,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_1D(self): + ts, W = self.fixture() + indexes = [0, 1] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == tuple() + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D(self): + ts, W = self.fixture() + indexes = [[0, 1]] + X = ts.genetic_relatedness_weighted(W, indexes, mode="branch") + assert X.shape == (1,) + nt.assert_almost_equal(X, self.branch_value) + + def test_indexes_2D_windows(self): + ts, W = self.fixture() + indexes = [[0, 1], [0, 1]] + X = ts.genetic_relatedness_weighted( + W, indexes, windows=[0, 0.5, 1], mode="branch" + ) + assert X.shape == (2, 2) + nt.assert_almost_equal(X, self.branch_value) + + +class TestGeneticRelatednessWeightedErrors: + def ts(self): + return tskit.Tree.generate_balanced(3).tree_sequence + + @pytest.mark.parametrize("W", [[0], np.array([0]), np.zeros(100)]) + def test_bad_weight_size(self, W): + with pytest.raises(ValueError, match="First trait dimension"): + self.ts().genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("cols", [1, 3]) + def test_no_indexes_with_non_2_cols(self, cols): + ts = self.ts() + W = np.zeros((ts.num_samples, cols)) + with pytest.raises(ValueError, match="Must specify indexes"): + ts.genetic_relatedness_weighted(W) + + @pytest.mark.parametrize("indexes", [[], [[0]], [[0, 0, 0]], [[[0], [0], [0]]]]) + def test_bad_index_shapes(self, indexes): + ts = self.ts() + W = np.zeros((ts.num_samples, 2)) + with pytest.raises(ValueError, match="Indexes must be convertable to a 2D"): + ts.genetic_relatedness_weighted(W, indexes=indexes) + + ############################################ # Fst ############################################ @@ -2143,7 +2422,6 @@ def single_site_Fst(ts, sample_sets, indexes): class TestFst(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2332,7 +2610,6 @@ def Y2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2505,7 +2782,6 @@ def Y3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class TestY3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2674,7 +2950,6 @@ def f2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf2(StatsTestCase, TwoWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -2860,7 +3135,6 @@ def f3(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf3(StatsTestCase, ThreeWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3051,7 +3325,6 @@ def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise= class Testf4(StatsTestCase, FourWaySampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -3315,7 +3588,6 @@ def update_result(window_index, u, right): last_update[u] = right for (t_left, t_right), edges_out, edges_in in ts.edge_diffs(): - for edge in edges_out: u = edge.child v = edge.parent @@ -3476,7 +3748,6 @@ def allele_frequency_spectrum( class TestAlleleFrequencySpectrum(StatsTestCase, SampleSetStatsMixin): - # Derived classes define this to get a specific stats mode. mode = None @@ -5806,7 +6077,6 @@ def f(x): branch_true_diversity_02, ], ): - self.assertAlmostEqual(diversity(ts, A, mode=mode)[0][0], truth) self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0], truth) self.assertAlmostEqual(ts.diversity(A, mode="branch")[0], truth) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index cc4f9d45da..1d78dd0a08 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -489,6 +489,27 @@ def test_unicode_table(): ) +def test_unicode_table_column_alignments(): + assert ( + util.unicode_table( + [["5", "6", "7", "8"], ["90", "10", "11", "12"]], + header=["1", "2", "3", "4"], + column_alignments="<>><", + ) + == textwrap.dedent( + """ + ╔══╤══╤══╤══╗ + ║1 │2 │3 │4 ║ + ╠══╪══╪══╪══╣ + ║5 │ 6│ 7│8 ║ + ╟──┼──┼──┼──╢ + ║90│10│11│12║ + ╚══╧══╧══╧══╝ + """ + )[1:] + ) + + def test_set_printoptions(): assert tskit._print_options == {"max_lines": 40} util.set_print_options(max_lines=None) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 7f70e29a1f..b86a159274 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -24,11 +24,13 @@ A collection of utilities to edit and construct tree sequences. """ import collections +import dataclasses import functools import json import random import string import struct +import typing import msprime import numpy as np @@ -81,6 +83,8 @@ def insert_branch_mutations(ts, mutations_per_branch=1): Returns a copy of the specified tree sequence with a mutation on every branch in every tree. """ + if mutations_per_branch == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() @@ -146,23 +150,26 @@ def insert_discrete_time_mutations(ts, num_times=4, num_sites=10): return tables.tree_sequence() -def insert_branch_sites(ts): +def insert_branch_sites(ts, m=1): """ - Returns a copy of the specified tree sequence with a site on every branch + Returns a copy of the specified tree sequence with m sites on every branch of every tree. """ + if m == 0: + return ts tables = ts.dump_tables() tables.sites.clear() tables.mutations.clear() for tree in ts.trees(): left, right = tree.interval - delta = (right - left) / len(list(tree.nodes())) + delta = (right - left) / (m * len(list(tree.nodes()))) x = left for u in tree.nodes(): if tree.parent(u) != tskit.NULL: - site = tables.sites.add_row(position=x, ancestral_state="0") - tables.mutations.add_row(site=site, node=u, derived_state="1") - x += delta + for _ in range(m): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + x += delta add_provenance(tables.provenances, "insert_branch_sites") return tables.tree_sequence() @@ -1708,6 +1715,196 @@ def iterate(self): left = right +FORWARD = 1 +REVERSE = -1 + + +@dataclasses.dataclass +class Interval: + left: float + right: float + + def __iter__(self): + yield self.left + yield self.right + + +@dataclasses.dataclass +class EdgeRange: + start: int + stop: int + order: typing.List + + +class TreePosition: + def __init__(self, ts): + self.ts = ts + self.index = -1 + self.direction = 0 + self.interval = Interval(0, 0) + self.in_range = EdgeRange(0, 0, None) + self.out_range = EdgeRange(0, 0, None) + + def __str__(self): + s = f"index: {self.index}\ninterval: {self.interval}\n" + s += f"direction: {self.direction}\n" + s += f"in_range: {self.in_range}\n" + s += f"out_range: {self.out_range}\n" + return s + + def set_null(self): + self.index = -1 + self.interval.left = 0 + self.interval.right = 0 + + def next(self): # NOQA: A003 + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + left = self.interval.right + + j = right_current_index + self.out_range.start = j + while j < M and right_coords[right_order[j]] == left: + j += 1 + self.out_range.stop = j + self.out_range.order = right_order + + j = left_current_index + self.in_range.start = j + while j < M and left_coords[left_order[j]] == left: + j += 1 + self.in_range.stop = j + self.in_range.order = left_order + + self.direction = FORWARD + self.index += 1 + if self.index == self.ts.num_trees: + self.set_null() + else: + self.interval.left = left + self.interval.right = breakpoints[self.index + 1] + return self.index != -1 + + def prev(self): + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + + if self.index == -1: + self.index = self.ts.num_trees + self.interval.left = self.ts.sequence_length + self.in_range.stop = M - 1 + self.out_range.stop = M - 1 + self.direction = REVERSE + + if self.direction == REVERSE: + left_current_index = self.out_range.stop + right_current_index = self.in_range.stop + else: + left_current_index = self.in_range.stop - 1 + right_current_index = self.out_range.stop - 1 + + right = self.interval.left + + j = left_current_index + self.out_range.start = j + while j >= 0 and left_coords[left_order[j]] == right: + j -= 1 + self.out_range.stop = j + self.out_range.order = left_order + + j = right_current_index + self.in_range.start = j + while j >= 0 and right_coords[right_order[j]] == right: + j -= 1 + self.in_range.stop = j + self.in_range.order = right_order + + self.direction = REVERSE + self.index -= 1 + if self.index == -1: + self.set_null() + else: + self.interval.left = breakpoints[self.index] + self.interval.right = right + return self.index != -1 + + def seek_forward(self, index): + # NOTE this is still in development and not fully tested. + assert index >= self.index and index < self.ts.num_trees + M = self.ts.num_edges + breakpoints = self.ts.breakpoints(as_array=True) + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval.right = 0 + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 + + self.direction = FORWARD + left = breakpoints[index] + + # The range of edges we need consider for removal starts + # at the current right index and ends at the first edge + # where the right coordinate is equal to the new tree's + # left coordinate. + j = right_current_index + self.out_range.start = j + # TODO This could be done with binary search + while j < M and right_coords[right_order[j]] <= left: + j += 1 + self.out_range.stop = j + + # The range of edges we need to consider for the new tree + # must have right coordinate > left + j = left_current_index + while j < M and right_coords[left_order[j]] <= left: + j += 1 + self.in_range.start = j + # TODO this could be done with a binary search + while j < M and left_coords[left_order[j]] <= left: + j += 1 + self.in_range.stop = j + + self.interval.left = left + self.interval.right = breakpoints[index + 1] + self.out_range.order = right_order + self.in_range.order = left_order + self.index = index + + def mean_descendants(ts, reference_sets): """ Returns the mean number of nodes from the specified reference sets @@ -1774,7 +1971,6 @@ def update_counts(edge, left, sign): def genealogical_nearest_neighbours(ts, focal, reference_sets): - reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 for k, reference_set in enumerate(reference_sets): for u in reference_set: @@ -1932,9 +2128,12 @@ def all_trees_ts(n): return tables.tree_sequence() -def all_fields_ts(): +def all_fields_ts(edge_metadata=True, migrations=True): """ - A tree sequence with data in all fields + A tree sequence with data in all fields (except edge metadata is not set if + edge_metadata is False and migrations are not defined if migrations is False + (this is needed to test simplify, which doesn't allow either) + """ demography = msprime.Demography() demography.add_population(name="A", initial_size=10_000) @@ -1949,7 +2148,7 @@ def all_fields_ts(): sequence_length=5, random_seed=42, recombination_rate=1, - record_migrations=True, + record_migrations=migrations, record_provenance=True, ) ts = msprime.sim_mutations(ts, rate=0.001, random_seed=42) @@ -1973,21 +2172,27 @@ def all_fields_ts(): population=i % len(tables.populations), ) ) - tables.migrations.add_row(left=0, right=1, node=21, source=1, dest=3, time=1001) + if migrations: + tables.migrations.add_row(left=0, right=1, node=21, source=1, dest=3, time=1001) # Add metadata for name, table in tables.table_name_map.items(): - if name != "provenances": - table.metadata_schema = tskit.MetadataSchema.permissive_json() - metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] - metadata, metadata_offset = tskit.pack_strings(metadatas) - table.set_columns( - **{ - **table.asdict(), - "metadata": metadata, - "metadata_offset": metadata_offset, - } - ) + if name == "provenances": + continue + if name == "migrations" and not migrations: + continue + if name == "edges" and not edge_metadata: + continue + table.metadata_schema = tskit.MetadataSchema.permissive_json() + metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] + metadata, metadata_offset = tskit.pack_strings(metadatas) + table.set_columns( + **{ + **table.asdict(), + "metadata": metadata, + "metadata_offset": metadata_offset, + } + ) tables.metadata_schema = tskit.MetadataSchema.permissive_json() tables.metadata = "Test metadata" tables.time_units = "Test time units" diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index 09e16091e5..c1b153be04 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -90,3 +90,4 @@ from tskit.util import * # NOQA from tskit.metadata import * # NOQA from tskit.text_formats import * # NOQA +from tskit.intervals import RateMap # NOQA diff --git a/python/tskit/_version.py b/python/tskit/_version.py index d730ceabab..d36b46f038 100644 --- a/python/tskit/_version.py +++ b/python/tskit/_version.py @@ -1,4 +1,4 @@ # Definitive location for the version number. # During development, should be x.y.z.devN # For beta should be x.y.zbN -tskit_version = "0.5.3" +tskit_version = "0.5.5" diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 4097172632..51e0260cd6 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -50,8 +50,9 @@ # constants for whether to plot a tree in a tree sequence OMIT = 1 -LEFT_CLIPPED_BIT = 2 -RIGHT_CLIPPED_BIT = 4 +LEFT_CLIP = 2 +RIGHT_CLIP = 4 +OMIT_MIDDLE = 8 @dataclass @@ -62,7 +63,7 @@ class Offsets: mutation: int = 0 -@dataclass +@dataclass(frozen=True) class Timescaling: "Class used to transform the time axis" max_time: float @@ -77,9 +78,9 @@ def __post_init__(self): if self.use_log_transform: if self.min_time < 0: raise ValueError("Cannot use a log scale if there are negative times") - self.transform = self.log_transform + super().__setattr__("transform", self.log_transform) else: - self.transform = self.linear_transform + super().__setattr__("transform", self.linear_transform) def log_transform(self, y): "Standard log transform but allowing for values of 0 by adding 1" @@ -229,10 +230,11 @@ def create_tick_labels(tick_values, decimal_places=2): return [f"{lab:.{label_precision}f}" for lab in tick_values] -def clip_ts(ts, x_min, x_max): +def clip_ts(ts, x_min, x_max, max_num_trees=None): """ Culls the edges of the tree sequence outside the limits of x_min and x_max if - necessary. + necessary, and flags internal trees for omission if there are more than + max_num_trees in the tree sequence Returns the new tree sequence using the same genomic scale, and an array specifying which trees to actually plot from it. This array contains @@ -276,6 +278,12 @@ def clip_ts(ts, x_min, x_max): if ts.num_sites > 0 and np.max(sites.position) > x_max: x_max = ts.sequence_length # Last region has sites but no edges => keep + if max_num_trees is None: + max_num_trees = np.inf + + if max_num_trees < 2: + raise ValueError("Must show at least 2 trees when clipping a tree sequence") + if (x_min > 0) or (x_max < ts.sequence_length): old_breaks = ts.breakpoints(as_array=True) offsets.tree = np.searchsorted(old_breaks, x_min, "right") - 2 @@ -303,10 +311,22 @@ def clip_ts(ts, x_min, x_max): # Which breakpoints are new ones, as a result of clipping new_breaks = np.logical_not(np.isin(ts.breakpoints(as_array=True), old_breaks)) - tree_status[new_breaks[:-1]] |= LEFT_CLIPPED_BIT - tree_status[new_breaks[1:]] |= RIGHT_CLIPPED_BIT + tree_status[new_breaks[:-1]] |= LEFT_CLIP + tree_status[new_breaks[1:]] |= RIGHT_CLIP else: tree_status = np.zeros(ts.num_trees, dtype=np.uint8) + + first_tree = 1 if tree_status[0] & OMIT else 0 + last_tree = ts.num_trees - 2 if tree_status[-1] & OMIT else ts.num_trees - 1 + num_shown_trees = last_tree - first_tree + 1 + if num_shown_trees > max_num_trees: + num_start_trees = max_num_trees // 2 + (1 if max_num_trees % 2 else 0) + num_end_trees = max_num_trees // 2 + assert num_start_trees + num_end_trees == max_num_trees + tree_status[ + (first_tree + num_start_trees) : (last_tree - num_end_trees + 1) + ] = (OMIT | OMIT_MIDDLE) + return ts, tree_status, offsets @@ -336,20 +356,27 @@ def rnd(x): return x -def referenced_nodes(ts): +def edge_and_sample_nodes(ts, omit_regions=None): """ - Return the ids of nodes which are actually plotted in this tree sequence - (i.e. do not include nodes which are not samples and not in any edge: this - happens extensively in plotting tree sequences with x_lim specified) + Return ids of nodes which are mentioned in an edge in this tree sequence or which + are samples: nodes not connected to an edge are often found if x_lim is specified. """ - ids = np.concatenate( - ( - ts.tables.edges.child, - ts.tables.edges.parent, - np.where(ts.tables.nodes.flags & NODE_IS_SAMPLE)[0], - ) + if omit_regions is None or len(omit_regions) == 0: + ids = np.concatenate((ts.edges_child, ts.edges_parent)) + else: + ids = np.array([], dtype=ts.edges_child.dtype) + edges = ts.tables.edges + assert omit_regions.shape[1] == 2 + omit_regions = omit_regions.flatten() + assert np.all(omit_regions == np.unique(omit_regions)) # Check they're in order + use_regions = np.concatenate(([0.0], omit_regions, [ts.sequence_length])) + use_regions = use_regions.reshape(-1, 2) + for left, right in use_regions: + used_edges = edges[np.logical_and(edges.left >= left, edges.right < right)] + ids = np.concatenate((ids, used_edges.child, used_edges.parent)) + return np.unique( + np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0])) ) - return np.unique(ids) def draw_tree( @@ -469,10 +496,17 @@ def add_class(attrs_dict, classes_str): @dataclass class Plotbox: total_size: list - pad_top: float - pad_left: float - pad_bottom: float - pad_right: float + pad_top: float = 0 + pad_left: float = 0 + pad_bottom: float = 0 + pad_right: float = 0 + + def set_padding(self, top, left, bottom, right): + self.pad_top = top + self.pad_left = left + self.pad_bottom = bottom + self.pad_right = right + self._check() @property def max_x(self): @@ -507,6 +541,9 @@ def height(self): return self.bottom - self.top def __post_init__(self): + self._check() + + def _check(self): if self.width < 1 or self.height < 1: raise ValueError("Image size too small to fit") @@ -537,7 +574,92 @@ def draw(self, dwg, add_to, colour="grey"): class SvgPlot: - """The base class for plotting either a tree or a tree sequence as an SVG file""" + """ + The base class for plotting any box to canvas + """ + + text_height = 14 # May want to calculate this based on a font size + line_height = text_height * 1.2 # allowing padding above and below a line + + def __init__( + self, + size, + svg_class, + root_svg_attributes=None, + canvas_size=None, + ): + """ + Creates self.drawing, an svgwrite.Drawing object for further use, and populates + it with a base group. The root_groups will be populated with + items that can be accessed from the outside, such as the plotbox, axes, etc. + """ + + if root_svg_attributes is None: + root_svg_attributes = {} + if canvas_size is None: + canvas_size = size + dwg = svgwrite.Drawing(size=canvas_size, debug=True, **root_svg_attributes) + + self.image_size = size + self.plotbox = Plotbox(size) + self.root_groups = {} + self.svg_class = svg_class + self.timescaling = None + self.root_svg_attributes = root_svg_attributes + self.dwg_base = dwg.add(dwg.g(class_=svg_class)) + self.drawing = dwg + + def get_plotbox(self): + """ + Get the svgwrite plotbox, creating it if necessary. + """ + if "plotbox" not in self.root_groups: + dwg = self.drawing + self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox")) + return self.root_groups["plotbox"] + + def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs): + """ + Add the text to the elem within a group; allows text rotations to work smoothly, + otherwise, if x & y parameters are used to position text, rotations applied to + the text tag occur around the (0,0) point of the containing group + """ + dwg = self.drawing + group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"} + if group_class is not None: + group_attributes["class_"] = group_class + grp = add_to.add(dwg.g(**group_attributes)) + grp.add(dwg.text(text, **kwargs)) + + +class SvgSkippedPlot(SvgPlot): + def __init__( + self, + size, + num_skipped, + ): + super().__init__( + size, + svg_class="skipped", + ) + container = self.get_plotbox() + x = self.plotbox.width / 2 + y = self.plotbox.height / 2 + self.add_text_in_group( + f"{num_skipped} trees", + container, + (x, y - self.line_height / 2), + text_anchor="middle", + ) + self.add_text_in_group( + "skipped", container, (x, y + self.line_height / 2), text_anchor="middle" + ) + + +class SvgAxisPlot(SvgPlot): + """ + The class used for plotting either a tree or a tree sequence as an SVG file + """ standard_style = ( ".background path {fill: #808080; fill-opacity: 0}" @@ -546,6 +668,7 @@ class SvgPlot: ".x-axis .tick .lab {font-weight: bold; dominant-baseline: hanging}" ".axes, .tree {font-size: 14px; text-anchor: middle}" ".axes line, .edge {stroke: black; fill: none}" + ".axes .ax-skip {stroke-dasharray: 4}" ".y-axis .grid {stroke: #FAFAFA}" ".node > .sym {fill: black; stroke: none}" ".site > .sym {stroke: black}" @@ -561,8 +684,6 @@ class SvgPlot: ) # TODO: we may want to make some of the constants below into parameters - text_height = 14 # May want to calculate this based on a font size - line_height = text_height * 1.2 # allowing padding above and below a line root_branch_fraction = 1 / 8 # Rel root branch len, unless it has a timed mutation default_tick_length = 5 default_tick_length_site = 10 @@ -587,27 +708,18 @@ def __init__( omit_sites=None, canvas_size=None, ): - """ - Creates self.drawing, an svgwrite.Drawing object for further use, and populates - it with a stylesheet and base group. The root_groups will be populated with - items that can be accessed from the outside, such as the plotbox, axes, etc. - """ + super().__init__( + size, + svg_class, + root_svg_attributes, + canvas_size, + ) self.ts = ts - self.image_size = size - self.svg_class = svg_class - if root_svg_attributes is None: - root_svg_attributes = {} - if canvas_size is None: - canvas_size = size - self.root_svg_attributes = root_svg_attributes - dwg = svgwrite.Drawing(size=canvas_size, debug=True, **root_svg_attributes) + dwg = self.drawing # Put all styles in a single stylesheet (required for Inkscape 0.92) style = self.standard_style + ("" if style is None else style) dwg.defs.add(dwg.style(style)) - self.dwg_base = dwg.add(dwg.g(class_=svg_class)) - self.root_groups = {} self.debug_box = debug_box - self.drawing = dwg self.time_scale = check_time_scale(time_scale) self.y_axis = y_axis self.x_axis = x_axis @@ -626,29 +738,6 @@ def __init__( self.omit_sites = omit_sites self.mutations_outside_tree = set() # mutations in here get an additional class - def get_plotbox(self): - """ - Get the svgwrite plotbox (contains the tree(s) but not axes etc), creating it - if necessary. - """ - if "plotbox" not in self.root_groups: - dwg = self.drawing - self.root_groups["plotbox"] = self.dwg_base.add(dwg.g(class_="plotbox")) - return self.root_groups["plotbox"] - - def add_text_in_group(self, text, add_to, pos, group_class=None, **kwargs): - """ - Add the text to the elem within a group; allows text rotations to work smoothly, - otherwise, if x & y parameters are used to position text, rotations applied to - the text tag occur around the (0,0) point of the containing group - """ - dwg = self.drawing - group_attributes = {"transform": f"translate({rnd(pos[0])} {rnd(pos[1])})"} - if group_class is not None: - group_attributes["class_"] = group_class - grp = add_to.add(dwg.g(**group_attributes)) - grp.add(dwg.text(text, **kwargs)) - def set_spacing(self, top=0, left=0, bottom=0, right=0): """ Set edges, but allow space for axes etc @@ -663,7 +752,7 @@ def set_spacing(self, top=0, left=0, bottom=0, right=0): bottom += self.x_axis_offset if self.y_axis: left = self.y_axis_offset # Override user-provided, so y-axis is at x=0 - self.plotbox = Plotbox(self.image_size, top, left, bottom, right) + self.plotbox.set_padding(top, left, bottom, right) if self.debug_box: self.root_groups["debug"] = self.dwg_base.add( self.drawing.g(class_="debug") @@ -682,9 +771,12 @@ def draw_x_axis( tick_length_lower=default_tick_length, tick_length_upper=None, # If None, use the same as tick_length_lower site_muts=None, # A dict of site id => mutation to plot as ticks on the x axis + alternate_dash_positions=None, # Where to alternate the axis from solid to dash ): if not self.x_axis and not self.x_label: return + if alternate_dash_positions is None: + alternate_dash_positions = np.array([]) dwg = self.drawing axes = self.get_axes() x_axis = axes.add(dwg.g(class_="x-axis")) @@ -702,7 +794,21 @@ def draw_x_axis( if tick_length_upper is None: tick_length_upper = tick_length_lower y = rnd(self.plotbox.max_y - self.x_axis_offset) - x_axis.add(dwg.line((self.plotbox.left, y), (self.plotbox.right, y))) + dash_locs = np.concatenate( + ( + [self.plotbox.left], + self.x_transform(alternate_dash_positions), + [self.plotbox.right], + ) + ) + for i, (x1, x2) in enumerate(zip(dash_locs[:-1], dash_locs[1:])): + x_axis.add( + dwg.line( + (rnd(x1), y), + (rnd(x2), y), + class_="ax-skip" if i % 2 else "ax-line", + ) + ) if tick_positions is not None: if tick_labels is None or isinstance(tick_labels, np.ndarray): if tick_labels is None: @@ -790,7 +896,7 @@ def draw_y_axis( transform="translate(11) rotate(-90)", ) if self.y_axis: - y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)))) + y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)), class_="ax-line")) ticks_group = y_axis.add(dwg.g(class_="ticks")) for y, label in ticks.items(): tick = ticks_group.add( @@ -859,7 +965,7 @@ def shade_background( diag_h=rnd(diag_height), tick_h=rnd(tick_length_lower), ax_x=rnd(prev_break_x - break_x), - ldiag_x=rnd(prev_tree_x - prev_break_x), + ldiag_x=rnd(rnd(prev_tree_x) - rnd(prev_break_x)), ) ) ) @@ -870,7 +976,7 @@ def x_transform(self, x): ) -class SvgTreeSequence(SvgPlot): +class SvgTreeSequence(SvgAxisPlot): """ A class to draw a tree sequence in SVG format. @@ -906,6 +1012,7 @@ def __init__( mutation_label_attrs=None, tree_height_scale=None, max_tree_height=None, + max_num_trees=None, **kwargs, ): if max_time is None and max_tree_height is not None: @@ -923,10 +1030,13 @@ def __init__( FutureWarning, ) x_lim = check_x_lim(x_lim, max_x=ts.sequence_length) - ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1]) - num_trees = int(np.sum((self.tree_status & OMIT) != OMIT)) + ts, self.tree_status, offsets = clip_ts(ts, x_lim[0], x_lim[1], max_num_trees) + + use_tree = self.tree_status & OMIT == 0 + use_skipped = np.append(np.diff(self.tree_status & OMIT_MIDDLE == 0) == 1, 0) + num_plotboxes = np.sum(np.logical_or(use_tree, use_skipped)) if size is None: - size = (200 * num_trees, 200) + size = (200 * int(num_plotboxes), 200) if max_time is None: max_time = "ts" if min_time is None: @@ -954,53 +1064,68 @@ def __init__( if force_root_branch is None: force_root_branch = any( any(tree.parent(mut.node) == NULL for mut in tree.mutations()) - for tree in ts.trees() + for tree, use in zip(ts.trees(), use_tree) + if use ) # TODO add general padding arguments following matplotlib's terminology. self.set_spacing(top=0, left=20, bottom=10, right=20) - svg_trees = [ - SvgTree( - tree, - (self.plotbox.width / num_trees, self.plotbox.height), - time_scale=time_scale, - node_labels=node_labels, - mutation_labels=mutation_labels, - order=order, - force_root_branch=force_root_branch, - symbol_size=symbol_size, - max_time=max_time, - min_time=min_time, - node_attrs=node_attrs, - mutation_attrs=mutation_attrs, - edge_attrs=edge_attrs, - node_label_attrs=node_label_attrs, - mutation_label_attrs=mutation_label_attrs, - offsets=offsets, - # Do not plot axes on these subplots - **kwargs, # pass though e.g. debug boxes - ) - for status, tree in zip(self.tree_status, ts.trees()) - if (status & OMIT) != OMIT - ] + subplot_size = (self.plotbox.width / num_plotboxes, self.plotbox.height) + subplots = [] + for tree, use, summary in zip(ts.trees(), use_tree, use_skipped): + if use: + subplots.append( + SvgTree( + tree, + size=subplot_size, + time_scale=time_scale, + node_labels=node_labels, + mutation_labels=mutation_labels, + order=order, + force_root_branch=force_root_branch, + symbol_size=symbol_size, + max_time=max_time, + min_time=min_time, + node_attrs=node_attrs, + mutation_attrs=mutation_attrs, + edge_attrs=edge_attrs, + node_label_attrs=node_label_attrs, + mutation_label_attrs=mutation_label_attrs, + offsets=offsets, + # Do not plot axes on these subplots + **kwargs, # pass though e.g. debug boxes + ) + ) + last_used_index = tree.index + elif summary: + subplots.append( + SvgSkippedPlot( + size=subplot_size, num_skipped=tree.index - last_used_index + ) + ) y = self.plotbox.top - self.tree_plotbox = svg_trees[0].plotbox + self.tree_plotbox = subplots[0].plotbox + tree_is_used, breaks, skipbreaks = self.find_used_trees() self.draw_x_axis( x_scale, + tree_is_used, + breaks, + skipbreaks, tick_length_lower=self.default_tick_length, # TODO - parameterize tick_length_upper=self.default_tick_length_site, # TODO - parameterize ) y_low = self.tree_plotbox.bottom if y_axis is not None: - self.timescaling = svg_trees[0].timescaling - for svg_tree in svg_trees: - if self.timescaling != svg_tree.timescaling: - raise ValueError( - "Can't draw a tree sequence Y axis if trees vary in timescale" - ) + tscales = {s.timescaling for s in subplots if s.timescaling} + if len(tscales) > 1: + raise ValueError( + "Can't draw a tree sequence Y axis if trees vary in timescale" + ) + self.timescaling = tscales.pop() y_low = self.timescaling.transform(self.timescaling.min_time) if y_ticks is None: - y_ticks = np.unique(ts.tables.nodes.time[referenced_nodes(ts)]) + used_nodes = edge_and_sample_nodes(ts, breaks[skipbreaks]) + y_ticks = np.unique(ts.nodes_time[used_nodes]) if self.time_scale == "rank": # Ticks labelled by time not rank y_ticks = dict(enumerate(y_ticks)) @@ -1013,78 +1138,128 @@ def __init__( gridlines=y_gridlines, ) - tree_x = self.plotbox.left - trees = self.get_plotbox() # Top-level TS plotbox contains all trees - trees["class"] = trees["class"] + " trees" - for svg_tree in svg_trees: - tree = trees.add( + subplot_x = self.plotbox.left + container = self.get_plotbox() # Top-level TS plotbox contains all trees + container["class"] = container["class"] + " trees" + for subplot in subplots: + svg_subplot = container.add( self.drawing.g( - class_=svg_tree.svg_class, transform=f"translate({rnd(tree_x)} {y})" + class_=subplot.svg_class, + transform=f"translate({rnd(subplot_x)} {y})", ) ) - for svg_items in svg_tree.root_groups.values(): - tree.add(svg_items) - tree_x += svg_tree.image_size[0] - assert self.tree_plotbox == svg_tree.plotbox + for svg_items in subplot.root_groups.values(): + svg_subplot.add(svg_items) + subplot_x += subplot.image_size[0] + + def find_used_trees(self): + """ + Return a boolean array of which trees are actually plotted, + a list of which breakpoints are used to transition between plotted trees, + and a 2 x n array (often n=0) of indexes into these breakpoints delimiting + the regions that should be plotted as "skipped" + """ + tree_is_used = (self.tree_status & OMIT) != OMIT + break_used_as_tree_left = np.append(tree_is_used, False) + break_used_as_tree_right = np.insert(tree_is_used, 0, False) + break_used = np.logical_or(break_used_as_tree_left, break_used_as_tree_right) + all_breaks = self.ts.breakpoints(True) + used_breaks = all_breaks[break_used] + mark_skip_transitions = np.concatenate( + ([False], np.diff(self.tree_status & OMIT_MIDDLE) != 0, [False]) + ) + skipregion_indexes = np.where(mark_skip_transitions[break_used])[0] + assert len(skipregion_indexes) % 2 == 0 # all skipped regions have start, end + return tree_is_used, used_breaks, skipregion_indexes.reshape((-1, 2)) def draw_x_axis( self, x_scale, - tick_length_lower=SvgPlot.default_tick_length, - tick_length_upper=SvgPlot.default_tick_length_site, + tree_is_used, + breaks, + skipbreaks, + tick_length_lower=SvgAxisPlot.default_tick_length, + tick_length_upper=SvgAxisPlot.default_tick_length_site, ): """ - Add extra functionality to the original draw_x_axis method in SvgPlot, mainly + Add extra functionality to the original draw_x_axis method in SvgAxisPlot, to account for the background shading that is displayed in a tree sequence + and in case trees are omitted from the middle of the tree sequence """ if not self.x_axis and not self.x_label: return - left_break_status = np.append(self.tree_status, OMIT) - right_break_status = np.insert(self.tree_status, 0, OMIT) - use_left = (left_break_status & OMIT) != OMIT - use_right = (right_break_status & OMIT) != OMIT - all_breaks = self.ts.breakpoints(True) - breaks = all_breaks[np.logical_or(use_left, use_right)] if x_scale == "physical": - # Assume the trees are simply concatenated end-to-end - self.x_transform = ( - lambda x: self.plotbox.left - + (x - breaks[0]) / (breaks[-1] - breaks[0]) * self.plotbox.width + # In a tree sequence plot, the x_transform is used for the ticks, background + # shading positions, and sites along the x-axis. Each tree will have its own + # separate x_transform function for node positions within the tree. + + # For a plot with a break on the x-axis (representing "skipped" trees), the + # x_transform is a piecewise function. We need to identify the breakpoints + # where the x-scale transitions from the standard scale to the scale(s) used + # within a skipped region + + skipregion_plot_width = self.tree_plotbox.width + skipregion_span = np.diff(breaks[skipbreaks]).T[0] + std_scale = ( + self.plotbox.width - skipregion_plot_width * len(skipregion_span) + ) / (breaks[-1] - breaks[0] - np.sum(skipregion_span)) + skipregion_pos = breaks[skipbreaks].flatten() + genome_pos = np.concatenate(([breaks[0]], skipregion_pos, [breaks[-1]])) + plot_step = np.full(len(genome_pos) - 1, skipregion_plot_width) + plot_step[::2] = std_scale * np.diff(genome_pos)[::2] + plot_pos = np.cumsum(np.insert(plot_step, 0, self.plotbox.left)) + # Convert to slope + intercept form + slope = np.diff(plot_pos) / np.diff(genome_pos) + intercept = plot_pos[1:] - slope * genome_pos[1:] + self.x_transform = lambda y: ( + y * slope[np.searchsorted(skipregion_pos, y)] + + intercept[np.searchsorted(skipregion_pos, y)] ) + tick_positions = breaks + site_muts = { + s.id: s.mutations + for tree, use in zip(self.ts.trees(), tree_is_used) + for s in tree.sites() + if use + } + self.shade_background( breaks, tick_length_lower, self.tree_plotbox.max_x, self.plotbox.pad_bottom + self.tree_plotbox.pad_bottom, ) - site_muts = {s.id: s.mutations for s in self.ts.sites()} - # omit tick on LHS for trees that have been clipped on left, and same on RHS - use_left = np.logical_and( - use_left, (left_break_status & LEFT_CLIPPED_BIT) != LEFT_CLIPPED_BIT - ) - use_right = np.logical_and( - use_right, (right_break_status & RIGHT_CLIPPED_BIT) != RIGHT_CLIPPED_BIT - ) - super().draw_x_axis( - tick_positions=all_breaks[np.logical_or(use_left, use_right)], - tick_length_lower=tick_length_lower, - tick_length_upper=tick_length_upper, - site_muts=site_muts, - ) - else: - # No background shading needed if x_scale is "treewise" + + # For a treewise plot, the only time the x_transform is used is to apply + # to tick positions, so simply use positions 0..num_used_breaks for the + # positions, and a simple transform self.x_transform = ( lambda x: self.plotbox.left + x / (len(breaks) - 1) * self.plotbox.width ) - super().draw_x_axis( - tick_positions=np.arange(len(breaks)), - tick_labels=breaks, - tick_length_lower=tick_length_lower, - ) + tick_positions = np.arange(len(breaks)) + + site_muts = None # It doesn't make sense to plot sites for "treewise" plots + tick_length_upper = None # No sites plotted, so use the default upper tick + + # NB: no background shading needed if x_scale is "treewise" + + skipregion_pos = skipbreaks.flatten() + + first_tick = 1 if np.any(self.tree_status[tree_is_used] & LEFT_CLIP) else 0 + last_tick = -1 if np.any(self.tree_status[tree_is_used] & RIGHT_CLIP) else None + + super().draw_x_axis( + tick_positions=tick_positions[first_tick:last_tick], + tick_labels=breaks[first_tick:last_tick], + tick_length_lower=tick_length_lower, + tick_length_upper=tick_length_upper, + site_muts=site_muts, + alternate_dash_positions=skipregion_pos, + ) -class SvgTree(SvgPlot): +class SvgTree(SvgAxisPlot): """ A class to draw a tree in SVG format. @@ -1335,8 +1510,8 @@ def assign_y_coordinates( max_time, min_time, force_root_branch, - bottom_space=SvgPlot.line_height, - top_space=SvgPlot.line_height, + bottom_space=SvgAxisPlot.line_height, + top_space=SvgAxisPlot.line_height, ): """ Create a self.node_height dict, a self.timescaling instance and @@ -1346,8 +1521,8 @@ def assign_y_coordinates( """ max_time = check_max_time(max_time, self.time_scale != "rank") min_time = check_min_time(min_time, self.time_scale != "rank") - node_time = self.ts.tables.nodes.time - mut_time = self.ts.tables.mutations.time + node_time = self.ts.nodes_time + mut_time = self.ts.mutations_time root_branch_len = 0 if self.time_scale == "rank": t = np.zeros_like(node_time) @@ -1358,7 +1533,8 @@ def assign_y_coordinates( else: # only rank the nodes that are actually referenced in the edge table # (non-referenced nodes could occur if the user specifies x_lim values) - use_time = referenced_nodes(self.ts) + # However, we do include nodes in trees that have been skipped + use_time = edge_and_sample_nodes(self.ts) t[use_time] = node_time[use_time] node_time = t times = np.unique(node_time[node_time <= self.ts.max_root_time]) @@ -1400,7 +1576,7 @@ def assign_y_coordinates( min_time = min(self.node_height.values()) # don't need to check mutation times, as they must be above a node elif min_time == "ts": - min_time = np.min(self.ts.tables.nodes.time[referenced_nodes(self.ts)]) + min_time = np.min(self.ts.nodes_time[edge_and_sample_nodes(self.ts)]) # In pathological cases, all the nodes are at the same time if min_time == max_time: max_time = min_time + 1 @@ -1967,7 +2143,6 @@ def _draw(self): self.canvas[y, xv] = mid_char self.canvas[y, left] = left_child self.canvas[y, right] = right_child - # print(self.canvas) if self.orientation == TOP: self.canvas = np.flip(self.canvas, axis=0) # Reverse the time positions so that we can use them in the tree @@ -2070,4 +2245,3 @@ def _draw(self): # Move the padding to the left. self.canvas[:, :-1] = self.canvas[:, 1:] self.canvas[:, -1] = " " - # print(self.canvas) diff --git a/python/tskit/genotypes.py b/python/tskit/genotypes.py index d0abfb3835..239e135777 100644 --- a/python/tskit/genotypes.py +++ b/python/tskit/genotypes.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -340,6 +340,17 @@ def _repr_html_(self) -> str: """ return util.variant_html(self) + def __repr__(self): + d = { + "site": self.site, + "samples": self.samples, + "alleles": self.alleles, + "genotypes": self.genotypes, + "has_missing_data": self.has_missing_data, + "isolated_as_missing": self.isolated_as_missing, + } + return f"Variant({repr(d)})" + # # Miscellaneous auxiliary methods. diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py new file mode 100644 index 0000000000..0c78c50b5b --- /dev/null +++ b/python/tskit/intervals.py @@ -0,0 +1,601 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# Copyright (C) 2020-2021 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +""" +Utilities for working with intervals and interval maps. +""" +from __future__ import annotations + +import collections.abc +import numbers + +import numpy as np + +import tskit +import tskit.util as util + + +class RateMap(collections.abc.Mapping): + """ + A class mapping a non-negative rate value to a set of non-overlapping intervals + along the genome. Intervals for which the rate is unknown (i.e., missing data) + are encoded by NaN values in the ``rate`` array. + + :param list position: A list of :math:`n+1` positions, starting at 0, and ending + in the sequence length over which the RateMap will apply. + :param list rate: A list of :math:`n` positive rates that apply between each + position. Intervals with missing data are encoded by NaN values. + """ + + # The args are marked keyword only to give us some flexibility in how we + # create class this in the future. + def __init__( + self, + *, + position, + rate, + ): + # Making the arrays read-only guarantees rate and cumulative mass stay in sync + # We prevent the arrays themselves being overwritten by making self.position, + # etc properties. + + # TODO we always coerce the position type to float here, but we may not + # want to do this. int32 is a perfectly good choice a lot of the time. + self._position = np.array(position, dtype=float) + self._position.flags.writeable = False + self._rate = np.array(rate, dtype=float) + self._rate.flags.writeable = False + size = len(self._position) + if size < 2: + raise ValueError("Must have at least two positions") + if len(self._rate) != size - 1: + raise ValueError( + "Rate array must have one less entry than the position array" + ) + if self._position[0] != 0: + raise ValueError("First position must be zero") + + span = self.span + if np.any(span <= 0): + bad_pos = np.where(span <= 0)[0] + 1 + raise ValueError( + f"Position values not strictly increasing at indexes {bad_pos}" + ) + if np.any(self._rate < 0): + bad_rates = np.where(self._rate < 0)[0] + raise ValueError(f"Rate values negative at indexes {bad_rates}") + self._missing = np.isnan(self.rate) + self._num_missing_intervals = np.sum(self._missing) + if self._num_missing_intervals == len(self.rate): + raise ValueError("All intervals are missing data") + # We don't expose the cumulative mass array as a part of the array + # API is it's not quite as obvious how it lines up for each interval. + # It's really the sum of the mass up to but not including the current + # interval, which is a bit confusing. Probably best to just leave + # it as a function, so that people can sample at regular positions + # along the genome anyway, emphasising that it's a continuous function, + # not a step function like the other interval attributes. + self._cumulative_mass = np.insert(np.nancumsum(self.mass), 0, 0) + assert self._cumulative_mass[0] == 0 + self._cumulative_mass.flags.writeable = False + + @property + def left(self): + """ + The left position of each interval (inclusive). + """ + return self._position[:-1] + + @property + def right(self): + """ + The right position of each interval (exclusive). + """ + return self._position[1:] + + @property + def mid(self): + """ + Returns the midpoint of each interval. + """ + mid = self.left + self.span / 2 + mid.flags.writeable = False + return mid + + @property + def span(self): + """ + Returns the span (i.e., ``right - left``) of each of the intervals. + """ + span = self.right - self.left + span.flags.writeable = False + return span + + @property + def position(self): + """ + The breakpoint positions between intervals. This is equal to the + :attr:`~.RateMap.left` array with the :attr:`sequence_length` + appended. + """ + return self._position + + @property + def rate(self): + """ + The rate associated with each interval. Missing data is encoded + by NaN values. + """ + return self._rate + + @property + def mass(self): + r""" + The "mass" of each interval, defined as the :attr:`~.RateMap.rate` + :math:`\times` :attr:`~.RateMap.span`. This is NaN for intervals + containing missing data. + """ + return self._rate * self.span + + @property + def missing(self): + """ + A boolean array encoding whether each interval contains missing data. + Equivalent to ``np.isnan(rate_map.rate)`` + """ + return self._missing + + @property + def non_missing(self): + """ + A boolean array encoding whether each interval contains non-missing data. + Equivalent to ``np.logical_not(np.isnan(rate_map.rate))`` + """ + return ~self._missing + + # + # Interval counts + # + + @property + def num_intervals(self) -> int: + """ + The total number of intervals in this map. Equal to + :attr:`~.RateMap.num_missing_intervals` + + :attr:`~.RateMap.num_non_missing_intervals`. + """ + return len(self._rate) + + @property + def num_missing_intervals(self) -> int: + """ + Returns the number of missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is a NaN. + """ + return self._num_missing_intervals + + @property + def num_non_missing_intervals(self) -> int: + """ + The number of non missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is not a NaN. + """ + return self.num_intervals - self.num_missing_intervals + + @property + def sequence_length(self): + """ + The sequence length covered by this map + """ + return self.position[-1] + + @property + def total_mass(self): + """ + The cumulative total mass over the entire map. + """ + return self._cumulative_mass[-1] + + @property + def mean_rate(self): + """ + The mean rate over this map weighted by the span covered by each rate. + Unknown intervals are excluded. + """ + total_span = np.sum(self.span[self.non_missing]) + return self.total_mass / total_span + + def get_rate(self, x): + """ + Return the rate at the specified list of positions. + + .. note:: This function will return a NaN value for any positions + that contain missing data. + + :param numpy.ndarray x: The positions for which to return values. + :return: An array of rates, the same length as ``x``. + :rtype: numpy.ndarray + """ + loc = np.searchsorted(self.position, x, side="right") - 1 + if np.any(loc < 0) or np.any(loc >= len(self.rate)): + raise ValueError("position out of bounds") + return self.rate[loc] + + def get_cumulative_mass(self, x): + """ + Return the cumulative mass of the map up to (but not including) a + given point for a list of positions along the map. This is equal to + the integral of the rate from 0 to the point. + + :param numpy.ndarray x: The positions for which to return values. + + :return: An array of cumulative mass values, the same length as ``x`` + :rtype: numpy.ndarray + """ + x = np.array(x) + if np.any(x < 0) or np.any(x > self.sequence_length): + raise ValueError(f"Cannot have positions < 0 or > {self.sequence_length}") + return np.interp(x, self.position, self._cumulative_mass) + + def find_index(self, x: float) -> int: + """ + Returns the index of the interval that the specified position falls within, + such that ``rate_map.left[index] <= x < self.rate_map.right[index]``. + + :param float x: The position to search. + :return: The index of the interval containing this point. + :rtype: int + :raises KeyError: if the position is not contained in any of the intervals. + """ + if x < 0 or x >= self.sequence_length: + raise KeyError(f"Position {x} out of bounds") + index = np.searchsorted(self.position, x, side="left") + if x < self.position[index]: + index -= 1 + assert self.left[index] <= x < self.right[index] + return index + + def missing_intervals(self): + """ + Returns the left and right coordinates of the intervals containing + missing data in this map as a 2D numpy array + with shape (:attr:`~.RateMap.num_missing_intervals`, 2). Each row + of this returned array is therefore a ``left``, ``right`` tuple + corresponding to the coordinates of the missing intervals. + + :return: A numpy array of the coordinates of intervals containing + missing data. + :rtype: numpy.ndarray + """ + out = np.empty((self.num_missing_intervals, 2)) + out[:, 0] = self.left[self.missing] + out[:, 1] = self.right[self.missing] + return out + + def asdict(self): + return {"position": self.position, "rate": self.rate} + + # + # Dunder methods. We implement the Mapping protocol via __iter__, __len__ + # and __getitem__. We have some extra semantics for __getitem__, providing + # slice notation. + # + + def __iter__(self): + # The clinching argument for using mid here is that if we used + # left instead we would have + # RateMap([0, 1], [0.1]) == RateMap([0, 100], [0.1]) + # by the inherited definition of equality since the dictionary items + # would be equal. + # Similarly, we only return the midpoints of known intervals + # because NaN values are not equal, and we would need to do + # something to work around this. It seems reasonable that + # this high-level operation returns the *known* values only + # anyway. + yield from self.mid[self.non_missing] + + def __len__(self): + return np.sum(self.non_missing) + + def __getitem__(self, key): + if isinstance(key, slice): + if key.step is not None: + raise TypeError("Only interval slicing is supported") + return self.slice(key.start, key.stop) + if isinstance(key, numbers.Number): + index = self.find_index(key) + if np.isnan(self.rate[index]): + # To be consistent with the __iter__ definition above we + # don't consider these missing positions to be "in" the map. + raise KeyError(f"Position {key} is within a missing interval") + return self.rate[index] + # TODO we could implement numpy array indexing here and call + # to get_rate. Note we'd need to take care that we return a keyerror + # if the returned array contains any nans though. + raise KeyError("Key {key} not in map") + + def _text_header_and_rows(self, limit=None): + headers = ("left", "right", "mid", "span", "rate") + num_rows = len(self.left) + rows = [] + row_indexes = util.truncate_rows(num_rows, limit) + for j in row_indexes: + if j == -1: + rows.append(f"__skipped__{num_rows-limit}") + else: + rows.append( + [ + f"{self.left[j]:.10g}", + f"{self.right[j]:.10g}", + f"{self.mid[j]:.10g}", + f"{self.span[j]:.10g}", + f"{self.rate[j]:.2g}", + ] + ) + return headers, rows + + def __str__(self): + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + table = util.unicode_table( + rows=rows, + header=header, + column_alignments="<<>>>", + ) + return table + + def _repr_html_(self): + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + return util.html_table(rows, header=header) + + def __repr__(self): + return f"RateMap(position={repr(self.position)}, rate={repr(self.rate)})" + + # + # Methods for building rate maps. + # + + def copy(self) -> RateMap: + """ + Returns a deep copy of this RateMap. + """ + # We take read-only copies of the arrays in the constructor anyway, so + # no need for copying. + return RateMap(position=self.position, rate=self.rate) + + def slice(self, left=None, right=None, *, trim=False) -> RateMap: # noqa: A003 + """ + Returns a subset of this rate map in the specified interval. + + :param float left: The left coordinate (inclusive) of the region to keep. + If ``None``, defaults to 0. + :param float right: The right coordinate (exclusive) of the region to keep. + If ``None``, defaults to the sequence length. + :param bool trim: If True, remove the flanking regions such that the + sequence length of the new rate map is ``right`` - ``left``. If ``False`` + (default), do not change the coordinate system and mark the flanking + regions as "unknown". + :return: A new RateMap instance + :rtype: RateMap + """ + left = 0 if left is None else left + right = self.sequence_length if right is None else right + if not (0 <= left < right <= self.sequence_length): + raise KeyError(f"Invalid slice: left={left}, right={right}") + + i = self.find_index(left) + j = i + np.searchsorted(self.position[i:], right, side="right") + if right > self.position[j - 1]: + j += 1 + + position = self.position[i:j].copy() + rate = self.rate[i : j - 1].copy() + position[0] = left + position[-1] = right + + if trim: + # Return trimmed map with changed coords + return RateMap(position=position - left, rate=rate) + + # Need to check regions before & after sliced region are filled out: + if left != 0: + if np.isnan(rate[0]): + position[0] = 0 # Extend + else: + rate = np.insert(rate, 0, np.nan) # Prepend + position = np.insert(position, 0, 0) + if right != self.position[-1]: + if np.isnan(rate[-1]): + position[-1] = self.sequence_length # Extend + else: + rate = np.append(rate, np.nan) # Append + position = np.append(position, self.position[-1]) + return RateMap(position=position, rate=rate) + + @staticmethod + def uniform(sequence_length, rate) -> RateMap: + """ + Create a uniform rate map + """ + return RateMap(position=[0, sequence_length], rate=[rate]) + + @staticmethod + def read_hapmap( + fileobj, + sequence_length=None, + *, + has_header=True, + position_col=None, + rate_col=None, + map_col=None, + ): + # Black barfs with an INTERNAL_ERROR trying to reformat this docstring, + # so we explicitly disable reformatting here. + # fmt: off + """ + Parses the specified file in HapMap format and returns a :class:`.RateMap`. + HapMap files must white-space-delimited, and by default are assumed to + contain a single header line (which is ignored). Each subsequent line + then contains a physical position (in base pairs) and either a genetic + map position (in centiMorgans) or a recombination rate (in centiMorgans + per megabase). The value in the rate column in a given line gives the + constant rate between the physical position in that line (inclusive) and the + physical position on the next line (exclusive). + By default, the second column of the file is taken + as the physical position and the fourth column is taken as the genetic + position, as seen in the following sample of the format:: + + Chromosome Position(bp) Rate(cM/Mb) Map(cM) + chr10 48232 0.1614 0.002664 + chr10 48486 0.1589 0.002705 + chr10 50009 0.159 0.002947 + chr10 52147 0.1574 0.003287 + ... + chr10 133762002 3.358 181.129345 + chr10 133766368 0.000 181.144008 + + In the example above, the first row has a nonzero genetic map position + (last column, cM), implying a nonzero recombination rate before that + position, that is assumed to extend to the start of the chromosome + (at position 0 bp). However, if the first line has a nonzero bp position + (second column) and a zero genetic map position (last column, cM), + then the recombination rate before that position is *unknown*, producing + :ref:`missing data `. + + .. note:: + The rows are all assumed to come from the same contig, and the + first column is currently ignored. Therefore if you have a single + file containing several contigs or chromosomes, you must must split + it up into multiple files, and pass each one separately to this + function. + + :param str fileobj: Filename or file to read. This is passed directly + to :func:`numpy.loadtxt`, so if the filename extension is .gz or .bz2, + the file is decompressed first + :param float sequence_length: The total length of the map. If ``None``, + then assume it is the last physical position listed in the file. + Otherwise it must be greater then or equal to the last physical + position in the file, and the region between the last physical position + and the sequence_length is padded with a rate of zero. + :param bool has_header: If True (default), assume the file has a header row + and ignore the first line of the file. + :param int position_col: The zero-based index of the column in the file + specifying the physical position in base pairs. If ``None`` (default) + assume an index of 1 (i.e. the second column). + :param int rate_col: The zero-based index of the column in the file + specifying the rate in cM/Mb. If ``None`` (default) do not use the rate + column, but calculate rates using the genetic map positions, as + specified in ``map_col``. If the rate column is used, the + interval from 0 to first physical position in the file is marked as + unknown, and the last value in the rate column must be zero. + :param int map_col: The zero-based index of the column in the file + specifying the genetic map position in centiMorgans. If ``None`` + (default), assume an index of 3 (i.e. the fourth column). If the first + genetic position is 0 the interval from position 0 to the first + physical position in the file is marked as unknown. Otherwise, act + as if an additional row, specifying physical position 0 and genetic + position 0, exists at the start of the file. + :return: A RateMap object. + :rtype: RateMap + """ + # fmt: on + column_defs = {} # column definitions passed to np.loadtxt + if rate_col is None and map_col is None: + # Default to map_col + map_col = 3 + elif rate_col is not None and map_col is not None: + raise ValueError("Cannot specify both rate_col and map_col") + if map_col is not None: + column_defs[map_col] = ("map", float) + else: + column_defs[rate_col] = ("rate", float) + position_col = 1 if position_col is None else position_col + if position_col in column_defs: + raise ValueError( + "Cannot specify the same columns for position_col and " + "rate_col or map_col" + ) + column_defs[position_col] = ("pos", int) + + column_names = [c[0] for c in column_defs.values()] + column_data = np.loadtxt( + fileobj, + skiprows=1 if has_header else 0, + dtype=list(column_defs.values()), + usecols=list(column_defs.keys()), + unpack=True, + ) + data = dict(zip(column_names, column_data)) + + if "map" not in data: + assert "rate" in data + if data["rate"][-1] != 0: + raise ValueError("The last entry in the 'rate' column must be zero") + pos_Mb = data["pos"] / 1e6 + map_pos = np.cumsum(data["rate"][:-1] * np.diff(pos_Mb)) + data["map"] = np.insert(map_pos, 0, 0) / 100 + else: + data["map"] /= 100 # Convert centiMorgans to Morgans + if len(data["map"]) == 0: + raise ValueError("Empty hapmap file") + + # TO DO: read in chrom name from col 0 and poss set as .name + # attribute on the RateMap + + physical_positions = data["pos"] + genetic_positions = data["map"] + start = physical_positions[0] + end = physical_positions[-1] + + if genetic_positions[0] > 0 and start == 0: + raise ValueError( + "The map distance at the start of the chromosome must be zero" + ) + if start > 0: + physical_positions = np.insert(physical_positions, 0, 0) + if genetic_positions[0] > 0: + # Exception for a map that starts > 0cM: include the start rate + # in the mean + start = 0 + genetic_positions = np.insert(genetic_positions, 0, 0) + + if sequence_length is not None: + if sequence_length < end: + raise ValueError( + "The sequence_length cannot be less than the last physical position " + f" ({physical_positions[-1]})" + ) + if sequence_length > end: + physical_positions = np.append(physical_positions, sequence_length) + genetic_positions = np.append(genetic_positions, genetic_positions[-1]) + + assert genetic_positions[0] == 0 + rate = np.diff(genetic_positions) / np.diff(physical_positions) + if start != 0: + rate[0] = np.nan + if end != physical_positions[-1]: + rate[-1] = np.nan + return RateMap(position=physical_positions, rate=rate) diff --git a/python/tskit/provenance.py b/python/tskit/provenance.py index 82fb19518a..bc88e29f1a 100644 --- a/python/tskit/provenance.py +++ b/python/tskit/provenance.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -117,7 +117,7 @@ def validate_provenance(provenance): :param dict provenance: The dictionary representing a JSON document to be validated against the schema. - :raises: :class:`tskit.ProvenanceValidationError` + :raises ProvenanceValidationError: if the schema is not valid. """ schema = get_schema() try: diff --git a/python/tskit/stats.py b/python/tskit/stats.py index 3e161662b1..a972b01454 100644 --- a/python/tskit/stats.py +++ b/python/tskit/stats.py @@ -215,24 +215,49 @@ def resample_blocks(self, block_multiplier): ) if self.cum_weights[-1, i] > 0: self.quantile[:, i] = self.cum_weights[:, i] / self.cum_weights[-1, i] + else: + self.quantile[:, i] = np.nan class CoalescenceTimeDistribution: """ Class to precompute a table of sorted/weighted node times, from which to calculate the empirical distribution function and estimate coalescence rates in time windows. + + To compute weights efficiently requires an update operation of the form: + + ``output[parent], state[parent] = update(state[children])`` + + where ``output`` are the weights associated with the node, and ``state`` + are values that are needed to compute ``output`` that are recursively + calculated along the tree. The value of ``state`` on the leaves is + initialized via, + + ``state[sample] = initialize(sample, sample_sets)`` """ @staticmethod - def _count_coalescence_events(node, tree, sample_sets): - # TODO this will count unary nodes: should it count nodes - # with >1 child instead? - return np.array([1], dtype=np.int32) + def _count_coalescence_events(): + """ + Count the number of samples that coalesce in ``node``, within each + set of samples in ``sample_sets``. + """ + + def initialize(node, sample_sets): + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + return (singles,) + + def update(singles_per_child): + singles = np.sum(singles_per_child, axis=0, keepdims=True) + is_ancestor = (singles > 0).astype(np.float64) + return is_ancestor, (singles,) + + return (initialize, update) @staticmethod - def _count_pair_coalescence_events(node, tree, sample_sets): + def _count_pair_coalescence_events(): """ - Count the number of pairs that coalesce in node, within and between the + Count the number of pairs that coalesce in ``node``, within and between the sets of samples in ``sample_sets``. The count of pairs with members that belong to sets :math:`a` and :math:`b` is: @@ -246,30 +271,29 @@ def _count_pair_coalescence_events(node, tree, sample_sets): correspond to counts of pairs with set labels ``[(0,0), (0,1), (1,1)]``. """ - # TODO needs to be optimized, use np.intersect1d - children = tree.children(node) - samples_per_child = [set(list(tree.samples(c))) for c in children] - sample_counts = np.zeros((len(sample_sets), len(children)), dtype=np.int32) - for i, s1 in enumerate(samples_per_child): - for a, s2 in enumerate([set(s) for s in sample_sets]): - sample_counts[a, i] = len(s1 & s2) - - pair_counts = [] - for a, b in itertools.combinations_with_replacement( - range(sample_counts.shape[0]), 2 - ): - count = 0 - for i, j in itertools.combinations(range(sample_counts.shape[1]), 2): - count += ( - sample_counts[a, i] * sample_counts[b, j] - + sample_counts[a, j] * sample_counts[b, i] - ) / (1 + int(a == b)) - pair_counts.append(count) - - return np.array(pair_counts, dtype=np.int32) + def initialize(node, sample_sets): + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + return (singles,) + + def update(singles_per_child): + C = singles_per_child.shape[0] # number of children + S = singles_per_child.shape[1] # number of sample sets + singles = np.sum(singles_per_child, axis=0, keepdims=True) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + for a, b in itertools.combinations(range(C), 2): + for i, (j, k) in enumerate( + itertools.combinations_with_replacement(range(S), 2) + ): + pairs[0, i] += ( + singles_per_child[a, j] * singles_per_child[b, k] + + singles_per_child[a, k] * singles_per_child[b, j] + ) / (1 + int(j == k)) + return pairs, (singles,) + + return (initialize, update) @staticmethod - def _count_trio_first_coalescence_events(node, tree, sample_sets): + def _count_trio_first_coalescence_events(): """ Count the number of pairs that coalesce in node with an outgroup, within and between the sets of samples in ``sample_sets``. In other @@ -290,88 +314,160 @@ def _count_trio_first_coalescence_events(node, tree, sample_sets): correspond to counts of pairs with set labels, ``[((0,0),0), ((0,0),1), ..., ((0,1),0), ((0,1),1), ...]``. """ - samples = list(tree.samples(node)) - outg_counts = [len(s) - len(np.intersect1d(samples, s)) for s in sample_sets] - pair_counts = CoalescenceTimeDistribution._count_pair_coalescence_events( - node, tree, sample_sets - ) - trio_counts = [] - for i in pair_counts: - for j in outg_counts: - trio_counts.append(i * j) - return np.array(trio_counts, dtype=np.int32) - def _update_weights_by_edge_diff(self, tree, edge_diff, running_weights): + def initialize(node, sample_sets): + S = len(sample_sets) + totals = np.array([[len(s) for s in sample_sets]], dtype=np.float64) + singles = np.array([[node in s for s in sample_sets]], dtype=np.float64) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + return ( + totals, + singles, + pairs, + ) + + def update(totals_per_child, singles_per_child, pairs_per_child): + C = totals_per_child.shape[0] # number of children + S = totals_per_child.shape[1] # number of sample sets + totals = np.mean(totals_per_child, axis=0, keepdims=True) + singles = np.sum(singles_per_child, axis=0, keepdims=True) + pairs = np.zeros((1, int(S * (S + 1) / 2)), dtype=np.float64) + for a, b in itertools.combinations(range(C), 2): + pair_iterator = itertools.combinations_with_replacement(range(S), 2) + for i, (j, k) in enumerate(pair_iterator): + pairs[0, i] += ( + singles_per_child[a, j] * singles_per_child[b, k] + + singles_per_child[a, k] * singles_per_child[b, j] + ) / (1 + int(j == k)) + outgr = totals - singles + trios = np.zeros((1, pairs.size * outgr.size), dtype=np.float64) + trio_iterator = itertools.product(range(pairs.size), range(outgr.size)) + for i, (j, k) in enumerate(trio_iterator): + trios[0, i] += pairs[0, j] * outgr[0, k] + return trios, ( + totals, + singles, + pairs, + ) + + return (initialize, update) + + def _update_running_with_edge_diff( + self, tree, edge_diff, running_output, running_state, running_index + ): """ - Update ``running_weights`` to reflect ``tree`` using edge differences - ``edge_diff`` with the previous tree. + Update ``running_output`` and ``running_state`` to reflect ``tree``, + using edge differences ``edge_diff`` with the previous tree. + The dict ``running_index`` maps node IDs onto rows of the running arrays. """ assert edge_diff.interval == tree.interval - # nodes that have been removed from tree - removed = {i.child for i in edge_diff.edges_out if tree.is_isolated(i.child)} - # TODO: What if sample is removed from tree? In that case should all - # nodes be updated for trio first coalescences? - - # nodes where descendant subtree has been altered - modified = {i.parent for i in edge_diff.edges_in} - for i in copy.deepcopy(modified): - while tree.parent(i) != tskit.NULL and not tree.parent(i) in modified: + # empty rows in the running arrays + available_rows = {i for i in range(self.running_array_size)} + available_rows -= set(running_index.values()) + + # find internal nodes that have been removed from tree or are unary + removed_nodes = set() + for i in edge_diff.edges_out: + for j in [i.child, i.parent]: + if tree.num_children(j) < 2 and not tree.is_sample(j): + removed_nodes.add(j) + + # find non-unary nodes where descendant subtree has been altered + modified_nodes = { + i.parent for i in edge_diff.edges_in if tree.num_children(i.parent) > 1 + } + for i in copy.deepcopy(modified_nodes): + while tree.parent(i) != tskit.NULL and not tree.parent(i) in modified_nodes: i = tree.parent(i) - modified.add(i) + if tree.num_children(i) > 1: + modified_nodes.add(i) + + # clear running state/output for nodes that are no longer in tree + for i in removed_nodes: + if i in running_index: + running_state[running_index[i], :] = 0 + running_output[running_index[i], :] = 0 + available_rows.add(running_index.pop(i)) + + # recalculate state/output for nodes whose descendants have changed + for i in sorted(modified_nodes, key=lambda node: tree.time(node)): + children = [] + for c in tree.children(i): # skip unary children + while tree.num_children(c) == 1: + (c,) = tree.children(c) + children.append(c) + child_index = [running_index[c] for c in children] + + inputs = ( + running_state[child_index][:, state_index] + for state_index in self.state_indices + ) + output, state = self._update(*inputs) - # recalculate weights for current tree - for i in removed: - running_weights[i, :] = 0 - for i in modified: - running_weights[i, :] = self.weight_func(i, tree, self.sample_sets) - self.weight_func_evals += len(modified) + # update running state/output arrays + if i not in running_index: + running_index[i] = available_rows.pop() + running_output[running_index[i], :] = output + for state_index, x in zip(self.state_indices, state): + running_state[running_index[i], state_index] = x + + # track the number of times the weight function was called + self.weight_func_evals += len(modified_nodes) def _build_ecdf_table_for_window( - self, left, right, tree, edge_diffs, running_weights + self, + left, + right, + tree, + edge_diffs, + running_output, + running_state, + running_index, ): """ - Construct ECDF table for genomic interval [left, right]. Update ``tree``, - ``edge_diffs``, and ``running_weights`` for input for next window. Trees are - counted as belonging to any interval with which they overlap, and thus - can be used in several intervals. Thus, the concatenation of ECDF - tables across multiple intervals is not the same as the ECDF table - for the union of those intervals. Trees within intervals are chunked - into roughly equal-sized blocks for bootstrapping. + Construct ECDF table for genomic interval [left, right]. Update + ``tree``; ``edge_diffs``; and ``running_output``, ``running_state``, + `running_idx``; for input for next window. Trees are counted as + belonging to any interval with which they overlap, and thus can be used + in several intervals. Thus, the concatenation of ECDF tables across + multiple intervals is not the same as the ECDF table for the union of + those intervals. Trees within intervals are chunked into roughly + equal-sized blocks for bootstrapping. """ assert tree.interval.left <= left and right > left + # TODO: if bootstrapping, block span needs to be tracked + # and used to renormalise each replicate. This should be + # done by the bootstrapping machinery, not here. + # assign trees in window to equal-sized blocks with unique id - other_tree = tree.copy() - # TODO: is a full copy of the tree needed, given that the original is - # mutated below? - if right >= other_tree.tree_sequence.sequence_length: - other_tree.last() - else: - # other_tree.seek(right) won't work if `right` is recomb breakpoint - while other_tree.interval.right < right: - other_tree.next() - tree_idx = np.arange(tree.index, other_tree.index + 1) - tree.index tree_offset = tree.index + if right >= tree.tree_sequence.sequence_length: + tree.last() + else: + # tree.seek(right) won't work if `right` is recomb breakpoint + while tree.interval.right < right: + tree.next() + tree_idx = np.arange(tree_offset, tree.index + 1) - tree_offset num_blocks = min(self.num_blocks, len(tree_idx)) tree_blocks = np.floor_divide(num_blocks * tree_idx, len(tree_idx)) # calculate span weights - # TODO: if bootstrapping, does block span need to be tracked - # and used to renormalise each replicate? - other_tree.seek(tree.interval.left) - tree_span = [ - min(other_tree.interval.right, right) - max(other_tree.interval.left, left) - ] - while other_tree.index < tree_offset + tree_idx[-1]: - other_tree.next() + tree.seek_index(tree_offset) + tree_span = [min(tree.interval.right, right) - max(tree.interval.left, left)] + while tree.index < tree_offset + tree_idx[-1]: + tree.next() tree_span.append( - min(other_tree.interval.right, right) - - max(other_tree.interval.left, left) + min(tree.interval.right, right) - max(tree.interval.left, left) ) - tree_span = np.array(tree_span) / sum(tree_span) + tree_span = np.array(tree_span) + total_span = np.sum(tree_span) + assert np.isclose( + total_span, min(right, tree.tree_sequence.sequence_length) - left + ) # storage if using single window, block for entire tree sequence buffer_size = self.buffer_size @@ -381,49 +477,64 @@ def _build_ecdf_table_for_window( weights = np.zeros((table_size, self.num_weights)) # assemble table of coalescence times in window + num_record = 0 + accessible_span = 0.0 + span_weight = 1.0 indices = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 last_block = np.zeros(tree.tree_sequence.num_nodes, dtype=np.int32) - 1 - num_record = 0 + tree.seek_index(tree_offset) while tree.index != tskit.NULL: if tree.interval.right > left: current_block = tree_blocks[tree.index - tree_offset] if self.span_normalise: - span_weight = tree_span[tree.index - tree_offset] - else: - span_weight = 1.0 - nodes_in_tree = np.array( - [i for i in tree.nodes() if tree.is_internal(i)] - ) - # TODO this will fail if all nodes are isolated (masked tree) - nodes_to_add = nodes_in_tree[ - np.where(last_block[nodes_in_tree] != current_block) - ] - if len(nodes_to_add) > 0: - idx = np.arange(num_record, num_record + len(nodes_to_add)) - last_block[nodes_to_add] = current_block - indices[nodes_to_add] = idx - if table_size < num_record + len(nodes_to_add): - table_size += buffer_size - time = np.pad(time, (0, buffer_size)) - block = np.pad(block, (0, buffer_size)) - weights = np.pad(weights, ((0, buffer_size), (0, 0))) - time[idx] = [tree.time(i) for i in nodes_to_add] - block[idx] = current_block - num_record += len(nodes_to_add) - weights[indices[nodes_in_tree], :] += ( - span_weight * running_weights[nodes_in_tree, :] + span_weight = tree_span[tree.index - tree_offset] / total_span + + # TODO: shouldn't need to loop over all keys (nodes) for every tree + internal_nodes = np.array( + [i for i in running_index.keys() if not tree.is_sample(i)], + dtype=np.int32, ) + if internal_nodes.size > 0: + accessible_span += tree_span[tree.index - tree_offset] + rows_in_running = np.array( + [running_index[i] for i in internal_nodes], dtype=np.int32 + ) + nodes_to_add = internal_nodes[ + last_block[internal_nodes] != current_block + ] + if nodes_to_add.size > 0: + table_idx = np.arange( + num_record, num_record + len(nodes_to_add) + ) + last_block[nodes_to_add] = current_block + indices[nodes_to_add] = table_idx + if table_size < num_record + len(nodes_to_add): + table_size += buffer_size + time = np.pad(time, (0, buffer_size)) + block = np.pad(block, (0, buffer_size)) + weights = np.pad(weights, ((0, buffer_size), (0, 0))) + time[table_idx] = [tree.time(i) for i in nodes_to_add] + block[table_idx] = current_block + num_record += len(nodes_to_add) + weights[indices[internal_nodes], :] += ( + span_weight * running_output[rows_in_running, :] + ) + if tree.interval.right < right: # if current tree does not cross window boundary, move to next tree.next() - self._update_weights_by_edge_diff( - tree, next(edge_diffs), running_weights + self._update_running_with_edge_diff( + tree, next(edge_diffs), running_output, running_state, running_index ) else: # use current tree as initial tree for next window break + # reweight span so that weights are averaged over nonmissing trees + if self.span_normalise: + weights *= total_span / accessible_span + return CoalescenceTimeTable(time, block, weights) def _generate_ecdf_tables(self, ts, window_breaks): @@ -437,11 +548,35 @@ def _generate_ecdf_tables(self, ts, window_breaks): tree = ts.first() edge_diffs = ts.edge_diffs() - running_weights = np.zeros((ts.num_nodes, self.num_weights)) - self._update_weights_by_edge_diff(tree, next(edge_diffs), running_weights) + + # initialize running arrays for first tree + running_index = {i: n for i, n in enumerate(tree.samples())} + running_output = np.zeros( + (self.running_array_size, self.num_weights), + dtype=np.float64, + ) + running_state = np.zeros( + (self.running_array_size, self.num_states), + dtype=np.float64, + ) + for node in tree.samples(): + state = self._initialize(node, self.sample_sets) + for state_index, x in zip(self.state_indices, state): + running_state[running_index[node], state_index] = x + + self._update_running_with_edge_diff( + tree, next(edge_diffs), running_output, running_state, running_index + ) + for left, right in zip(window_breaks[:-1], window_breaks[1:]): yield self._build_ecdf_table_for_window( - left, right, tree, edge_diffs, running_weights + left, + right, + tree, + edge_diffs, + running_output, + running_state, + running_index, ) def __init__( @@ -463,18 +598,47 @@ def __init__( self.sample_sets = sample_sets if weight_func is None or weight_func == "coalescence_events": - self.weight_func = self._count_coalescence_events + self._initialize, self._update = self._count_coalescence_events() elif weight_func == "pair_coalescence_events": - self.weight_func = self._count_pair_coalescence_events + self._initialize, self._update = self._count_pair_coalescence_events() elif weight_func == "trio_first_coalescence_events": - self.weight_func = self._count_trio_first_coalescence_events + self._initialize, self._update = self._count_trio_first_coalescence_events() else: - assert callable(weight_func) - self.weight_func = weight_func - _weight_func_eval = self.weight_func(0, ts.first(), self.sample_sets) - assert isinstance(_weight_func_eval, np.ndarray) - assert _weight_func_eval.ndim == 1 - self.num_weights = len(_weight_func_eval) + # user supplies pair of callables ``(initialize, update)`` + assert isinstance(weight_func, tuple) + assert len(weight_func) == 2 + self._initialize, self._update = weight_func + assert callable(self._initialize) + assert callable(self._update) + + # check initialization operation + _state = self._initialize(0, self.sample_sets) + assert isinstance(_state, tuple) + self.num_states = 0 + self.state_indices = [] + for x in _state: + # ``assert is_row_vector(x)`` + assert isinstance(x, np.ndarray) + assert x.ndim == 2 + assert x.shape[0] == 1 + index = list(range(self.num_states, self.num_states + x.size)) + self.state_indices.append(index) + self.num_states += x.size + + # check update operation + _weights, _state = self._update(*_state) + assert isinstance(_state, tuple) + for state_index, x in zip(self.state_indices, _state): + # ``assert is_row_vector(x, len(state_index))`` + assert isinstance(x, np.ndarray) + assert x.ndim == 2 + assert x.shape[0] == 1 + assert x.size == len(state_index) + # ``assert is_row_vector(_weights)`` + assert isinstance(_weights, np.ndarray) + assert _weights.ndim == 2 + assert _weights.shape[0] == 1 + self.num_weights = _weights.size if window_breaks is None: window_breaks = np.array([0.0, ts.sequence_length]) @@ -499,6 +663,7 @@ def __init__( self.span_normalise = span_normalise self.buffer_size = ts.num_nodes + self.running_array_size = ts.num_samples * 2 - 1 # assumes no unary nodes self.weight_func_evals = 0 self.tables = [table for table in self._generate_ecdf_tables(ts, window_breaks)] @@ -533,13 +698,44 @@ def ecdf(self, times): values[:, :, k] = table.quantile[indices, :].T return values - # TODO - # - # def quantile(self, times): - # """ - # Return interpolated quantiles of coalescence times, using the same - # approach as numpy.quantile(..., method="linear") - # """ + def quantile(self, quantiles): + """ + Return interpolated quantiles of weighted coalescence times. + """ + + assert isinstance(quantiles, np.ndarray) + assert quantiles.ndim == 1 + assert np.all(np.logical_and(quantiles >= 0, quantiles <= 1)) + + values = np.empty((self.num_weights, quantiles.size, self.num_windows)) + values[:] = np.nan + for k, table in enumerate(self.tables): + # retrieve ECDF for each unique timepoint in table + last_index = np.flatnonzero(table.time[:-1] != table.time[1:]) + time = np.append(table.time[last_index], table.time[-1]) + ecdf = np.append( + table.quantile[last_index, :], table.quantile[[-1]], axis=0 + ) + for i in range(self.num_weights): + if not np.isnan(ecdf[-1, i]): + # interpolation requires strictly increasing arguments, so + # retrieve leftmost x for step-like F(x), including F(0) = 0. + assert ecdf[-1, i] == 1.0 + assert ecdf[0, i] == 0.0 + delta = ecdf[1:, i] - ecdf[:-1, i] + first_index = 1 + np.flatnonzero(delta > 0) + + n_eff = first_index.size + weight = delta[first_index - 1] + cum_weight = np.roll(ecdf[first_index, i], 1) + cum_weight[0] = 0 + midpoint = np.arange(n_eff) * weight + (n_eff - 1) * cum_weight + assert midpoint[0] == 0 + assert midpoint[-1] == n_eff - 1 + values[i, :, k] = np.interp( + quantiles * (n_eff - 1), midpoint, time[first_index] + ) + return values def num_coalesced(self, times): """ @@ -597,11 +793,12 @@ def mean(self, since=0.0): values[:, k] = np.nan else: for i in range(self.num_weights): - if table.cum_weights[-1, i] > 0: - multiplier = table.block_multiplier[table.block[index:]] + multiplier = table.block_multiplier[table.block[index:]] + weights = table.weights[index:, i] * multiplier + if np.any(weights > 0): values[i, k] = np.average( table.time[index:] - since, - weights=table.weights[index:, i] * multiplier, + weights=weights, ) return values diff --git a/python/tskit/tables.py b/python/tskit/tables.py index ea2ab364d2..c582654463 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -27,7 +27,6 @@ import collections.abc import dataclasses import datetime -import itertools import json import numbers import warnings @@ -613,6 +612,30 @@ def truncate(self, num_rows): """ return self.ll_table.truncate(num_rows) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + # We do this check here rather than in the C code because calling + # len() on the input will cause a more readable exception to be + # raised than the inscrutable errors we get from numpy when + # converting arguments of the wrong type. + if len(keep) != len(self): + msg = ( + "Argument for keep_rows must be a boolean array of " + "the same length as the table. " + f"(need:{len(self)}, got:{len(keep)})" + ) + raise ValueError(msg) + return self.ll_table.keep_rows(keep) + # Pickle support def __getstate__(self): return self.asdict() @@ -658,40 +681,12 @@ def __str__(self): def _repr_html_(self): """ - Called by jupyter notebooks to render tables + Called e.g. by jupyter notebooks to render tables """ headers, rows = self._text_header_and_rows( limit=tskit._print_options["max_lines"] ) - headers = "".join(f"{header}" for header in headers) - rows = ( - f'{row[11:]}' - f" rows skipped (tskit.set_print_options)" - if "__skipped__" in row - else "".join(f"{cell}" for cell in row) - for row in rows - ) - rows = "".join(f"{row}\n" for row in rows) - return f""" -
- - - - - {headers} - - - - {rows} - -
-
- """ + return util.html_table(rows, header=headers) def _columns_all_integer(self, *colnames): # For displaying floating point values without loads of decimal places @@ -852,15 +847,8 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "location", "parents", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1059,6 +1047,33 @@ def packset_parents(self, parents): d["parents_offset"] = offset self.set_columns(**d) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + The values in the ``parents`` column are updated according to this + map, so that reference integrity within the table is maintained. + As a consequence of this, the values in the ``parents`` column + for kept rows are bounds-checked and an error raised if they + are not valid. Rows that are deleted are not checked for + parent ID integrity. + + If an attempt is made to delete rows that are referred to by + the ``parents`` column of rows that are retained, an error + is raised. + + These error conditions are checked before any alterations to + the table are made. + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + return super().keep_rows(keep) + class NodeTable(MetadataTable): """ @@ -1105,16 +1120,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "population", "individual", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: row = self[j] if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") @@ -1306,16 +1314,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "parent", "child", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("left", "right") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1528,17 +1529,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "node", "source", "dest", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_coords = 0 if self._columns_all_integer("left", "right") else 8 decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1748,16 +1742,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "position", "ancestral_state", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("position") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1971,17 +1958,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "site", "node", "time", "derived_state", "parent", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) # Currently mutations do not have discretised times: this for consistency decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2182,6 +2162,33 @@ def packset_derived_state(self, derived_states): d["derived_state_offset"] = offset self.set_columns(**d) + def keep_rows(self, keep): + """ + .. include:: substitutions/table_keep_rows_main.rst + + The values in the ``parent`` column are updated according to this + map, so that reference integrity within the table is maintained. + As a consequence of this, the values in the ``parent`` column + for kept rows are bounds-checked and an error raised if they + are not valid. Rows that are deleted are not checked for + parent ID integrity. + + If an attempt is made to delete rows that are referred to by + the ``parent`` column of rows that are retained, an error + is raised. + + These error conditions are checked before any alterations to + the table are made. + + :param array-like keep: The rows to keep as a boolean array. Must + be the same length as the table, and convertible to a numpy + array of dtype bool. + :return: The mapping between old and new row IDs as a numpy + array (dtype int32). + :rtype: numpy.ndarray (dtype=np.int32) + """ + return super().keep_rows(keep) + class PopulationTable(MetadataTable): """ @@ -2232,15 +2239,8 @@ def add_row(self, metadata=None): def _text_header_and_rows(self, limit=None): headers = ("id", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2490,15 +2490,8 @@ def append_columns( def _text_header_and_rows(self, limit=None): headers = ("id", "timestamp", "record") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -3282,12 +3275,18 @@ def __getstate__(self): def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False): file, local_file = util.convert_file_like_to_open_file(file_or_path, "rb") ll_tc = _tskit.TableCollection() - ll_tc.load( - file, - skip_tables=skip_tables, - skip_reference_sequence=skip_reference_sequence, - ) - return TableCollection(ll_tables=ll_tc) + try: + ll_tc.load( + file, + skip_tables=skip_tables, + skip_reference_sequence=skip_reference_sequence, + ) + return TableCollection(ll_tables=ll_tc) + except tskit.FileFormatError as e: + util.raise_known_file_format_errors(file, e) + finally: + if local_file: + file.close() def dump(self, file_or_path): """ @@ -3348,6 +3347,8 @@ def simplify( filter_populations=None, filter_individuals=None, filter_sites=None, + filter_nodes=None, + update_sample_flags=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -3357,15 +3358,16 @@ def simplify( """ Simplifies the tables in place to retain only the information necessary to reconstruct the tree sequence describing the given ``samples``. - This will change the ID of the nodes, so that the node - ``samples[k]`` will have ID ``k`` in the result. The resulting - NodeTable will have only the first ``len(samples)`` nodes marked - as samples. The mapping from node IDs in the current set of tables to - their equivalent values in the simplified tables is also returned as a - numpy array. If an array ``a`` is returned by this function and ``u`` - is the ID of a node in the input table, then ``a[u]`` is the ID of this - node in the output table. For any node ``u`` that is not mapped into - the output tables, this mapping will equal ``-1``. + If ``filter_nodes`` is True (the default), this can change the ID of + the nodes, so that the node ``samples[k]`` will have ID ``k`` in the + result, resulting in a NodeTable where only the first ``len(samples)`` + nodes are marked as samples. The mapping from node IDs in the current + set of tables to their equivalent values in the simplified tables is + returned as a numpy array. If an array ``a`` is returned by this + function and ``u`` is the ID of a node in the input table, then + ``a[u]`` is the ID of this node in the output table. For any node ``u`` + that is not mapped into the output tables, this mapping will equal + ``-1``. Tables operated on by this function must: be sorted (see :meth:`TableCollection.sort`), have children be born strictly after their @@ -3374,10 +3376,11 @@ def simplify( requirements to specify a valid tree sequence (but the resulting tables will). - This is identical to :meth:`TreeSequence.simplify` but acts *in place* to - alter the data in this :class:`TableCollection`. Please see the - :meth:`TreeSequence.simplify` method for a description of the remaining - parameters. + .. seealso:: + This is identical to :meth:`TreeSequence.simplify` but acts *in place* to + alter the data in this :class:`TableCollection`. Please see the + :meth:`TreeSequence.simplify` method for a description of the remaining + parameters. :param list[int] samples: A list of node IDs to retain as samples. They need not be nodes marked as samples in the original tree sequence, but @@ -3399,6 +3402,15 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: None, treated as True) + :param bool filter_nodes: If True, remove any nodes that are + not referenced by edges after simplification. If False, the only + potential change to the node table may be to change the node flags + (if ``samples`` is specified and different from the existing samples). + (Default: None, treated as True) + :param bool update_sample_flags: If True, update node flags to so that + nodes in the specified list of samples have the NODE_IS_SAMPLE + flag after simplification, and nodes that are not in this list + do not. (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e. nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -3440,6 +3452,10 @@ def simplify( filter_individuals = True if filter_sites is None: filter_sites = True + if filter_nodes is None: + filter_nodes = True + if update_sample_flags is None: + update_sample_flags = True if keep_unary_in_individuals is None: keep_unary_in_individuals = False @@ -3448,6 +3464,8 @@ def simplify( filter_sites=filter_sites, filter_individuals=filter_individuals, filter_populations=filter_populations, + filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, reduce_to_site_topology=reduce_to_site_topology, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c228d5b9b2..03b0f069c6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -635,7 +635,7 @@ class Tree: :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. :param bool sample_lists: If True, provide more efficient access - to the samples beneath a give node using the + to the samples beneath a given node using the :meth:`Tree.samples` method. :param int root_threshold: The minimum number of samples that a node must be ancestral to for it to be in the list of roots. By default @@ -665,6 +665,8 @@ def __init__( if sample_lists: options |= _tskit.SAMPLE_LISTS kwargs = {"options": options} + if root_threshold <= 0: + raise ValueError("Root threshold must be greater than 0") if tracked_samples is not None: # TODO remove this when we allow numpy arrays in the low-level API. kwargs["tracked_samples"] = list(tracked_samples) @@ -725,7 +727,8 @@ def tree_sequence(self): def root_threshold(self): """ Returns the minimum number of samples that a node must be an ancestor - of to be considered a potential root. + of to be considered a potential root. This can be set, for example, when + calling the :meth:`TreeSequence.trees` iterator. :return: The root threshold. :rtype: :class:`TreeSequence` @@ -817,7 +820,6 @@ def seek_index(self, index): .. include:: substitutions/linear_traversal_warning.rst - :param int index: The tree index to seek to. :raises IndexError: If an index outside the acceptable range is provided. """ @@ -826,12 +828,7 @@ def seek_index(self, index): index += num_trees if index < 0 or index >= num_trees: raise IndexError("Index out of bounds") - # This should be implemented in C efficiently using the indexes. - # No point in complicating the current implementation by trying - # to seek from the correct direction. - self.first() - while self.index != index: - self.next() + self._ll_tree.seek_index(index) def seek(self, position): """ @@ -881,7 +878,7 @@ def unrank(num_leaves, rank, *, span=1, branch_length=1) -> Tree: from which the tree is taken will have its :attr:`~tskit.TreeSequence.sequence_length` equal to ``span``. :param: float branch_length: The minimum length of a branch in this tree. - :raises: ValueError: If the given rank is out of bounds for trees + :raises ValueError: If the given rank is out of bounds for trees with ``num_leaves`` leaves. """ rank_tree = combinatorics.RankTree.unrank(num_leaves, rank) @@ -1000,7 +997,9 @@ def mrca(self, *args): Returns the most recent common ancestor of the specified nodes. :param int `*args`: input node IDs, must be at least 2. - :return: The most recent common ancestor of input nodes. + :return: The node ID of the most recent common ancestor of the + input nodes, or :data:`tskit.NULL` if the nodes do not share + a common ancestor in the tree. :rtype: int """ if len(args) < 2: @@ -1205,6 +1204,33 @@ def right_sib_array(self): """ return self._right_sib_array + def siblings(self, u): + """ + Returns the sibling(s) of the specified node ``u`` as a tuple of integer + node IDs. If ``u`` has no siblings or is not a node in the current tree, + returns an empty tuple. If ``u`` is the root of a single-root tree, + returns an empty tuple; if ``u`` is the root of a multi-root tree, + returns the other roots (note all the roots are related by the virtual root). + If ``u`` is the virtual root (which has no siblings), returns an empty tuple. + If ``u`` is an isolated node, whether it has siblings or not depends on + whether it is a sample or non-sample node; if it is a sample node, + returns the root(s) of the tree, otherwise, returns an empty tuple. + The ordering of siblings is arbitrary and should not be depended on; + see the :ref:`data model ` section for details. + + :param int u: The node of interest. + :return: The siblings of ``u``. + :rtype: tuple(int) + """ + if u == self.virtual_root: + return tuple() + parent = self.parent(u) + if self.is_root(u): + parent = self.virtual_root + if parent != tskit.NULL: + return tuple(v for v in self.children(parent) if u != v) + return tuple() + @property def num_children_array(self): """ @@ -1546,6 +1572,12 @@ def roots(self): Only requires O(number of roots) time. + .. note:: + In trees with large amounts of :ref:`sec_data_model_missing_data`, + for example where a region of the genome lacks any ancestral information, + there can be a very large number of roots, potentially all the samples + in the tree sequence. + :return: The list of roots in this tree. :rtype: list """ @@ -1568,12 +1600,29 @@ def root(self): :return: The root node. :rtype: int - :raises: :class:`ValueError` if this tree contains more than one root. + :raises ValueError: if this tree contains more than one root. """ if self.has_multiple_roots: raise ValueError("More than one root exists. Use tree.roots instead") return self.left_root + def is_root(self, u) -> bool: + """ + Returns ``True`` if the specified node is a root in this tree (see + :attr:`~Tree.roots` for the definition of a root). This is exactly equivalent to + finding the node ID in :attr:`~Tree.roots`, but is more efficient for trees + with large numbers of roots, such as in regions with extensive + :ref:`sec_data_model_missing_data`. Note that ``False`` is returned for all + other nodes, including :ref:`isolated` + non-sample nodes which are not found in the topology of the current tree. + + :param int u: The node of interest. + :return: ``True`` if u is a root. + """ + return ( + self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL + ) + def get_index(self): # Deprecated alias for self.index return self.index @@ -2134,9 +2183,12 @@ def _sample_generator(self, u): def samples(self, u=None): """ Returns an iterator over the numerical IDs of all the sample nodes in - this tree that are underneath node ``u``. If ``u`` is a sample, it is - included in the returned iterator. If u is not specified, return all - sample node IDs in the tree. + this tree that are underneath the node with ID ``u``. If ``u`` is a sample, + it is included in the returned iterator. If ``u`` is not a sample, it is + possible for the returned iterator to be empty, for example if ``u`` is an + :meth:`isolated` node that is not part of the the current + topology. If u is not specified, return all sample node IDs in the tree + (equivalent to all the sample node IDs in the tree sequence). If the :meth:`TreeSequence.trees` method is called with ``sample_lists=True``, this method uses an efficient algorithm to find @@ -4040,6 +4092,8 @@ def load(cls, file_or_path, *, skip_tables=False, skip_reference_sequence=False) skip_reference_sequence=skip_reference_sequence, ) return TreeSequence(ts) + except tskit.FileFormatError as e: + util.raise_known_file_format_errors(file, e) finally: if local_file: file.close() @@ -4509,14 +4563,14 @@ def max_root_time(self): raise ValueError( "max_root_time is not defined in a tree sequence with 0 samples" ) - ret = max(self.node(u).time for u in self.samples()) + ret = max(self.nodes_time[u] for u in self.samples()) if self.num_edges > 0: # Edges are guaranteed to be listed in parent-time order, so we can get the # last one to get the oldest root edge = self.edge(self.num_edges - 1) # However, we can have situations where there is a sample older than a # 'proper' root - ret = max(ret, self.node(edge.parent).time) + ret = max(ret, self.nodes_time[edge.parent]) return ret def migrations(self): @@ -5157,10 +5211,10 @@ def haplotypes( *Deprecated in 0.3.0. Use ``isolated_as_missing``, but inverting value. Will be removed in a future version* :rtype: collections.abc.Iterable - :raises: TypeError if the ``missing_data_character`` or any of the alleles + :raises TypeError: if the ``missing_data_character`` or any of the alleles at a site are not a single ascii character. - :raises: ValueError - if the ``missing_data_character`` exists in one of the alleles + :raises ValueError: if the ``missing_data_character`` exists in one of the + alleles """ if impute_missing_data is not None: warnings.warn( @@ -5467,10 +5521,9 @@ def alignments( :return: An iterator over the alignment strings for specified samples in this tree sequence, in the order given in ``samples``. :rtype: collections.abc.Iterable - :raises: ValueError - if any genome coordinate in this tree sequence is not discrete, - or if the ``reference_sequence`` is not of the correct length. - :raises: TypeError if any of the alleles at a site are not a + :raises ValueError: if any genome coordinate in this tree sequence is not + discrete, or if the ``reference_sequence`` is not of the correct length. + :raises TypeError: if any of the alleles at a site are not a single ascii character. """ if not self.discrete_genome: @@ -6489,6 +6542,8 @@ def simplify( filter_populations=None, filter_individuals=None, filter_sites=None, + filter_nodes=None, + update_sample_flags=None, keep_unary=False, keep_unary_in_individuals=None, keep_input_roots=False, @@ -6503,15 +6558,10 @@ def simplify( original tree sequence, or :data:`tskit.NULL` (-1) if ``u`` is no longer present in the simplified tree sequence. - In the returned tree sequence, the node with ID ``0`` corresponds to - ``samples[0]``, node ``1`` corresponds to ``samples[1]`` etc., and all - the passed-in nodes are flagged as samples. The remaining node IDs in - the returned tree sequence are allocated sequentially in time order - and are not flagged as samples. - - If you wish to simplify a set of tables that do not satisfy all - requirements for building a TreeSequence, then use - :meth:`TableCollection.simplify`. + .. note:: + If you wish to simplify a set of tables that do not satisfy all + requirements for building a TreeSequence, then use + :meth:`TableCollection.simplify`. If the ``reduce_to_site_topology`` parameter is True, the returned tree sequence will contain only topological information that is necessary to @@ -6523,12 +6573,33 @@ def simplify( (up to node ID remapping) to the topology of the corresponding tree in the input tree sequence. - If ``filter_populations``, ``filter_individuals`` or ``filter_sites`` is - True, any of the corresponding objects that are not referenced elsewhere - are filtered out. As this is the default behaviour, it is important to - realise IDs for these objects may change through simplification. By setting - these parameters to False, however, the corresponding tables can be preserved - without changes. + If ``filter_populations``, ``filter_individuals``, ``filter_sites``, or + ``filter_nodes`` is True, any of the corresponding objects that are not + referenced elsewhere are filtered out. As this is the default behaviour, + it is important to realise IDs for these objects may change through + simplification. By setting these parameters to False, however, the + corresponding tables can be preserved without changes. + + If ``filter_nodes`` is False, then the output node table will be + unchanged except for updating the sample status of nodes and any ID + remappings caused by filtering individuals and populations (if the + ``filter_individuals`` and ``filter_populations`` options are enabled). + Nodes that are in the specified list of ``samples`` will be marked as + samples in the output, and nodes that are currently marked as samples + in the node table but not in the specified list of ``samples`` will + have their :data:`tskit.NODE_IS_SAMPLE` flag cleared. Note also that + the order of the ``samples`` list is not meaningful when + ``filter_nodes`` is False. In this case, the returned node mapping is + always the identity mapping, such that ``a[u] == u`` for all nodes. + + Setting the ``update_sample_flags`` parameter to False disables the + automatic sample status update of nodes (described above) from + occuring, making it the responsibility of calling code to keep track of + the ultimate sample status of nodes. This is an advanced option, mostly + of use when combined with the ``filter_nodes=False``, + ``filter_populations=False`` and ``filter_individuals=False`` options, + which then guarantees that the node table will not be altered by + simplification. :param list[int] samples: A list of node IDs to retain as samples. They need not be nodes marked as samples in the original tree sequence, but @@ -6554,6 +6625,15 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: None, treated as True) + :param bool filter_nodes: If True, remove any nodes that are + not referenced by edges after simplification. If False, the only + potential change to the node table may be to change the node flags + (if ``samples`` is specified and different from the existing samples). + (Default: None, treated as True) + :param bool update_sample_flags: If True, update node flags to so that + nodes in the specified list of samples have the NODE_IS_SAMPLE + flag after simplification, and nodes that are not in this list + do not. (Default: None, treated as True) :param bool keep_unary: If True, preserve unary nodes (i.e., nodes with exactly one child) that exist on the path from samples to root. (Default: False) @@ -6585,6 +6665,8 @@ def simplify( filter_populations=filter_populations, filter_individuals=filter_individuals, filter_sites=filter_sites, + filter_nodes=filter_nodes, + update_sample_flags=update_sample_flags, keep_unary=keep_unary, keep_unary_in_individuals=keep_unary_in_individuals, keep_input_roots=keep_input_roots, @@ -6820,147 +6902,47 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): tables.delete_older(time) return tables.tree_sequence() - def _extend(self, forwards=True): - print("forwards:", forwards) - num_edges = np.full(self.num_nodes, 0) - - t = self.tables - edges = t.edges.copy() - t.edges.clear() - new_left = edges.left - new_right = edges.right - - # edge diff stuff - M = edges.num_rows - if forwards: - I = self.indexes_edge_insertion_order - O = self.indexes_edge_removal_order - else: - I = np.flip(self.indexes_edge_removal_order) - O = np.flip(self.indexes_edge_insertion_order) - tj = 0 - tk = 0 - # "here" will be left if fowards else right - here = 0 if forwards else self.sequence_length - edges_out = [] - edges_in = [] - endpoint = self.sequence_length if forwards else 0 - sign = +1 if forwards else -1 - near_edge = edges.left if forwards else edges.right - far_edge = edges.right if forwards else edges.left - - while (tj < M) or (forwards and here < endpoint): - # clear out non-extended or postponed edges - edges_out = [[e, False] for e, x in edges_out if x] - edges_in = [[e, False] for e, x in edges_in if x] - - # Find edges_out between trees - while (tk < M) and (far_edge[O[tk]] == here): - edges_out.append([O[tk], False]) - num_edges[edges.parent[O[tk]]] -= 1 - num_edges[edges.child[O[tk]]] -= 1 - #print("Edge Out", tk, edges[O[tk]]) - tk += 1 - # Find edges_in between trees - while (tj < M) and (near_edge[I[tj]] == here): - edges_in.append([I[tj], False]) - num_edges[edges.parent[I[tj]]] += 1 - num_edges[edges.child[I[tj]]] += 1 - #print("Edge In", tj, edges[I[tj]]) - tj += 1 - - # Find smallest length right endpoint of all edges in edges_in and edges_out - # there should equal the endpoint of a T_k - there = self.sequence_length if forwards else 0 - if forwards: - if tk < M: - there = min(there, far_edge[O[tk]]) - if tj < M: - there = min(there, near_edge[I[tj]]) - else: - if tk < M: - there = max(there, far_edge[O[tk]]) - if tj < M: - there = max(there, near_edge[I[tj]]) - print("All Edges Out", edges_out) - print("All Edges In", edges_in) - assert np.all(num_edges >= 0) - print("-------------", here, len(edges_out), len(edges_in)) - for ex1 in edges_out: - #print("e1:", e1, [edges.parent[O[e1]], edges.child[O[e1]]], edges[O[e1]]) - if not ex1[1]: - e1 = ex1[0] - for ex2 in edges_out: - #print("e2:", e2, num_edges[edges.child[e2]], ":", [edges.parent[O[e2]], edges.child[O[e2]]]) - if not ex2[1]: - # need the intermediate node to not be present in - # the new tree - e2 = ex2[0] - if ((edges.parent[e1] == edges.child[e2]) - and (num_edges[edges.child[e2]] == 0)): - for ex_in in edges_in: - e_in = ex_in[0] - #print("ein", e_in, [edges.parent[I[e_in]], edges.child[I[e_in]]]) - if sign * far_edge[e_in] > sign * here: - if ( - edges.child[e1] == edges.child[e_in] - and edges.parent[e2] == edges.parent[e_in] - ): - print("EXTEND") - # extend e2->e1 and postpone e_in - ex1[1] = True - ex2[1] = True - ex_in[1] = True - if forwards: - new_right[e1] = there - new_right[e2] = there - new_left[e_in] = there - else: - new_left[e1] = there - new_left[e2] = there - new_right[e_in] = there - # amend num_edges: the intermediate - # node has 2 edges instead of 0 - num_edges[edges.parent[e1]] += 2 - # cleanup at end of loop - here = there - - for j in range(edges.num_rows): - left = new_left[j] - right = new_right[j] - if left < right: - e = edges[j].replace(left=left, right=right) - t.edges.append(e) - t.build_index() - return t.tree_sequence() - - - def extend_edges(self, max_iter=100): - ''' - Returns a new tree sequence whose unary nodes are extended to neighboring trees given the following condition: - While iterating over the tree sequence, in each tree, we identify connecting edges with unary nodes. - If an equivalent edge segment exists in the next tree without that unary node, - we extend the connecting edges from the previous tree into the next tree, - subsequently adding that unary node to the tree. - This in turn reduces the length of the edge just removed from the next tree, - and if its length becomes zero it is removed from the edge table. - - : param max_iters: (int) -- the number of iterations we analyze the tree sequence to edge extend. - The process will halt if there is no change in edge count over two consecutive iterations. Default = 100 - - :return: A new tree sequence with unary nodes extended across the tree sequence. - :rtype: tskit.TreeSequence - ''' - ts = self - last_num_edges = ts.num_edges - for _ in range(max_iter): - ts = ts._extend(forwards=True) - ts = ts._extend(forwards=False) - if ts.num_edges == last_num_edges: - break - else: - last_num_edges = ts.num_edges - return ts + def extend_edges(self, max_iter=10): + """ + Returns a new tree sequence in which the span covered by ancestral nodes + is "extended" to regions of the genome according to the following rule: + If an ancestral segment corresponding to node `n` has parent `p` and + child `c` on some portion of the genome, and on an adjacent segment of + genome `p` is the immediate parent of `c`, then `n` is inserted into the + edge from `p` to `c`. This involves extending the span of the edges + from `p` to `n` and `n` to `c` and reducing the span of the edge from + `p` to `c`. Since the latter edge may be removed entirely, this process + reduces (or at least does not increase) the number of edges in the tree + sequence. + + *Note:* this is a somewhat experimental operation, and is probably not + what you are looking for. + + The method works by iterating over the genome to look for edges that can + be extended in this way; the maximum number of such iterations is + controlled by ``max_iter``. + + The rationale is that we know that `n` carries a portion of the segment + of ancestral genome inherited by `c` from `p`, and so likely carries + the *entire* inherited segment (since the implication otherwise would + be that distinct recombined segments were passed down separately from + `p` to `c`). + + The method will not affect the marginal trees (so, following up with + `simplify` will recover the original tree sequence, possibly with edges + in a different order). + + :param int max_iters: The maximum number of iterations over the tree + sequence. Defaults to 10. + + :return: A new tree sequence with unary nodes extended. + :rtype: tskit.TreeSequence + """ + max_iter = int(max_iter) + if max_iter <= 0: + raise ValueError("max_iter must be a positive integer.") + ll_ts = self._ll_tree_sequence.extend_edges(max_iter) + return TreeSequence(ll_ts) def subset( self, @@ -7100,6 +7082,7 @@ def draw_svg( y_gridlines=None, omit_sites=None, canvas_size=None, + max_num_trees=None, **kwargs, ): """ @@ -7186,6 +7169,11 @@ def draw_svg( elements, allowing extra room e.g. for unusually long labels. If ``None`` take the canvas size to be the same as the target drawing size (see ``size``, above). Default: None + :param int max_num_trees: The maximum number of trees to plot. If there are + more trees than this in the tree sequence, the middle trees will be skipped + from the plot and a message "XX trees skipped" displayed in their place. + If ``None``, all the trees will be plotted: this can produce a very wide + plot if there are many trees in the tree sequence. Default: None :return: An SVG representation of a tree sequence. :rtype: SVGString @@ -7223,6 +7211,7 @@ def draw_svg( y_gridlines=y_gridlines, omit_sites=omit_sites, canvas_size=canvas_size, + max_num_trees=max_num_trees, **kwargs, ) output = draw.drawing.tostring() @@ -7612,6 +7601,48 @@ def __k_way_sample_set_stat( stat = stat[()] return stat + def __k_way_weighted_stat( + self, + ll_method, + k, + W, + indexes=None, + windows=None, + mode=None, + span_normalise=True, + polarised=False, + ): + W = np.asarray(W) + if indexes is None: + if W.shape[1] != k: + raise ValueError( + "Must specify indexes if there are not exactly {} columns " + "in W.".format(k) + ) + indexes = np.arange(k, dtype=np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + stat = self.__run_windowed_stat( + windows, + ll_method, + W, + indexes, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + if drop_dimension: + stat = stat.reshape(stat.shape[:-1]) + return stat + ############################################ # Statistics definitions ############################################ @@ -7749,8 +7780,80 @@ def divergence( span_normalise=span_normalise, ) - # JK: commenting this out for now to get the other methods well tested. - # Issue: https://github.com/tskit-dev/tskit/issues/201 + ############################################ + # Pairwise sample x sample statistics + ############################################ + + def _chunk_sequence_by_tree(self, num_chunks): + """ + Return list of (left, right) genome interval tuples that contain + approximately equal numbers of trees as a 2D numpy array. A + maximum of self.num_trees single-tree intervals can be returned. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(self.num_trees, num_chunks) + breakpoints = self.breakpoints(as_array=True)[:-1] + splits = np.array_split(breakpoints, num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunks.append((splits[j][0], splits[j + 1][0])) + chunks.append((splits[-1][0], self.sequence_length)) + return chunks + + @staticmethod + def _chunk_windows(windows, num_chunks): + """ + Returns a list of (at most) num_chunks windows, which represent splitting + up the specified list of windows into roughly equal work. + + Currently this is implemented by just splitting up into roughly equal + numbers of windows in each chunk. + """ + if num_chunks <= 0 or int(num_chunks) != num_chunks: + raise ValueError("Number of chunks must be an integer > 0") + num_chunks = min(len(windows) - 1, num_chunks) + splits = np.array_split(windows[:-1], num_chunks) + chunks = [] + for j in range(num_chunks - 1): + chunk = np.append(splits[j], splits[j + 1][0]) + chunks.append(chunk) + chunk = np.append(splits[-1], windows[-1]) + chunks.append(chunk) + return chunks + + def _parallelise_divmat_by_tree(self, num_threads, **kwargs): + """ + No windows were specified, so we can chunk up the whole genome by + tree, and do a simple sum of the results. + """ + + def worker(interval): + return self._ll_tree_sequence.divergence_matrix(interval, **kwargs) + + work = self._chunk_sequence_by_tree(num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: + results = pool.map(worker, work) + return sum(results) + + def _parallelise_divmat_by_window(self, windows, num_threads, **kwargs): + """ + We assume we have a number of windows that's >= to the number + of threads available, and let each thread have a chunk of the + windows. There will definitely cases where this leads to + pathological behaviour, so we may need a more sophisticated + strategy at some point. + """ + + def worker(sub_windows): + return self._ll_tree_sequence.divergence_matrix(sub_windows, **kwargs) + + work = self._chunk_windows(windows, num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, sub_windows) for sub_windows in work] + concurrent.futures.wait(futures) + return np.vstack([future.result() for future in futures]) + # def divergence_matrix(self, sample_sets, windows=None, mode="site"): # """ # Finds the mean divergence between pairs of samples from each set of @@ -7784,6 +7887,36 @@ def divergence( # A[w, i, j] = A[w, j, i] = x[w][k] # k += 1 # return A + # NOTE: see older definition of divmat here, which may be useful when documenting + # this function. See https://github.com/tskit-dev/tskit/issues/2781 + def divergence_matrix( + self, *, windows=None, samples=None, num_threads=0, mode=None + ): + windows_specified = windows is not None + windows = [0, self.sequence_length] if windows is None else windows + + mode = "site" if mode is None else mode + + # NOTE: maybe we want to use a different default for num_threads here, just + # following the approach in GNN + if num_threads <= 0: + D = self._ll_tree_sequence.divergence_matrix( + windows, samples=samples, mode=mode + ) + else: + if windows_specified: + D = self._parallelise_divmat_by_window( + windows, num_threads, samples=samples, mode=mode + ) + else: + D = self._parallelise_divmat_by_tree( + num_threads, samples=samples, mode=mode + ) + + if not windows_specified: + # Drop the windows dimension + D = D[0] + return D def genetic_relatedness( self, @@ -7897,6 +8030,51 @@ def genetic_relatedness( return out + def genetic_relatedness_weighted( + self, + W, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + polarised=False, + ): + r""" + Computes weighted genetic relatedness. If the k-th pair of indices is (i, j) + then the k-th column of output will be + :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, + where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample + a and sample b, summing over all pairs of samples in the tree sequence. + + :param numpy.ndarray W: An array of values with one row for each sample node and + one column for each set of weights. + :param list indexes: A list of 2-tuples, or None (default). Note that if + indexes = None, then W must have exactly two columns and this is equivalent + to indexes = [(0,1)]. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :return: A ndarray with shape equal to (num windows, num statistics). + """ + if len(W) != self.num_samples: + raise ValueError( + "First trait dimension must be equal to number of samples." + ) + return self.__k_way_weighted_stat( + self._ll_tree_sequence.genetic_relatedness_weighted, + 2, + W, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W`` @@ -8324,6 +8502,7 @@ def Tajimas_D(self, sample_sets=None, windows=None, mode="site"): :return: A ndarray with shape equal to (num windows, num statistics). If there is one sample set and windows=None, a numpy scalar is returned. """ + # TODO this should be done in C as we'll want to support this method there. def tjd_func(sample_set_sizes, flattened, **kwargs): n = sample_set_sizes @@ -8936,6 +9115,36 @@ def coalescence_time_distribution( span_normalise=span_normalise, ) + def impute_unknown_mutations_time( + self, + method=None, + ): + """ + Returns an array of mutation times, where any unknown times are + imputed from the times of associated nodes. Not to be confused with + :meth:`TableCollection.compute_mutation_times`, which modifies the + ``time`` column of the mutations table in place. + + :param str method: The method used to impute the unknown mutation times. + Currently only "min" is supported, which uses the time of the node + below the mutation as the mutation time. The "min" method can also + be specified by ``method=None`` (Default: ``None``). + :return: An array of length equal to the number of mutations in the + tree sequence. + """ + allowed_methods = ["min"] + if method is None: + method = "min" + if method not in allowed_methods: + raise ValueError( + f"Mutations time imputation method must be chosen from {allowed_methods}" + ) + if method == "min": + mutations_time = self.mutations_time.copy() + unknown = tskit.is_unknown_time(mutations_time) + mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] + return mutations_time + ############################################ # # Deprecated APIs. These are either already unsupported, or will be unsupported in a diff --git a/python/tskit/util.py b/python/tskit/util.py index 9baa298ceb..28e9876b5a 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2023 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,8 @@ Module responsible for various utility functions used in other modules. """ import dataclasses +import io +import itertools import json import numbers import os @@ -45,6 +47,13 @@ def replace(self, **kwargs): """ return dataclasses.replace(self, **kwargs) + def asdict(self, **kwargs): + """ + Return a new dict which maps field names to their corresponding values + in this dataclass. + """ + return dataclasses.asdict(self, **kwargs) + def canonical_json(obj): """ @@ -320,7 +329,7 @@ def obj_to_collapsed_html(d, name=None, open_depth=0): :param str name: Name for this object :param int open_depth: By default sub-sections are collapsed. If this number is - non-zero the first layers up to open_depth will be opened. + non-zero the first layers up to open_depth will be opened. :return: The HTML as a string :rtype: str """ @@ -369,19 +378,25 @@ def render_metadata(md, length=40): return truncate_string_end(str(md), length) -def unicode_table(rows, title=None, header=None, row_separator=True): +def unicode_table( + rows, *, title=None, header=None, row_separator=True, column_alignments=None +): """ Convert a table (list of lists) of strings to a unicode table. If a row contains the string "__skipped__NNN" then "skipped N rows" is displayed. :param list[list[str]] rows: List of rows, each of which is a list of strings for - each cell. The first column will be left justified, the others right. Each row must - have the same number of cells. + each cell. Each row must have the same number of cells. :param str title: If specified the first output row will be a single cell - containing this string, left-justified. [optional] + containing this string, left-justified. [optional] :param list[str] header: Specifies a row above the main rows which will be in double - lined borders and left justified. Must be same length as each row. [optional] + lined borders and left justified. Must be same length as each row. [optional] :param boolean row_separator: If True add lines between each row. [Default: True] + :param column_alignments str: A string of the same length as the number of cells in + a row (i.e. columns) where each character specifies an alignment such as ``<``, + ``>`` or ``^`` as used in Python's string formatting mini-language. If ``None``, + set the first column to be left justified and the remaining columns to be right + justified [Default: ``None``] :return: The table as a string :rtype: str """ @@ -392,6 +407,8 @@ def unicode_table(rows, title=None, header=None, row_separator=True): widths = [ max(len(row[i_col]) for row in all_rows) for i_col in range(len(all_rows[0])) ] + if column_alignments is None: + column_alignments = "<" + ">" * (len(widths) - 1) out = [] inner_width = sum(widths) + len(header or rows[0]) - 1 if title is not None: @@ -423,9 +440,13 @@ def unicode_table(rows, title=None, header=None, row_separator=True): else: if i != 0 and not last_skipped and row_separator: out.append(f"╟{'┼'.join('─' * w for w in widths)}╢\n") + out.append( - f"║{row[0].ljust(widths[0])}│" - f"{'│'.join(cell.rjust(w) for cell, w in zip(row[1:], widths[1:]))}║\n" + "║" + + "│".join( + f"{r:{a}{w}}" for r, w, a in zip(row, widths, column_alignments) + ) + + "║\n" ) last_skipped = False @@ -433,6 +454,38 @@ def unicode_table(rows, title=None, header=None, row_separator=True): return "".join(out) +def html_table(rows, *, header): + headers = "".join(f"{h}" for h in header) + rows = ( + f'{row[11:]}' + f" rows skipped (tskit.set_print_options)" + if "__skipped__" in row + else "".join(f"{cell}" for cell in row) + for row in rows + ) + rows = "".join(f"{row}\n" for row in rows) + return f""" +
+ + + + + {headers} + + + + {rows} + +
+
+ """ + + def tree_sequence_html(ts): table_rows = "".join( f""" @@ -674,6 +727,20 @@ def set_print_options(*, max_lines=40): tskit._print_options = {"max_lines": max_lines} +def truncate_rows(num_rows, limit=None): + """ + Return a list of indexes into a set of rows, but if a ``limit`` is set, truncate the + number of rows and place a single ``-1`` entry, instead of the intermediate indexes + """ + if limit is None or num_rows <= limit: + return range(num_rows) + return itertools.chain( + range(limit // 2), + [-1], + range(num_rows - (limit - (limit // 2)), num_rows), + ) + + def random_nucleotides(length: numbers.Number, *, seed: Union[int, None] = None) -> str: """ Returns a random string of nucleotides of the specified length. Characters @@ -690,3 +757,32 @@ def random_nucleotides(length: numbers.Number, *, seed: Union[int, None] = None) encoded_nucleotides = np.array(list(map(ord, "ACTG")), dtype=np.int8) a = rng.choice(encoded_nucleotides, size=int(length)) return a.tobytes().decode("ascii") + + +def raise_known_file_format_errors(open_file, existing_exception): + """ + Sniffs the file for pk-zip or hdf header bytes, then raises an exception + if these are detected, if not raises the existing exception. + """ + # Check for HDF5 header bytes + try: + open_file.seek(0) + header = open_file.read(4) + except io.UnsupportedOperation: + # If we can't seek, we can't sniff the file. + raise existing_exception + if header == b"\x89HDF": + raise tskit.FileFormatError( + "The specified file appears to be in HDF5 format. This file " + "may have been generated by msprime < 0.6.0 (June 2018) which " + "can no longer be read directly. Please convert to the new " + "kastore format using the ``tskit upgrade`` command." + ) from existing_exception + if header[:2] == b"\x50\x4b": + raise tskit.FileFormatError( + "The specified file appears to be in zip format, so may be a compressed " + "tree sequence. Try using the tszip module to decompress this file before " + "loading. `pip install tszip; tsunzip ` or use " + "`tszip.decompress` in Python code." + ) from existing_exception + raise existing_exception