Skip to content

Commit

Permalink
Function to compare two sgkit xarray datasets and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 16, 2023
1 parent c88d5ec commit 621f15b
Show file tree
Hide file tree
Showing 2 changed files with 297 additions and 0 deletions.
87 changes: 87 additions & 0 deletions src/compare_vcfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
import xarray as xr

import sgkit as sg


def get_matching_indices(arr1, arr2):
idx_pairs = []
i = j = 0
while i < len(arr1) and j < len(arr2):
if arr1[i] == arr2[j]:
idx_pairs.append((i, j))
i += 1
j += 1
elif arr1[i] < arr2[j]:
i += 1
else:
j += 1
return idx_pairs


def remap_genotypes(ds1, ds2):
"""
Remap genotypes of `ds2` to `ds1` for each site.
:param xarray.Dataset ds1: sgkit-style dataset
:param xarray.Dataset ds2: sgkit-style dataset
:return: Remapped genotypes of `ds2`
:rtype: xarray.DataArray
"""
common_site_pos = get_matching_indices(ds1.variant_position, ds2.variant_position)
remapped_ds2 = xr.DataArray(
np.zeros([len(common_site_pos), ds1.dims["sample"], ds1.dims["ploidy"]]),
dims=["variants", "samples", "ploidy"]
)

i = 0
for ds1_idx, ds2_idx in common_site_pos:
# Get the allele lists at matching positions
ds1_alleles = np.array([a for a in ds1.variant_allele[ds1_idx].values if a != b''])
ds2_alleles = np.array([a for a in ds2.variant_allele[ds2_idx].values if a != b''])

# Modify the allele lists such that both lists contain the same alleles
ds1_alleles = np.delete(ds1_alleles, np.where(ds1_alleles == b''))
ds2_alleles = np.delete(ds2_alleles, np.where(ds2_alleles == b''))
ds1_uniq = np.setdiff1d(ds1_alleles, ds2_alleles)
ds2_uniq = np.setdiff1d(ds2_alleles, ds1_alleles)
ds1_alleles = np.append(ds1_alleles, ds2_uniq)
ds2_alleles = np.append(ds2_alleles, ds1_uniq)
assert np.array_equal(np.sort(ds1_alleles), np.sort(ds2_alleles))

# Get index map from the allele list of ds2 to the allele list of ds1
ds1_sort_idx = np.argsort(ds1_alleles)
ds2_sort_idx = np.argsort(ds2_alleles)
index_array = np.argsort(ds2_sort_idx)[ds1_sort_idx]

# Remap genotypes 2 to genotypes 1
#ds1_genotype = ds1.call_genotype[ds1_idx].values
ds2_genotype = ds2.call_genotype[ds2_idx].values
remapped_ds2_genotype = index_array[ds2_genotype].tolist()

remapped_ds2[i] = remapped_ds2_genotype
i += 1

return remapped_ds2


def make_compatible_genotypes(ds1, ds2):
"""
Make `ds2` compatible with `ds1` by remapping genotypes.
Definition of compatibility:
1. `ds1` and `ds2` have the same number of samples.
2. `ds1` and `ds2` have the same ploidy.
3. `ds1` and `ds2` have the same number of variable sites.
4. `ds1` and `ds2` have the same allele list at each site.
:param xarray.Dataset ds1: sgkit-style dataset
:param xarray.Dataset ds2: sgkit-style dataset
:return: Compatible `ds1` and `ds2`
:rtype: tuple(xarray.Dataset, xarray.Dataset)
"""
remapped_ds2_genotype = remap_genotypes(ds1, ds2)
common_site_pos = remapped_ds2_genotype.variant_position
ds1_subset = None
ds2_subset = None
return (ds1_subset, ds2_subset)
210 changes: 210 additions & 0 deletions tests/test_compare_vcfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import pytest

import numpy as np
import xarray as xr

import sgkit as sg

import sys
sys.path.append('../src/')
import compare_vcfs


# Assumptions:
# 1) All individuals (i.e., samples) are diploid in both ds1 and ds2.
# 2) All site positions are the same in both ds1 and ds2.
# 3) There is no missing data.

def make_test_case():
"""
Create an xarray dataset that contains:
1) two diploid samples (i.e., individuals); and
2) two biallelic sites on contig 20 at positions 5 and 10.
:return: Dataset for sgkit use.
:rtype: xr.Dataset
"""
contig_id = ['20']
variant_contig = np.array([0, 0], dtype=np.int64)
variant_position = np.array([5, 10], dtype=np.int64)
variant_allele = np.array([
[b'A', b'C'],
[b'G', b'T'],
])
sample_id = np.array([
'tsk0',
'tsk1',
])
call_genotype = np.array(
[
[
[0, 1], # tsk0
[1, 0], # tsk1
],
[
[1, 0], # tsk0
[0, 1], # tsk1
],
],
dtype=np.int8
)
call_genotype_mask = np.zeros_like(call_genotype, dtype=bool) # no mask

ds = xr.Dataset(
{
"contig_id": ("contig", contig_id),
"variant_contig": ("variants", variant_contig),
"variant_position": ("variants", variant_position),
"variant_allele": (["variants", "alleles"], variant_allele),
"sample_id": ("sample", sample_id),
"call_genotype": (["variants", "samples", "ploidy"], call_genotype),
"call_genotype_mask": (["variants", "samples", "ploidy"], call_genotype_mask),
},
attrs={
"contigs": contig_id,
"source": "sgkit" + "-" + str(sg.__version__),
}
)

return ds


def test_both_biallelic_same_alleles_same_order():
ds1 = make_test_case()
ds2 = ds1.copy(deep=True)
actual = compare_vcfs.remap_genotypes(ds1, ds2)
expected = xr.DataArray(
[
[[0, 1], [1, 0]],
[[1, 0], [0, 1]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)


def test_both_biallelic_same_alleles_different_order():
ds1 = make_test_case()
ds2 = ds1.copy(deep=True)
for i in np.arange(ds2.variant_contig.size):
ds2.variant_allele[i] = xr.DataArray(np.flip(ds2.variant_allele[i]))
ds2.call_genotype[i] = xr.DataArray(np.where(ds2.call_genotype[i] == 0, 1, 0))
actual = compare_vcfs.remap_genotypes(ds1, ds2)
expected = xr.DataArray(
[
[[0, 1], [1, 0]],
[[1, 0], [0, 1]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)


def test_biallelic_monoallelic():
ds1 = make_test_case()
ds2 = ds1.copy(deep=True)
# At the first site, one allele is shared.
# At the second site, no allele is shared.
for i in np.arange(ds2.variant_contig.size):
ds2.variant_allele[i] = xr.DataArray([b'C', b''])
ds2.call_genotype[i] = xr.DataArray(np.zeros_like(ds2.call_genotype[i]))
# Subtest 1
actual = compare_vcfs.remap_genotypes(ds1, ds2)
expected = xr.DataArray(
[
[[1, 1], [1, 1]],
[[2, 2], [2, 2]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)
# Subtest 2
actual = compare_vcfs.remap_genotypes(ds2, ds1)
expected = xr.DataArray(
[
[[1, 0], [0, 1]],
[[2, 1], [1, 2]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)


def test_both_biallelic_different_alleles():
ds1 = make_test_case()
ds2 = ds1.copy(deep=True)
# At the first site, one allele is shared.
ds2.variant_allele[0] = xr.DataArray([b'C', b'G'])
ds2.call_genotype[0] = xr.DataArray([[0, 1], [1, 0]])
# At the second site, no allele is shared.
ds2.variant_allele[1] = xr.DataArray([b'A', b'C'])
ds2.call_genotype[1] = xr.DataArray([[0, 1], [1, 0]])
# Subtest 1
actual = compare_vcfs.remap_genotypes(ds1, ds2)
expected = xr.DataArray(
[
[[1, 2], [2, 1]],
[[2, 3], [3, 2]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)
# Subtest 2
actual = compare_vcfs.remap_genotypes(ds2, ds1)
expected = xr.DataArray(
[
[[2, 0], [0, 2]],
[[3, 2], [2, 3]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)


def test_both_monoallelic():
ds1 = make_test_case()
ds2 = ds1.copy(deep=True)
# Overwrite certain data variables in ds1 and ds2.
# At the first site, one allele is shared.
ds1.variant_allele[0] = xr.DataArray([b'C', b''])
ds1.call_genotype[0] = xr.DataArray(np.zeros_like(ds1.call_genotype[0]))
ds2.variant_allele[0] = xr.DataArray([b'C', b''])
ds2.call_genotype[0] = xr.DataArray(np.zeros_like(ds2.call_genotype[0]))
# At the second site, no allele is shared.
ds1.variant_allele[1] = xr.DataArray([b'C', b''])
ds1.call_genotype[1] = xr.DataArray(np.zeros_like(ds1.call_genotype[1]))
ds2.variant_allele[1] = xr.DataArray([b'G', b''])
ds2.call_genotype[1] = xr.DataArray(np.zeros_like(ds2.call_genotype[1]))
actual = compare_vcfs.remap_genotypes(ds1, ds2)
expected = xr.DataArray(
[
[[0, 0], [0, 0]],
[[1, 1], [1, 1]],
],
dims=["variants", "samples", "ploidy"]
)
assert np.array_equal(actual, expected)


@pytest.mark.parametrize(
"arr1, arr2, expected",
[
pytest.param([1, 5, 9], [5, 9, 15], [(1, 0), (2, 1)], id="sites shared"),
pytest.param([1, 9], [5, 15], [], id="no sites shared"),
pytest.param([1], [], [], id="one empty"),
pytest.param([], [], [], id="both empty"),
]
)
def test_get_matching_indices(arr1, arr2, expected):
actual = compare_vcfs.get_matching_indices(arr1, arr2)
assert np.array_equal(actual, expected)


def test_both_empty():
"""TODO"""
raise NotImplementedError


def test_one_empty():
"""TODO"""
raise NotImplementedError

0 comments on commit 621f15b

Please sign in to comment.