Skip to content

Commit

Permalink
Update generate caids pipeline: add gnomAD v4 and update Hail usage (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nadeaujoshua authored Feb 16, 2024
1 parent 141ca39 commit 85950f5
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
4 changes: 3 additions & 1 deletion data-pipeline/caids/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
```
hailctl dataproc start my-cluster
hailctl dataproc submit my-cluster export_vcfs.py "gnomAD v4.0" gs://my-bucket/path/to/gnomad_v4.vcf.gz
hailctl dataproc submit my-cluster export_vcfs.py "gnomAD v3.1.1" gs://my-bucket/path/to/gnomad_v3.vcf.gz
hailctl dataproc submit my-cluster export_vcfs.py "gnomAD v2.1.1" gs://my-bucket/path/to/gnomad_v2.vcf.gz
hailctl dataproc submit my-cluster export_vcfs.py "ExAC" gs://my-bucket/path/to/exac.vcf.gz
Expand All @@ -19,7 +20,8 @@
```
gcloud compute instances create my-instance \
--machine-type=e2-standard-2 \
--scopes=default,storage-rw
--scopes=default,storage-rw \
--subnet=my-subnet
```

- Connect to the instance.
Expand Down
13 changes: 12 additions & 1 deletion data-pipeline/caids/export_vcfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
import hail as hl


def get_gnomad_v4_variants() -> hl.Table:
"""Get locus/alleles for all gnomAD v4 variants."""
ds = hl.read_table("gs://gcp-public-data--gnomad/release/4.0/ht/genomes/gnomad.genomes.v4.0.sites.ht")
ds = ds.select_globals()
ds = ds.select()
ds = ds.repartition(5000, shuffle=True)
return ds


def get_gnomad_v3_variants() -> hl.Table:
"""Get locus/alleles for all gnomAD v3 variants."""
ds = hl.read_table("gs://gcp-public-data--gnomad/release/3.1.1/ht/genomes/gnomad.genomes.v3.1.1.sites.ht")
Expand Down Expand Up @@ -45,6 +54,8 @@ def get_exac_variants() -> hl.Table:

def get_variants(dataset: str) -> hl.Table:
"""Get locus/alleles for all variants in the given dataset."""
if dataset == "gnomAD v4.0":
return get_gnomad_v4_variants()
if dataset == "gnomAD v3.1.1":
return get_gnomad_v3_variants()
if dataset == "gnomAD v2.1.1":
Expand Down Expand Up @@ -76,7 +87,7 @@ def export_vcfs(ds: hl.Table, output_url: str) -> None:

def main():
parser = argparse.ArgumentParser()
parser.add_argument("dataset", choices=("ExAC", "gnomAD v2.1.1", "gnomAD v3.1.1"))
parser.add_argument("dataset", choices=("ExAC", "gnomAD v2.1.1", "gnomAD v3.1.1", "gnomAD v4.0"))
parser.add_argument("output_url")
args = parser.parse_args()

Expand Down
24 changes: 14 additions & 10 deletions data-pipeline/caids/get_caids.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Awaitable, Callable, TypeVar

import aiohttp
from hailtop.aiotools import LocalAsyncFS, RouterAsyncFS
from hailtop.aiogoogle import GoogleStorageAsyncFS
from hailtop.utils import bounded_gather, sleep_and_backoff, tqdm
from hailtop.aiotools.router_fs import RouterAsyncFS
from hailtop.utils import bounded_gather, sleep_before_try
from hailtop.utils.rich_progress_bar import SimpleCopyToolProgressBar
from rich.console import Console

console = Console()

logger = logging.getLogger("get_caids")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -86,19 +88,18 @@ def is_transient_error(e):


async def retry_transient_errors(f: Callable[..., Awaitable[T]], max_attempts: int = 3) -> T:
delay = 0.1
errors = 0
tries = 0
while True:
try:
return await f()
except Exception as e:
if not is_transient_error(e):
raise
errors += 1
if errors >= max_attempts:
tries += 1
if tries >= max_attempts:
raise

delay = await sleep_and_backoff(delay)
await sleep_before_try(tries)


async def get_caids(sharded_vcf_url: str, output_url: str, *, parallelism: int = 4, request_timeout: int = 10,) -> None:
Expand All @@ -118,8 +119,9 @@ async def get_caids(sharded_vcf_url: str, output_url: str, *, parallelism: int =
output_url = output_url.rstrip("/")

with ThreadPoolExecutor() as thread_pool:
local_kwargs = {'thread_pool': thread_pool}
async with RouterAsyncFS(
"file", [LocalAsyncFS(thread_pool), GoogleStorageAsyncFS()]
local_kwargs=local_kwargs
) as fs, aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=request_timeout * 60)) as session:
# The ClinGen Allele Registry API does not accept VCFs with contigs other than 1-22, X, Y, and M.
# Remove other contigs from the VCF header.
Expand Down Expand Up @@ -158,7 +160,9 @@ async def get_caids(sharded_vcf_url: str, output_url: str, *, parallelism: int =
if part_name not in completed_parts:
remaining_part_urls.append(part_url)

with tqdm(total=len(remaining_part_urls)) as progress:
logger.warning(f'\n\nParts Counts\nTotal: {len(all_part_urls)}\nCompleted: {len(completed_parts)}\nRemaining: {len(remaining_part_urls)}\n')

with SimpleCopyToolProgressBar(total=len(remaining_part_urls)) as progress:

def create_task(part_url):
async def task():
Expand Down

0 comments on commit 85950f5

Please sign in to comment.