Skip to content

Commit

Permalink
Add support for cohorts in create_test_subset script
Browse files Browse the repository at this point in the history
  • Loading branch information
vivbak committed Jun 9, 2024
1 parent b8121aa commit daca8cc
Showing 1 changed file with 78 additions and 6 deletions.
84 changes: 78 additions & 6 deletions scripts/create_test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,34 +166,65 @@
"""
)

COHORT_QUERY = gql(
"""
query CohortQuery($project: String!) {
project(name: $project) {
cohorts {
id
sequencingGroups {
sample {
id
}
}
}
}
}
"""
)


def main(
project: str,
samples_n: int,
families_n: int,
cohort_samples_n: int,
additional_families: set[str],
additional_samples: set[str],
cohorts: set[str],
skip_ped: bool = True,
):
"""
Script creates a test subset for a given project.
A new project with a prefix -test is created, and for any files in sample/meta,
A new project with a suffix -test is created, and for any files in sample/meta,
sequence/meta, or analysis/output a copy in the -test namespace is created.
"""

if not any([additional_families, additional_samples, samples_n, families_n]):
if not any(
[additional_families, additional_samples, samples_n, families_n, cohorts]
):
raise ValueError('Come on, what exactly are you asking for?')

if cohorts and not cohort_samples_n:
raise ValueError(
'You must specify the number of samples to transfer from the cohort.'
)

# for reproducibility
logger.info('Setting random seed to 42')
random.seed(42)

# 1. Find and SG IDs to be moved by Family ID -test.
# 1. Find SG IDs to be moved by Family ID to -test.
if families_n or additional_families:
additional_samples.update(
get_sids_for_families(project, families_n, additional_families)
)

# 1.5 Find SG IDs to be moved by Cohort ID to -test.
if cohorts:
additional_samples.update(
get_sids_for_cohorts(project, cohorts, cohort_samples_n)
)

# 2. Get all sample IDs and their SG IDs in project.
logger.info(f'Querying all samples in {project}')
sid_output = query(SG_ID_QUERY, variables={'project': project})
Expand All @@ -206,7 +237,7 @@ def main(
)

# 4. Query all the samples from the selected sgs
logger.info(f'Transfering {len(additional_samples)} samples. Querying metadata.')
logger.info(f'Transferring {len(additional_samples)} samples. Querying metadata.')
original_project_subset_data = query(
QUERY_ALL_DATA, {'project': project, 'sids': list(additional_samples)}
)
Expand Down Expand Up @@ -627,6 +658,30 @@ def get_sids_for_families(
return included_sids


def get_sids_for_cohorts(
project: str, cohorts: set[str], cohort_samples_n: int
) -> set[str]:
"""Returns cohort_samples_n specific samples for given cohort IDs."""

cohort_sid_output = query(COHORT_QUERY, {'project': project})

all_cohort_groups = cohort_sid_output.get('project', {}).get('cohorts', [])

all_cohorts_sample_ids_subset: set[str] = set()
for cohort in all_cohort_groups:
sids_for_cohort: list[str] = []
if cohort.get('id') in cohorts:
seq_groups = cohort.get('sequencingGroups', [])
for seq_group in seq_groups:
sample = seq_group.get('sample')
sids_for_cohort.append(sample['id'])
all_cohorts_sample_ids_subset.update(
random.sample(sids_for_cohort, cohort_samples_n)
)

return all_cohorts_sample_ids_subset


def transfer_families(
initial_project: str, target_project: str, internal_participant_ids: list[int]
) -> list[int]:
Expand Down Expand Up @@ -852,8 +907,16 @@ def file_exists(path: str) -> bool:
parser.add_argument(
'--project', required=True, help='The sample-metadata project ($DATASET)'
)
parser.add_argument('-n', type=int, help='# Random Samples to copy', default=0)
parser.add_argument(
'-n', type=int, help='# Random Samples to copy', default=DEFAULT_SAMPLES_N
)
parser.add_argument('-f', type=int, help='# Random families to copy', default=0)
parser.add_argument(
'-nsamples-cohort',
type=int,
help='# Random samples to copy from each cohort',
default=0,
)
# Flag to be used when there isn't available pedigree/family information.
parser.add_argument(
'--skip-ped',
Expand All @@ -874,6 +937,13 @@ def file_exists(path: str) -> bool:
type=str,
default={},
)
parser.add_argument(
'--cohorts',
nargs='+',
help='Cohorts to take random samples from.',
type=str,
default={},
)
parser.add_argument(
'--noninteractive', action='store_true', help='Skip interactive confirmation'
)
Expand All @@ -886,7 +956,9 @@ def file_exists(path: str) -> bool:
project=args.project,
samples_n=args.n,
families_n=args.f,
cohort_samples_n=args.nsamples_cohort,
additional_samples=set(args.samples),
additional_families=set(args.families),
skip_ped=args.skip_ped,
cohorts=set(args.cohorts),
)

0 comments on commit daca8cc

Please sign in to comment.