-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Function to compare two sgkit xarray datasets and tests
- Loading branch information
Showing
2 changed files
with
395 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import numpy as np | ||
import xarray as xr | ||
|
||
|
||
_ACGT_ALLELES = np.array([b'A', b'C', b'G', b'T']) | ||
|
||
|
||
def get_matching_indices(arr1, arr2): | ||
""" | ||
Get the indices of `arr1` and `arr2`, | ||
where the values of `arr1` and `arr2` are equal. | ||
:param numpy.ndarray arr1: 1D array | ||
:param numpy.ndarray arr2: 1D array | ||
:return: Indices of `arr1` and `arr2` | ||
:rtype: numpy.ndarray | ||
""" | ||
idx_pairs = [] | ||
i = j = 0 | ||
while i < len(arr1) and j < len(arr2): | ||
if arr1[i] == arr2[j]: | ||
idx_pairs.append((i, j)) | ||
i += 1 | ||
j += 1 | ||
elif arr1[i] < arr2[j]: | ||
i += 1 | ||
else: | ||
j += 1 | ||
return np.array(idx_pairs) | ||
|
||
|
||
def remap_genotypes(ds1, ds2, actg_alleles=False): | ||
""" | ||
Remap genotypes of `ds2` to `ds1` for each site, such that: | ||
1. There are only genotypes shared between `ds1` and `ds2`. | ||
2. The allele lists always have length 4 (full or padded). | ||
Assumptions: | ||
1. Only ACGT alleles. | ||
2. No mixed ploidy. | ||
Additionally, if `actg_alleles` is set to True, | ||
all allele lists in `ds2` are set to ACGT. | ||
:param xarray.Dataset ds1: sgkit-style dataset. | ||
:param xarray.Dataset ds2: sgkit-style dataset. | ||
:param bool acgt_alleles: All allele lists are set to ACGT (default = False). | ||
:return: Remapped genotypes of `ds2`. | ||
:rtype: xarray.DataArray | ||
""" | ||
common_site_idx = get_matching_indices( | ||
ds1.variant_position, | ||
ds2.variant_position | ||
) | ||
|
||
remapped_ds2_variant_allele = xr.DataArray( | ||
np.empty([len(common_site_idx), 4], dtype='|S1'), # ACGT | ||
dims=["variants", "alleles"] | ||
) | ||
remapped_ds2_call_genotype = xr.DataArray( | ||
np.zeros([len(common_site_idx), ds2.dims["samples"], ds2.dims["ploidy"]]), | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
|
||
i = 0 | ||
for ds1_idx, ds2_idx in common_site_idx: | ||
# Get the allele lists at matching positions | ||
ds1_alleles = _ACGT_ALLELES if actg_alleles else \ | ||
np.array([a for a in ds1.variant_allele[ds1_idx].values if a != b'']) | ||
ds2_alleles = np.array([a for a in ds2.variant_allele[ds2_idx].values if a != b'']) | ||
|
||
if not np.all(np.isin(ds1_alleles, _ACGT_ALLELES)) or \ | ||
not np.all(np.isin(ds2_alleles, _ACGT_ALLELES)): | ||
raise ValueError("Alleles must be ACGT.") | ||
|
||
# Modify the allele lists such that both lists contain the same alleles | ||
ds1_uniq = np.setdiff1d(ds1_alleles, ds2_alleles) | ||
ds2_uniq = np.setdiff1d(ds2_alleles, ds1_alleles) | ||
ds1_alleles = np.append(ds1_alleles, ds2_uniq) | ||
ds2_alleles = np.append(ds2_alleles, ds1_uniq) | ||
|
||
if not np.array_equal(np.sort(ds1_alleles), np.sort(ds2_alleles)): | ||
raise ValueError("Allele lists are not the same.") | ||
|
||
# Get index map from the allele list of ds2 to the allele list of ds1 | ||
ds1_sort_idx = np.argsort(ds1_alleles) | ||
ds2_sort_idx = np.argsort(ds2_alleles) | ||
index_array = np.argsort(ds2_sort_idx)[ds1_sort_idx] | ||
|
||
# Pad allele list so that it is length 4 | ||
if len(ds1_alleles) < 4: | ||
ds1_alleles = np.append(ds1_alleles, np.full(4 - len(ds1_alleles), b'')) | ||
|
||
# Remap genotypes 2 to genotypes 1 | ||
remapped_ds2_variant_allele[i] = ds1_alleles | ||
ds2_genotype = ds2.call_genotype[ds2_idx].values | ||
remapped_ds2_call_genotype[i] = index_array[ds2_genotype].tolist() | ||
i += 1 | ||
|
||
return (remapped_ds2_variant_allele, remapped_ds2_call_genotype) | ||
|
||
|
||
def make_compatible_genotypes(ds1, ds2, acgt_alleles=False): | ||
""" | ||
Make `ds2` compatible with `ds1` by remapping genotypes. | ||
Definition of compatibility: | ||
1. `ds1` and `ds2` have the same number of samples. | ||
2. `ds1` and `ds2` have the same ploidy. | ||
3. `ds1` and `ds2` have the same number of variable sites. | ||
4. `ds1` and `ds2` have the same allele list at each site. | ||
Assumptions: | ||
1. Only ACGT alleles. | ||
2. No mixed ploidy. | ||
Additionally, if `actg_alleles` is set to True, | ||
all allele lists in `ds1` and `ds2` are set to ACGT. | ||
:param xarray.Dataset ds1: sgkit-style dataset. | ||
:param xarray.Dataset ds2: sgkit-style dataset. | ||
:param bool acgt_alleles: All allele lists are set to ACGT (default = False). | ||
:return: Compatible `ds1` and `ds2`. | ||
:rtype: tuple(xarray.Dataset, xarray.Dataset) | ||
""" | ||
# TODO: Refactor, routine run again when calling `remap_genotypes` | ||
common_site_idx = get_matching_indices(ds1, ds2) | ||
ds1_idx, ds2_idx = np.split(common_site_idx, len(common_site_idx), axis=1) | ||
ds1_idx = np.array(ds1_idx.flatten()) | ||
ds2_idx = np.array(ds2_idx.flatten()) | ||
assert len(ds1_idx) == len(ds2_idx) == len(common_site_idx) | ||
|
||
if acgt_alleles: | ||
remapped_ds1_alleles, remapped_ds1_genotypes = remap_genotypes(ds2, ds1, acgt_alleles=acgt_alleles) | ||
assert remapped_ds1_alleles.shape == (len(common_site_idx), 4) | ||
assert remapped_ds1_genotypes.shape == (len(common_site_idx), ds1.dims["samples"], ds1.dims["ploidy"]) | ||
|
||
remapped_ds2_alleles, remapped_ds2_genotypes = remap_genotypes(ds1, ds2, acgt_alleles=acgt_alleles) | ||
assert remapped_ds2_alleles.shape == (len(common_site_idx), 4) | ||
assert remapped_ds2_genotypes.shape == (len(common_site_idx), ds2.dims["samples"], ds2.dims["ploidy"]) | ||
|
||
# Subset `ds1` to common sites | ||
ds1_subset = ds1.copy(deep=True) # TODO: Copying is expensive, another way? | ||
ds1_subset["variant_contig"] = ds1_subset["variant_contig"].isel(variants=ds1_idx) | ||
ds1_subset["variant_position"] = ds1_subset["variant_position"].isel(variants=ds1_idx) | ||
if acgt_alleles: | ||
ds1_subset["variant_allele"] = remapped_ds1_alleles | ||
ds1_subset["call_genotype"] = remapped_ds1_genotypes | ||
else: | ||
ds1_subset["variant_allele"] = remapped_ds2_alleles | ||
ds1_subset["call_genotype"] = ds1_subset["call_genotype"].isel(variants=ds1_idx) | ||
ds1_subset["call_genotype_mask"] = ds1_subset["call_genotype_mask"].isel(variants=ds1_idx) | ||
|
||
# Subset `ds2` to common sites | ||
ds2_subset = ds2.copy(deep=True) | ||
ds2_subset["variant_contig"] = ds2_subset["variant_contig"].isel(variants=ds2_idx) | ||
ds2_subset["variant_position"] = ds2_subset["variant_position"].isel(variants=ds2_idx) | ||
ds2_subset["variant_allele"] = remapped_ds2_alleles | ||
ds2_subset["call_genotype"] = remapped_ds2_genotypes | ||
ds2_subset["call_genotype_mask"] = ds2_subset["call_genotype_mask"].isel(variants=ds2_idx) | ||
|
||
return (ds1_subset, ds2_subset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
import xarray as xr | ||
|
||
import sgkit as sg | ||
|
||
import sys | ||
sys.path.append('../src/') | ||
import compare_vcfs | ||
|
||
|
||
# Assumptions: | ||
# 1. All individuals (i.e., samples) are diploid in both ds1 and ds2. | ||
# 2. There is no missing data. | ||
|
||
def make_test_case(): | ||
""" | ||
Create an xarray dataset that contains: | ||
1) two diploid samples (i.e., individuals); and | ||
2) two biallelic sites on contig 20 at positions 5 and 10. | ||
:return: Dataset for sgkit use. | ||
:rtype: xr.Dataset | ||
""" | ||
contig_id = ['20'] | ||
variant_contig = np.array([0, 0], dtype=np.int64) | ||
variant_position = np.array([5, 10], dtype=np.int64) | ||
variant_allele = np.array([ | ||
[b'A', b'C'], | ||
[b'G', b'T'], | ||
]) | ||
sample_id = np.array([ | ||
'tsk0', | ||
'tsk1', | ||
]) | ||
call_genotype = np.array( | ||
[ | ||
[ | ||
[0, 1], # tsk0 | ||
[1, 0], # tsk1 | ||
], | ||
[ | ||
[1, 0], # tsk0 | ||
[0, 1], # tsk1 | ||
], | ||
], | ||
dtype=np.int8 | ||
) | ||
call_genotype_mask = np.zeros_like(call_genotype, dtype=bool) # no mask | ||
|
||
ds = xr.Dataset( | ||
{ | ||
"contig_id": ("contigs", contig_id), | ||
"variant_contig": ("variants", variant_contig), | ||
"variant_position": ("variants", variant_position), | ||
"variant_allele": (["variants", "alleles"], variant_allele), | ||
"sample_id": ("samples", sample_id), | ||
"call_genotype": (["variants", "samples", "ploidy"], call_genotype), | ||
"call_genotype_mask": (["variants", "samples", "ploidy"], call_genotype_mask), | ||
}, | ||
attrs={ | ||
"contigs": contig_id, | ||
"source": "sgkit" + "-" + str(sg.__version__), | ||
} | ||
) | ||
|
||
return ds | ||
|
||
|
||
def test_both_biallelic_same_alleles_same_order(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
_, actual = compare_vcfs.remap_genotypes(ds1, ds2) | ||
expected = xr.DataArray( | ||
[ | ||
[[0, 1], [1, 0]], | ||
[[1, 0], [0, 1]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_both_biallelic_same_alleles_different_order(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
for i in np.arange(ds2.variant_contig.size): | ||
ds2.variant_allele[i] = xr.DataArray(np.flip(ds2.variant_allele[i])) | ||
ds2.call_genotype[i] = xr.DataArray(np.where(ds2.call_genotype[i] == 0, 1, 0)) | ||
_, actual = compare_vcfs.remap_genotypes(ds1, ds2) | ||
expected = xr.DataArray( | ||
[ | ||
[[0, 1], [1, 0]], | ||
[[1, 0], [0, 1]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_both_biallelic_different_alleles(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
# At the first site, one allele is shared. | ||
ds2.variant_allele[0] = xr.DataArray([b'C', b'G']) | ||
ds2.call_genotype[0] = xr.DataArray([[0, 1], [1, 0]]) | ||
# At the second site, no allele is shared. | ||
ds2.variant_allele[1] = xr.DataArray([b'A', b'C']) | ||
ds2.call_genotype[1] = xr.DataArray([[0, 1], [1, 0]]) | ||
# Subtest 1 | ||
_, actual = compare_vcfs.remap_genotypes(ds1, ds2) | ||
expected = xr.DataArray( | ||
[ | ||
[[1, 2], [2, 1]], | ||
[[2, 3], [3, 2]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
# Subtest 2 | ||
_, actual = compare_vcfs.remap_genotypes(ds2, ds1) | ||
expected = xr.DataArray( | ||
[ | ||
[[2, 0], [0, 2]], | ||
[[3, 2], [2, 3]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_biallelic_monoallelic(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
# At the first site, one allele is shared. | ||
# At the second site, no allele is shared. | ||
for i in np.arange(ds2.variant_contig.size): | ||
ds2.variant_allele[i] = xr.DataArray([b'C', b'']) | ||
ds2.call_genotype[i] = xr.DataArray(np.zeros_like(ds2.call_genotype[i])) | ||
# Subtest 1 | ||
_, actual = compare_vcfs.remap_genotypes(ds1, ds2) | ||
expected = xr.DataArray( | ||
[ | ||
[[1, 1], [1, 1]], | ||
[[2, 2], [2, 2]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
# Subtest 2 | ||
_, actual = compare_vcfs.remap_genotypes(ds2, ds1) | ||
expected = xr.DataArray( | ||
[ | ||
[[1, 0], [0, 1]], | ||
[[2, 1], [1, 2]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_both_monoallelic(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
# Overwrite certain data variables in ds1 and ds2. | ||
# At the first site, one allele is shared. | ||
ds1.variant_allele[0] = xr.DataArray([b'C', b'']) | ||
ds1.call_genotype[0] = xr.DataArray(np.zeros_like(ds1.call_genotype[0])) | ||
ds2.variant_allele[0] = xr.DataArray([b'C', b'']) | ||
ds2.call_genotype[0] = xr.DataArray(np.zeros_like(ds2.call_genotype[0])) | ||
# At the second site, no allele is shared. | ||
ds1.variant_allele[1] = xr.DataArray([b'C', b'']) | ||
ds1.call_genotype[1] = xr.DataArray(np.zeros_like(ds1.call_genotype[1])) | ||
ds2.variant_allele[1] = xr.DataArray([b'G', b'']) | ||
ds2.call_genotype[1] = xr.DataArray(np.zeros_like(ds2.call_genotype[1])) | ||
_, actual = compare_vcfs.remap_genotypes(ds1, ds2) | ||
expected = xr.DataArray( | ||
[ | ||
[[0, 0], [0, 0]], | ||
[[1, 1], [1, 1]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_actg_alleles_true(): | ||
ds1 = make_test_case() | ||
ds2 = ds1.copy(deep=True) | ||
actual_alleles, actual_genotypes = compare_vcfs.remap_genotypes(ds1, ds2, actg_alleles=True) | ||
expected_alleles = np.array([ | ||
[b'A', b'C', b'G', b'T'], | ||
[b'A', b'C', b'G', b'T'], | ||
]) | ||
expected_genotypes = xr.DataArray( | ||
[ | ||
[[0, 1], [1, 0]], | ||
[[3, 2], [2, 3]], | ||
], | ||
dims=["variants", "samples", "ploidy"] | ||
) | ||
assert np.array_equal(actual_alleles, expected_alleles) | ||
assert np.array_equal(actual_genotypes, expected_genotypes) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arr1, arr2, expected", | ||
[ | ||
pytest.param([1, 5, 9], [5, 9, 15], [(1, 0), (2, 1)], id="sites shared"), | ||
pytest.param([1, 9], [5, 15], [], id="no sites shared"), | ||
pytest.param([1], [], [], id="one empty"), | ||
pytest.param([], [], [], id="both empty"), | ||
] | ||
) | ||
def test_get_matching_indices(arr1, arr2, expected): | ||
actual = compare_vcfs.get_matching_indices(arr1, arr2) | ||
assert np.array_equal(actual, expected) | ||
|
||
|
||
def test_both_empty(): | ||
"""TODO""" | ||
raise NotImplementedError | ||
|
||
|
||
def test_one_empty(): | ||
"""TODO""" | ||
raise NotImplementedError | ||
|
||
|
||
def test_non_acgt(): | ||
"""TODO""" | ||
raise NotImplementedError |