diff --git a/scripts/create_test_subset.py b/scripts/create_test_subset.py index 250abf284..6605c2347 100755 --- a/scripts/create_test_subset.py +++ b/scripts/create_test_subset.py @@ -166,34 +166,65 @@ """ ) +COHORT_QUERY = gql( + """ + query CohortQuery($project: String!) { + project(name: $project) { + cohorts { + id + sequencingGroups { + sample { + id + } + } + } + } + } + """ +) + def main( project: str, samples_n: int, families_n: int, + cohort_samples_n: int, additional_families: set[str], additional_samples: set[str], + cohorts: set[str], skip_ped: bool = True, ): """ Script creates a test subset for a given project. - A new project with a prefix -test is created, and for any files in sample/meta, + A new project with a suffix -test is created, and for any files in sample/meta, sequence/meta, or analysis/output a copy in the -test namespace is created. """ - - if not any([additional_families, additional_samples, samples_n, families_n]): + if not any( + [additional_families, additional_samples, samples_n, families_n, cohorts] + ): raise ValueError('Come on, what exactly are you asking for?') + if cohorts and not cohort_samples_n: + raise ValueError( + 'You must specify the number of samples to transfer from the cohort.' + ) + # for reproducibility logger.info('Setting random seed to 42') random.seed(42) - # 1. Find and SG IDs to be moved by Family ID -test. + # 1. Find SG IDs to be moved by Family ID to -test. if families_n or additional_families: additional_samples.update( get_sids_for_families(project, families_n, additional_families) ) + # 1.5 Find SG IDs to be moved by Cohort ID to -test. + if cohorts: + additional_samples.update( + get_sids_for_cohorts(project, cohorts, cohort_samples_n) + ) + # 2. Get all sample IDs and their SG IDs in project. logger.info(f'Querying all samples in {project}') sid_output = query(SG_ID_QUERY, variables={'project': project}) @@ -206,7 +237,7 @@ def main( ) # 4. Query all the samples from the selected sgs - logger.info(f'Transfering {len(additional_samples)} samples. Querying metadata.') + logger.info(f'Transferring {len(additional_samples)} samples. Querying metadata.') original_project_subset_data = query( QUERY_ALL_DATA, {'project': project, 'sids': list(additional_samples)} ) @@ -627,6 +658,30 @@ def get_sids_for_families( return included_sids +def get_sids_for_cohorts( + project: str, cohorts: set[str], cohort_samples_n: int +) -> set[str]: + """Returns cohort_samples_n specific samples for given cohort IDs.""" + + cohort_sid_output = query(COHORT_QUERY, {'project': project}) + + all_cohort_groups = cohort_sid_output.get('project', {}).get('cohorts', []) + + all_cohorts_sample_ids_subset: set[str] = set() + for cohort in all_cohort_groups: + sids_for_cohort: list[str] = [] + if cohort.get('id') in cohorts: + seq_groups = cohort.get('sequencingGroups', []) + for seq_group in seq_groups: + sample = seq_group.get('sample') + sids_for_cohort.append(sample['id']) + all_cohorts_sample_ids_subset.update( + random.sample(sids_for_cohort, cohort_samples_n) + ) + + return all_cohorts_sample_ids_subset + + def transfer_families( initial_project: str, target_project: str, internal_participant_ids: list[int] ) -> list[int]: @@ -852,8 +907,16 @@ def file_exists(path: str) -> bool: parser.add_argument( '--project', required=True, help='The sample-metadata project ($DATASET)' ) - parser.add_argument('-n', type=int, help='# Random Samples to copy', default=0) + parser.add_argument( + '-n', type=int, help='# Random Samples to copy', default=DEFAULT_SAMPLES_N + ) parser.add_argument('-f', type=int, help='# Random families to copy', default=0) + parser.add_argument( + '-nsamples-cohort', + type=int, + help='# Random samples to copy from each cohort', + default=0, + ) # Flag to be used when there isn't available pedigree/family information. parser.add_argument( '--skip-ped', @@ -874,6 +937,13 @@ def file_exists(path: str) -> bool: type=str, default={}, ) + parser.add_argument( + '--cohorts', + nargs='+', + help='Cohorts to take random samples from.', + type=str, + default={}, + ) parser.add_argument( '--noninteractive', action='store_true', help='Skip interactive confirmation' ) @@ -886,7 +956,9 @@ def file_exists(path: str) -> bool: project=args.project, samples_n=args.n, families_n=args.f, + cohort_samples_n=args.nsamples_cohort, additional_samples=set(args.samples), additional_families=set(args.families), skip_ped=args.skip_ped, + cohorts=set(args.cohorts), )