From 6baf01326b7eb6d390ab30d93c041205d092afcf Mon Sep 17 00:00:00 2001 From: MattWellie Date: Tue, 19 Sep 2023 11:13:45 +1000 Subject: [PATCH] my code don't jiggle jiggle, it folds --- scripts/create_test_subset.py | 459 ++++++++++++---------------------- 1 file changed, 165 insertions(+), 294 deletions(-) diff --git a/scripts/create_test_subset.py b/scripts/create_test_subset.py index 4551d6c2f..9fe9838b1 100755 --- a/scripts/create_test_subset.py +++ b/scripts/create_test_subset.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# pylint: disable=too-many-instance-attributes,too-many-locals,too-many-arguments +# pylint: disable=too-many-instance-attributes,too-many-locals """ Example Invocation @@ -11,24 +11,18 @@ This example will populate acute-care-test with the metamist data for 4 families. """ -from typing import Optional, Counter as CounterType +import csv import logging import os import random import subprocess +from argparse import ArgumentParser from collections import Counter -import csv -import click from google.cloud import storage -from metamist.apis import ( - AnalysisApi, - AssayApi, - SampleApi, - FamilyApi, - ParticipantApi, -) +from metamist.apis import AnalysisApi, AssayApi, SampleApi, FamilyApi, ParticipantApi +from metamist.graphql import gql, query from metamist.models import ( AssayUpsert, SampleUpsert, @@ -38,7 +32,6 @@ SequencingGroupUpsert, ) -from metamist.graphql import gql, query logger = logging.getLogger(__file__) logging.basicConfig(format='%(levelname)s (%(name)s %(lineno)s): %(message)s') @@ -157,7 +150,8 @@ """ ) -PARTICIPANT_QUERY = """ +PARTICIPANT_QUERY = gql( + """ query ($project: String!) { project (externalId: $project) { participants { @@ -167,102 +161,45 @@ } } """ +) -@click.command() -@click.option( - '--project', - required=True, - help='The sample-metadata project ($DATASET)', -) -@click.option( - '-n', - '--samples', - 'samples_n', - type=int, - help='Number of samples to subset', -) -@click.option( - '--families', - 'families_n', - type=int, - help='Minimal number of families to include', -) -# Flag to be used when there isn't available pedigree/family information. -@click.option( - '--skip-ped', - 'skip_ped', - is_flag=True, - default=False, - help='Skip transferring pedigree/family information', -) -@click.option( - '--add-family', - 'additional_families', - type=str, - multiple=True, - help="""Additional families to include. - All samples from these fams will be included. - This is in addition to the number of families specified in - --families and the number of samples specified in -n""", -) -@click.option( - '--add-sample', - 'additional_samples', - type=str, - multiple=True, - help="""Additional samples to include. - This is in addition to the number of families specified in - --families and the number of samples specified in -n""", -) def main( project: str, - samples_n: Optional[int], - families_n: Optional[int], - skip_ped: Optional[bool] = True, - additional_families: Optional[tuple[str]] = None, - additional_samples: Optional[tuple[str]] = None, + samples_n: int, + families_n: int, + additional_families: set[str], + additional_samples: set[str], + skip_ped: bool = True, ): """ Script creates a test subset for a given project. A new project with a prefix -test is created, and for any files in sample/meta, sequence/meta, or analysis/output a copy in the -test namespace is created. """ - samples_n, families_n = _validate_opts(samples_n, families_n) - _additional_families: list[str] = list(additional_families) - _additional_samples: list[str] = list(additional_samples) - - # 1. Determine the sids to be moved into -test. - specific_sids = _get_sids_for_families( - project, - families_n, - _additional_families, - ) - if not samples_n and not families_n: - samples_n = DEFAULT_SAMPLES_N - if not samples_n and families_n: - samples_n = 0 - specific_sids = specific_sids + _additional_samples + if not any([additional_families, additional_samples, samples_n, families_n]): + raise ValueError('Come on, what exactly are you asking for?') + + # for reproducibility + random.seed(42) + + # 1. Find and SG IDs to be moved by Family ID -test. + if families_n or additional_families: + additional_samples.update( + get_sids_for_families(project, families_n, additional_families) + ) # 2. Get all sids in project. sid_output = query(SG_ID_QUERY, variables={'project': project}) - all_sids = [sid['id'] for sid in sid_output.get('project').get('samples')] - - # 3. Subtract the specific_sgs from all the sgs - sgids_after_inclusions = list(set(all_sids) - set(specific_sids)) - # 4. Randomly select from the remaining sgs - random_sgs: list[str] = [] - random.seed(42) # for reproducibility - if (samples_n - len(specific_sids)) > 0: - random_sgs = random.sample( - sgids_after_inclusions, samples_n - len(specific_sids) - ) - # 5. Add the specific_sgs to the randomly selected sgs - final_subset_sids = specific_sids + random_sgs - # 6. Query all the samples from the selected sgs + all_sids = {sid['id'] for sid in sid_output.get('project').get('samples')} + + # 3. Randomly select from the remaining sgs + additional_samples.update(random.sample(all_sids - additional_samples, samples_n)) + + # 4. Query all the samples from the selected sgs original_project_subset_data = query( - QUERY_ALL_DATA, {'project': project, 'sids': final_subset_sids} + QUERY_ALL_DATA, {'project': project, 'sids': list(additional_samples)} ) # Pull Participant Data @@ -313,14 +250,14 @@ def transfer_samples_sgs_assays( sample_sgs: list[SequencingGroupUpsert] = [] for sg in s.get('sequencingGroups'): sg_assays: list[AssayUpsert] = [] - _existing_sg = _get_existing_sg( + existing_sg = get_existing_sg( existing_data, s.get('externalId'), sg.get('type') ) - _existing_sgid = _existing_sg.get('id') if _existing_sg else None + _existing_sgid = existing_sg.get('id') if existing_sg else None for assay in sg.get('assays'): _existing_assay: dict[str, str] = {} if _existing_sgid: - _existing_assay = _get_existing_assay( + _existing_assay = get_existing_assay( existing_data, s.get('externalId'), _existing_sgid, @@ -348,23 +285,23 @@ def transfer_samples_sgs_assays( ) sample_sgs.append(sg_upsert) - _sample_type = None if s['type'] == 'None' else s['type'] - _existing_sid: str = None - _existing_sample = _get_existing_sample(existing_data, s['externalId']) - if _existing_sample: - _existing_sid = _existing_sample['id'] + sample_type = None if s['type'] == 'None' else s['type'] + existing_sid: str | None = None + existing_sample = get_existing_sample(existing_data, s['externalId']) + if existing_sample: + existing_sid = existing_sample['id'] - _existing_pid: int = None + existing_pid: int | None = None if s['participant']: - _existing_pid = upserted_participant_map[s['participant']['externalId']] + existing_pid = upserted_participant_map[s['participant']['externalId']] sample_upsert = SampleUpsert( external_id=s['externalId'], - type=_sample_type or None, - meta=(_copy_files_in_dict(s['meta'], project) or {}), - participant_id=_existing_pid, + type=sample_type or None, + meta=(copy_files_in_dict(s['meta'], project) or {}), + participant_id=existing_pid, sequencing_groups=sample_sgs, - id=_existing_sid, + id=existing_sid, ) logger.info(f'Processing sample {s["id"]}') @@ -392,30 +329,27 @@ def transfer_analyses( for s in samples: for sg in s['sequencingGroups']: - _existing_sg = _get_existing_sg( + existing_sg = get_existing_sg( existing_data, s.get('externalId'), sg.get('type') ) - _existing_sgid = _existing_sg.get('id') if _existing_sg else None + existing_sgid = existing_sg.get('id') if existing_sg else None for analysis in sg['analyses']: if analysis['type'] not in ['cram', 'gvcf']: # Currently the create_test_subset script only handles crams or gvcf files. continue - _existing_analysis: dict = {} - if _existing_sgid: - _existing_analysis = _get_existing_analysis( - existing_data, - s['externalId'], - _existing_sgid, - analysis['type'], + existing_analysis: dict = {} + if existing_sgid: + existing_analysis = get_existing_analysis( + existing_data, s['externalId'], existing_sgid, analysis['type'] ) - _existing_analysis_id = ( - _existing_analysis.get('id') if _existing_analysis else None + existing_analysis_id = ( + existing_analysis.get('id') if existing_analysis else None ) - if _existing_analysis_id: + if existing_analysis_id: am = AnalysisUpdateModel( type=analysis['type'], - output=_copy_files_in_dict( + output=copy_files_in_dict( analysis['output'], project, (str(sg['id']), new_sg_map[s['externalId']][0]), @@ -425,13 +359,13 @@ def transfer_analyses( meta=analysis['meta'], ) aapi.update_analysis( - analysis_id=_existing_analysis_id, + analysis_id=existing_analysis_id, analysis_update_model=am, ) else: am = Analysis( type=analysis['type'], - output=_copy_files_in_dict( + output=copy_files_in_dict( analysis['output'], project, (str(sg['id']), new_sg_map[s['externalId']][0]), @@ -445,21 +379,31 @@ def transfer_analyses( aapi.create_analysis(project=target_project, analysis=am) -def _get_existing_sample(data: dict, sample_id: str) -> dict: - for sample in data.get('project').get('samples'): +def get_existing_sample(data: dict, sample_id: str) -> dict | None: + """ + Get the existing sample object for this ID + Returns: + The Sample dictionary, or None if unmatched + """ + for sample in data.get('project', {}).get('samples', []): if sample.get('externalId') == sample_id: return sample return None -def _get_existing_sg( +def get_existing_sg( existing_data: dict, sample_id: str, sg_type: str = None, sg_id: str = None -) -> dict: +) -> dict | None: + """ + Find a SG ID in the main data based on a sample ID + Match either on CPG ID or type (exome/genome) + Returns: + The SG Data, or None if no match is found + """ if not sg_type and not sg_id: raise ValueError('Must provide sg_type or sg_id when getting exsisting sg') - sample = _get_existing_sample(existing_data, sample_id) - if sample: + if sample := get_existing_sample(existing_data, sample_id): for sg in sample.get('sequencingGroups'): if sg_id and sg.get('id') == sg_id: return sg @@ -469,80 +413,71 @@ def _get_existing_sg( return None -def _get_existing_assay( +def get_existing_assay( data: dict, sample_id: str, sg_id: str, assay_type: str -) -> dict: - sg = _get_existing_sg( - existing_data=data, - sample_id=sample_id, - sg_id=sg_id, - ) - for assay in sg.get('assays'): - if assay.get('type') == assay_type: - return assay +) -> dict | None: + """ + Find assay in main data for this SGID + Returns: + The Assay Data, or None if no match is found + """ + if sg := get_existing_sg(existing_data=data, sample_id=sample_id, sg_id=sg_id): + for assay in sg.get('assays', []): + if assay.get('type') == assay_type: + return assay return None -def _get_existing_analysis( +def get_existing_analysis( data: dict, sample_id: str, sg_id: str, analysis_type: str -) -> dict: - sg = _get_existing_sg(existing_data=data, sample_id=sample_id, sg_id=sg_id) - for analysis in sg.get('analyses'): - if analysis.get('type') == analysis_type: - return analysis +) -> dict | None: + """ + Find the existing SG for this sample, then identify any relevant analysis objs + Returns: + an analysis dict, or None if the right type isn't found + """ + if sg := get_existing_sg(existing_data=data, sample_id=sample_id, sg_id=sg_id): + for analysis in sg.get('analyses', []): + if analysis.get('type') == analysis_type: + return analysis return None -def _get_sids_for_families( - project: str, - families_n: int, - additional_families, -) -> list[str]: +def get_sids_for_families( + project: str, families_n: int, additional_families: set[str] +) -> set[str]: """Returns specific sequencing groups to be included in the test project.""" - included_sids: list = [] - _num_families_to_subset: int = None - _randomly_selected_families: list = [] - - # Case 1: If neither families_n nor _additional_families - if not families_n and not additional_families: - return included_sids - - # Case 2: If families_n but not _additional_families - if families_n and not additional_families: - _num_families_to_subset = families_n - - # Case 3: If both families_n and _additional_families - if families_n and additional_families: - _num_families_to_subset = families_n - len(additional_families) - family_sgid_output = query(QUERY_FAMILY_SGID, {'project': project}) - # 1. Remove the families in _families_to_subset - all_family_sgids = family_sgid_output.get('project').get('families') - _filtered_family_sgids = [ - fam for fam in all_family_sgids if fam['externalId'] not in additional_families - ] - _user_input_families = [ + all_family_sgids = family_sgid_output.get('project', {}).get('families', []) + assert all_family_sgids, 'No families returned in GQL result' + + # 1. Remove the specifically requested families + user_input_families = [ fam for fam in all_family_sgids if fam['externalId'] in additional_families ] # TODO: Replace this with the nice script that randomly selects better :) - # 2. Randomly select _num_families_to_subset from the remaining families - if _num_families_to_subset: - _randomly_selected_families = random.sample( - _filtered_family_sgids, _num_families_to_subset + # 2. Randomly select from the remaining families (families_n can be 0) + user_input_families.extend( + random.sample( + [ + fam + for fam in all_family_sgids + if fam['externalId'] not in additional_families + ], + families_n, ) + ) - # 3. Combine the families in _families_to_subset with the randomly selected families & return sequencing group ids - - _all_families_to_subset = _randomly_selected_families + _user_input_families - - for fam in _all_families_to_subset: + # 3. Pull SGs from random + specific families + included_sids: set[str] = set() + for fam in user_input_families: for participant in fam['participants']: for sample in participant['samples']: - included_sids.append(sample['id']) + included_sids.add(sample['id']) return included_sids @@ -635,9 +570,9 @@ def transfer_participants( 'reported_gender': participant.get('reportedGender'), 'reported_sex': participant.get('reportedSex'), 'id': participant.get('id'), + 'samples': [], } # Participants are being created before the samples are, so this will be empty for now. - transfer_participant['samples'] = [] participants_to_transfer.append(transfer_participant) upserted_participants = papi.upsert_participants( @@ -653,110 +588,10 @@ def transfer_participants( return external_to_internal_participant_id_map -def get_samples_for_families(project: str, additional_families: list[str]) -> list[str]: - """Returns the samples that belong to a list of families""" - - full_pedigree = fapi.get_pedigree( - project=project, - replace_with_participant_external_ids=False, - replace_with_family_external_ids=True, - ) - - ipids = [ - family['individual_id'] - for family in full_pedigree - if family['family_id'] in additional_families - ] - - sample_objects = sapi.get_samples( - body_get_samples={ - 'project_ids': [project], - 'participant_ids': ipids, - 'active': True, - } - ) - - samples: list[str] = [sample['id'] for sample in sample_objects] - - return samples - - -def get_fams_for_samples( - project: str, - additional_samples: Optional[list[str]] = None, -) -> list[str]: - """Returns the families that a list of samples belong to""" - sample_objects = sapi.get_samples( - body_get_samples={ - 'project_ids': [project], - 'sample_ids': additional_samples, - 'active': True, - } - ) - - pids = [sample['participant_id'] for sample in sample_objects] - full_pedigree = fapi.get_pedigree( - project=project, - replace_with_participant_external_ids=False, - replace_with_family_external_ids=True, - ) - - fams: set[str] = { - fam['family_id'] for fam in full_pedigree if str(fam['individual_id']) in pids - } - - return list(fams) - - -def _normalise_map(unformatted_map: list[list[str]]) -> dict[str, str]: - """Input format: [[value1,key1,key2],[value2,key4]] - Output format: {key1:value1, key2: value1, key3:value2}""" - - normalised_map = {} - for group in unformatted_map: - value = group[0] - for key in group[1:]: - normalised_map[key] = value - - return normalised_map - - -def _validate_opts( - samples_n: int, families_n: int -) -> tuple[Optional[int], Optional[int]]: - """Validates the options passed to the script""" - if samples_n is None and families_n is None: - samples_n = DEFAULT_SAMPLES_N - logger.info( - f'Neither --samples nor --families specified, defaulting to selecting ' - f'{samples_n} samples' - ) - - return samples_n, families_n - - -def _print_fam_stats(families: list[dict[str, str]]): - family_sizes = Counter([fam['family_id'] for fam in families]) - fam_by_size: CounterType[int] = Counter() - # determine number of singles, duos, trios, etc - for fam in family_sizes: - fam_by_size[family_sizes[fam]] += 1 - for fam_size in sorted(fam_by_size): - if fam_size == 1: - label = 'singles' - elif fam_size == 2: - label = 'duos' - elif fam_size == 3: - label = 'trios' - else: - label = f'{fam_size} members' - logger.info(f' {label}: {fam_by_size[fam_size]}') - - -def _get_random_families( +def get_random_families( families: list[dict[str, str]], families_n: int, - include_single_person_families: Optional[bool] = False, + include_single_person_families: bool = False, ) -> list[str]: """Obtains a subset of families, that are a little less random. By default single-person families are discarded. @@ -801,7 +636,7 @@ def _get_random_families( return returned_families -def _copy_files_in_dict(d, dataset: str, sid_replacement: tuple[str, str] = None): +def copy_files_in_dict(d, dataset: str, sid_replacement: tuple[str, str] = None): """ Replaces all `gs://cpg-{project}-main*/` paths into `gs://cpg-{project}-test*/` and creates copies if needed @@ -840,16 +675,12 @@ def _copy_files_in_dict(d, dataset: str, sid_replacement: tuple[str, str] = None subprocess.run(cmd, check=False, shell=True) return new_path if isinstance(d, list): - return [_copy_files_in_dict(x, dataset) for x in d] + return [copy_files_in_dict(x, dataset) for x in d] if isinstance(d, dict): - return {k: _copy_files_in_dict(v, dataset) for k, v in d.items()} + return {k: copy_files_in_dict(v, dataset) for k, v in d.items()} return d -def _pretty_format_samples(samples: list[dict[str, str]]) -> str: - return ', '.join(f"{s['id']}/{s['external_id']}" for s in samples) - - def file_exists(path: str) -> bool: """ Check if the object exists, where the object can be: @@ -868,5 +699,45 @@ def file_exists(path: str) -> bool: if __name__ == '__main__': - # pylint: disable=no-value-for-parameter - main() + parser = ArgumentParser(description='Argument parser for subset generator') + parser.add_argument( + '--project', required=True, help='The sample-metadata project ($DATASET)' + ) + parser.add_argument('-n', type=int, help='# Random Samples to copy', default=0) + parser.add_argument('-f', type=int, help='# Random families to copy', default=0) + # Flag to be used when there isn't available pedigree/family information. + parser.add_argument( + '--skip-ped', + action='store_true', + help='Skip transferring pedigree/family information', + ) + parser.add_argument( + '--families', + nargs='+', + help='Additional families to include.', + type=set, + default={}, + ) + parser.add_argument( + '--samples', + nargs='+', + help='Additional samples to include.', + type=set, + default={}, + ) + parser.add_argument( + '--noninteractive', action='store_true', help='Skip interactive confirmation' + ) + args, fail = parser.parse_known_args() + if fail: + parser.print_help() + raise AttributeError(f'Invalid arguments: {fail}') + + main( + project=args.project, + samples_n=args.n, + families_n=args.f, + additional_samples=args.samples, + additional_families=args.families, + skip_ped=args.skip_ped, + )