Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes; pyproject.toml file #47

Merged
merged 19 commits into from
Aug 13, 2024
Merged
18 changes: 15 additions & 3 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,11 @@ def find_least_used_cuda_gpu():
return utilization.index(min(utilization))


def pick_device(gpu_index=0):
def pick_device(gpu_index=None):
"""
Pick a device for PyTorch to use. If CUDA is available, use the least used
GPU, and if all are idle use the gpu_index modulo the number of GPUs.
GPU, and if all are idle use the gpu_index modulo the number of GPUs. If
gpu_index is None, then use a random GPU.
"""

# check that CUDA is usable
Expand All @@ -198,7 +199,10 @@ def check_CUDA():
if torch.backends.cudnn.is_available() and check_CUDA():
which_gpu = find_least_used_cuda_gpu()
if which_gpu is None:
which_gpu = gpu_index % torch.cuda.device_count()
if gpu_index is None:
which_gpu = np.random.randint(torch.cuda.device_count())
else:
which_gpu = gpu_index % torch.cuda.device_count()
print(f"Using CUDA GPU {which_gpu}")
return torch.device(f"cuda:{which_gpu}")
elif torch.backends.mps.is_available():
Expand Down Expand Up @@ -235,6 +239,14 @@ def get_memory_usage_mb():
return usage.ru_maxrss / 1024 # Convert from KB to MB


def tensor_to_np_if_needed(x):
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
assert isinstance(x, np.ndarray)
return x


# Reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
Expand Down
25 changes: 20 additions & 5 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,9 @@ def load_branch_lengths(self, in_csv_prefix):
)
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def predictions_of_batch(self, batch):
def prediction_pair_of_batch(self, batch):
"""
Make predictions for a batch of data.

Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
Get log neutral mutation probabilities and log selection factors for a batch of data.
"""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
Expand All @@ -325,12 +322,27 @@ def predictions_of_batch(self, batch):
f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
return log_neutral_aa_mut_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
predictions = torch.exp(log_neutral_aa_mut_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
return predictions

def predictions_of_batch(self, batch):
"""
Make predictions for a batch of data.

Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
"""
log_neutral_aa_mut_probs, log_selection_factors = self.prediction_pair_of_batch(
batch
)
return self.predictions_of_pair(log_neutral_aa_mut_probs, log_selection_factors)

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
Expand Down Expand Up @@ -408,6 +420,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
Expand Down
20 changes: 16 additions & 4 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
kmer_to_index_of,
nt_mask_tensor_of,
optimizer_of_name,
tensor_to_np_if_needed,
BASES,
BASES_AND_N_TO_INDEX,
BIG,
Expand Down Expand Up @@ -217,8 +218,10 @@ def normalized_mutation_frequency(self):
def export_branch_lengths(self, out_csv_path):
pd.DataFrame(
{
"branch_length": self.branch_lengths,
"mut_freq": self.normalized_mutation_frequency(),
"branch_length": tensor_to_np_if_needed(self.branch_lengths),
"mut_freq": tensor_to_np_if_needed(
self.normalized_mutation_frequency()
),
}
).to_csv(out_csv_path, index=False)

Expand Down Expand Up @@ -346,15 +349,24 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):


def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
pcp_df = pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0).reset_index(
drop=True
"""
Load a PCP dataframe from a gzipped CSV file.

`orig_pcp_idx` is the index column from the original file, even if we subset by
sampling or by choosing V families.
"""
pcp_df = (
pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0)
.reset_index()
.rename(columns={"index": "orig_pcp_idx"})
)
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
if sample_count is not None:
pcp_df = pcp_df.sample(sample_count)
pcp_df.reset_index(drop=True, inplace=True)
return pcp_df


Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
packages=find_packages(),
install_requires=[
"biopython",
"natsort",
"optuna",
"pandas",
"pyyaml",
Expand Down