diff --git a/scripts/create_test_subset.py b/scripts/create_test_subset.py index c14f13d79..5d7d9513a 100755 --- a/scripts/create_test_subset.py +++ b/scripts/create_test_subset.py @@ -21,7 +21,6 @@ import random import subprocess import traceback -import typing from collections import Counter from typing import Optional @@ -97,6 +96,13 @@ This is in addition to the number of families specified in --families and the number of samples specified in -n""", ) +@click.option( + '--noninteractive', + 'noninteractive', + is_flag=True, + default=False, + help='Skip interactive confirmation', +) def main( project: str, samples_n: Optional[int], @@ -104,13 +110,14 @@ def main( skip_ped: Optional[bool] = True, additional_families: Optional[tuple[str]] = None, additional_samples: Optional[tuple[str]] = None, + noninteractive: Optional[bool] = False, ): """ Script creates a test subset for a given project. A new project with a prefix -test is created, and for any files in sample/meta, sequence/meta, or analysis/output a copy in the -test namespace is created. """ - samples_n, families_n = _validate_opts(samples_n, families_n) + samples_n, families_n = _validate_opts(samples_n, families_n, noninteractive) _additional_families: list[str] = list(additional_families) _additional_samples: list[str] = list(additional_samples) @@ -121,7 +128,7 @@ def main( } ) logger.info(f'Found {len(all_samples)} samples') - if samples_n and samples_n >= len(all_samples): + if (samples_n and samples_n >= len(all_samples)) and not noninteractive: resp = str( input( f'Requesting {samples_n} samples which is >= ' @@ -440,8 +447,7 @@ def get_map_ipid_esid( ip_es_map = [] for ip_is_pair in ip_is_map: - samples_per_participant = [] - samples_per_participant.append(ip_is_pair[0]) + samples_per_participant = [ip_is_pair[0]] for isid in ip_is_pair[1:]: if isid in is_es_map: samples_per_participant.append(is_es_map[isid]) @@ -453,10 +459,9 @@ def get_map_ipid_esid( return external_sample_internal_participant_map -def get_samples_for_families(project: str, additional_families: list[str]): +def get_samples_for_families(project: str, additional_families: list[str]) -> list[str]: """Returns the samples that belong to a list of families""" - samples: list[str] = [] full_pedigree = fapi.get_pedigree( project=project, replace_with_participant_external_ids=False, @@ -477,7 +482,7 @@ def get_samples_for_families(project: str, additional_families: list[str]): } ) - samples = [sample['id'] for sample in sample_objects] + samples: list[str] = [sample['id'] for sample in sample_objects] return samples @@ -485,9 +490,8 @@ def get_samples_for_families(project: str, additional_families: list[str]): def get_fams_for_samples( project: str, additional_samples: Optional[list[str]] = None, -): +) -> list[str]: """Returns the families that a list of samples belong to""" - fams: set[str] = set() sample_objects = sapi.get_samples( body_get_samples={ 'project_ids': [project], @@ -503,7 +507,7 @@ def get_fams_for_samples( replace_with_family_external_ids=True, ) - fams = { + fams: set[str] = { fam['family_id'] for fam in full_pedigree if str(fam['individual_id']) in pids } @@ -524,7 +528,7 @@ def _normalise_map(unformatted_map: list[list[str]]) -> dict[str, str]: def _validate_opts( - samples_n: int, families_n: int + samples_n: int, families_n: int, noninteractive: bool ) -> tuple[Optional[int], Optional[int]]: if samples_n is not None and families_n is not None: raise click.BadParameter('Please specify only one of --samples or --families') @@ -539,7 +543,7 @@ def _validate_opts( if families_n is not None and families_n < 1: raise click.BadParameter('Please specify --families higher than 0') - if families_n is not None and families_n >= 30: + if (families_n is not None and families_n >= 30) and not noninteractive: resp = str( input( f'You requested a subset of {families_n} families. ' @@ -549,7 +553,7 @@ def _validate_opts( if resp.lower() != 'y': raise SystemExit() - if samples_n is not None and samples_n >= 100: + if (samples_n is not None and samples_n >= 100) and not noninteractive: resp = str( input( f'You requested a subset of {samples_n} samples. ' @@ -563,7 +567,7 @@ def _validate_opts( def _print_fam_stats(families: list[dict[str, str]]): family_sizes = Counter([fam['family_id'] for fam in families]) - fam_by_size: typing.Counter[int] = Counter() + fam_by_size: Counter[int] = Counter() # determine number of singles, duos, trios, etc for fam in family_sizes: fam_by_size[family_sizes[fam]] += 1