diff --git a/helix/evoprotgrad.py b/helix/evoprotgrad.py index 852558b..b26b7cb 100644 --- a/helix/evoprotgrad.py +++ b/helix/evoprotgrad.py @@ -5,7 +5,7 @@ from .main import stub -def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "facebook/esm2_t33_650M_UR50D", "facebook/esm2_t36_3B_UR50D"]): +def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "facebook/esm2_t33_650M_UR50D", "facebook/esm2_t36_3B_UR50D", "facebook/esm2_t48_15B_UR50D"]): from transformers import EsmForMaskedLM, AutoTokenizer for slug in slugs: EsmForMaskedLM.from_pretrained(slug) @@ -19,7 +19,7 @@ def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "fa "pandas").run_function(download_esm_models) -@stub.cls(gpu='A10G', timeout=2000, image=image, allow_cross_region_volumes=True, concurrency_limit=9) +@stub.cls(gpu='A100', timeout=2000, image=image, allow_cross_region_volumes=True, concurrency_limit=9) class EvoProtGrad: def __init__(self, experts: list[str] = ["esm"], device: str = "cuda"): from evo_prot_grad import get_expert @@ -50,7 +50,7 @@ def evolve(self, sequence: str, n_steps: int = 100, parallel_chains: int = 10, m @stub.local_entrypoint() -def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_fasta_file: str = None, experts: str = "esm", n_steps: int = 100, num_chains: int = 20, max_mutations: int = -1, random_seed: int = None, concurrency_limit: int = 30): +def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_fasta_file: str = None, experts: str = "esm", n_steps: int = 100, num_chains: int = 20, max_mutations: int = -1, random_seed: int = None, batch_size: int = 9): from .evoprotgrad import EvoProtGrad from helix.utils import dataframe_to_fasta, count_mutations @@ -61,13 +61,13 @@ def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_ raise Exception( "Must specify either output_csv_file or output_fasta_file") - num_calls = num_chains // concurrency_limit - remaining_chains = num_chains % concurrency_limit + num_calls = num_chains // batch_size + remaining_chains = num_chains % batch_size print( f"Running {num_chains} parallel chains in {num_calls+1} containers") results = [] - args = [(sequence, n_steps, concurrency_limit, max_mutations, random_seed) + args = [(sequence, n_steps, batch_size, max_mutations, random_seed) for _ in range(num_calls)] if remaining_chains > 0: args.append((sequence, n_steps, remaining_chains, diff --git a/pyproject.toml b/pyproject.toml index e6cb6e0..62e4440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "helixbio" -version = "0.1.9" +version = "0.2.0" description = "" authors = ["Ragnor Comerford "] readme = "README.md"