Skip to content

Commit

Permalink
Merge pull request #556 from populationgenomics/resolve-main-conflicts
Browse files Browse the repository at this point in the history
Resolve main conflicts
  • Loading branch information
illusional authored Sep 20, 2023
2 parents 151f61f + 0c148a6 commit d7adc60
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions scripts/create_test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import random
import subprocess
import traceback
import typing
from collections import Counter
from typing import Optional

Expand Down Expand Up @@ -97,20 +96,28 @@
This is in addition to the number of families specified in
--families and the number of samples specified in -n""",
)
@click.option(
'--noninteractive',
'noninteractive',
is_flag=True,
default=False,
help='Skip interactive confirmation',
)
def main(
project: str,
samples_n: Optional[int],
families_n: Optional[int],
skip_ped: Optional[bool] = True,
additional_families: Optional[tuple[str]] = None,
additional_samples: Optional[tuple[str]] = None,
noninteractive: Optional[bool] = False,
):
"""
Script creates a test subset for a given project.
A new project with a prefix -test is created, and for any files in sample/meta,
sequence/meta, or analysis/output a copy in the -test namespace is created.
"""
samples_n, families_n = _validate_opts(samples_n, families_n)
samples_n, families_n = _validate_opts(samples_n, families_n, noninteractive)
_additional_families: list[str] = list(additional_families)
_additional_samples: list[str] = list(additional_samples)

Expand All @@ -121,7 +128,7 @@ def main(
}
)
logger.info(f'Found {len(all_samples)} samples')
if samples_n and samples_n >= len(all_samples):
if (samples_n and samples_n >= len(all_samples)) and not noninteractive:
resp = str(
input(
f'Requesting {samples_n} samples which is >= '
Expand Down Expand Up @@ -440,8 +447,7 @@ def get_map_ipid_esid(

ip_es_map = []
for ip_is_pair in ip_is_map:
samples_per_participant = []
samples_per_participant.append(ip_is_pair[0])
samples_per_participant = [ip_is_pair[0]]
for isid in ip_is_pair[1:]:
if isid in is_es_map:
samples_per_participant.append(is_es_map[isid])
Expand All @@ -453,10 +459,9 @@ def get_map_ipid_esid(
return external_sample_internal_participant_map


def get_samples_for_families(project: str, additional_families: list[str]):
def get_samples_for_families(project: str, additional_families: list[str]) -> list[str]:
"""Returns the samples that belong to a list of families"""

samples: list[str] = []
full_pedigree = fapi.get_pedigree(
project=project,
replace_with_participant_external_ids=False,
Expand All @@ -477,17 +482,16 @@ def get_samples_for_families(project: str, additional_families: list[str]):
}
)

samples = [sample['id'] for sample in sample_objects]
samples: list[str] = [sample['id'] for sample in sample_objects]

return samples


def get_fams_for_samples(
project: str,
additional_samples: Optional[list[str]] = None,
):
) -> list[str]:
"""Returns the families that a list of samples belong to"""
fams: set[str] = set()
sample_objects = sapi.get_samples(
body_get_samples={
'project_ids': [project],
Expand All @@ -503,7 +507,7 @@ def get_fams_for_samples(
replace_with_family_external_ids=True,
)

fams = {
fams: set[str] = {
fam['family_id'] for fam in full_pedigree if str(fam['individual_id']) in pids
}

Expand All @@ -524,7 +528,7 @@ def _normalise_map(unformatted_map: list[list[str]]) -> dict[str, str]:


def _validate_opts(
samples_n: int, families_n: int
samples_n: int, families_n: int, noninteractive: bool
) -> tuple[Optional[int], Optional[int]]:
if samples_n is not None and families_n is not None:
raise click.BadParameter('Please specify only one of --samples or --families')
Expand All @@ -539,7 +543,7 @@ def _validate_opts(
if families_n is not None and families_n < 1:
raise click.BadParameter('Please specify --families higher than 0')

if families_n is not None and families_n >= 30:
if (families_n is not None and families_n >= 30) and not noninteractive:
resp = str(
input(
f'You requested a subset of {families_n} families. '
Expand All @@ -549,7 +553,7 @@ def _validate_opts(
if resp.lower() != 'y':
raise SystemExit()

if samples_n is not None and samples_n >= 100:
if (samples_n is not None and samples_n >= 100) and not noninteractive:
resp = str(
input(
f'You requested a subset of {samples_n} samples. '
Expand All @@ -563,7 +567,7 @@ def _validate_opts(

def _print_fam_stats(families: list[dict[str, str]]):
family_sizes = Counter([fam['family_id'] for fam in families])
fam_by_size: typing.Counter[int] = Counter()
fam_by_size: Counter[int] = Counter()
# determine number of singles, duos, trios, etc
for fam in family_sizes:
fam_by_size[family_sizes[fam]] += 1
Expand Down

0 comments on commit d7adc60

Please sign in to comment.