Skip to content

Commit

Permalink
Format docstrings and use formatter (#59)
Browse files Browse the repository at this point in the history
* format more

* run docformatter

* add docformatter to workflow file

* run black after docformatter

* reformat docstring code
  • Loading branch information
willdumm authored Sep 18, 2024
1 parent 7222493 commit 3f98c6c
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 270 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
python=${{ matrix.python-version }}
black
flake8
docformatter
init-shell: bash
cache-environment: false
post-cleanup: 'none'
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ test:
pytest tests

format:
docformatter --in-place --black --recursive netam tests
black netam tests

checkformat:
docformatter --check --black --recursive netam tests
black --check netam tests

lint:
Expand Down
10 changes: 4 additions & 6 deletions netam/attention_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@


def reshape_tensor(tensor, head_count):
"""
Reshape the tensor to include the head dimension.
"""Reshape the tensor to include the head dimension.
Assumes batch size is 1 and squeezes it out.
"""
assert tensor.size(0) == 1, "Batch size should be 1"
Expand Down Expand Up @@ -60,10 +60,8 @@ def wrap(*args, **kwargs):


def attention_mapss_of(model, sequences):
"""
Get a list of attention maps (across sequences) as described in the module
docstring.
"""
"""Get a list of attention maps (across sequences) as described in the module
docstring."""
model = copy.deepcopy(model)
model.eval()
layer_count = len(model.encoder.layers)
Expand Down
3 changes: 1 addition & 2 deletions netam/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def concatenate_csvs(
is_tsv: bool = False,
record_path: bool = False,
):
"""
This function concatenates multiple CSV or TSV files into one CSV file.
"""This function concatenates multiple CSV or TSV files into one CSV file.
Args:
input_csvs: A string of paths to the input CSV or TSV files separated by commas.
Expand Down
51 changes: 26 additions & 25 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def kmer_to_index_of(all_kmers):


def aa_idx_tensor_of_str_ambig(aa_str):
"""Return the indices of the amino acids in a string, allowing the ambiguous character."""
"""Return the indices of the amino acids in a string, allowing the ambiguous
character."""
try:
return torch.tensor(
[AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int
Expand All @@ -57,8 +58,10 @@ def aa_idx_tensor_of_str_ambig(aa_str):


def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
"""Return a mask tensor indicating non-empty and non-ambiguous sites. Sites
beyond the length of the sequence are masked."""
"""Return a mask tensor indicating non-empty and non-ambiguous sites.
Sites beyond the length of the sequence are masked.
"""
if length is None:
length = len(seq_str)
mask = torch.zeros(length, dtype=torch.bool)
Expand Down Expand Up @@ -102,15 +105,16 @@ def parameter_count_of_model(model):


def stack_heterogeneous(tensors, pad_value=0.0):
"""
Stack an iterable of 1D or 2D torch.Tensor objects of different lengths along the first dimension into a single tensor.
"""Stack an iterable of 1D or 2D torch.Tensor objects of different lengths along the
first dimension into a single tensor.
Parameters:
tensors (iterable): An iterable of 1D or 2D torch.Tensor objects with variable lengths in the first dimension.
pad_value (number): The value used for padding shorter tensors. Default is 0.
black --check netam tests
Args:
tensors (iterable): An iterable of 1D or 2D torch.Tensor objects with variable lengths in the first dimension.
pad_value (number): The value used for padding shorter tensors. Default is 0.
Returns:
torch.Tensor: A stacked tensor with all input tensors padded to the length of the longest tensor in the first dimension.
torch.Tensor: A stacked tensor with all input tensors padded to the length of the longest tensor in the first dimension.
"""
if tensors is None or len(tensors) == 0:
return torch.Tensor() # Return an empty tensor if no tensors are provided
Expand Down Expand Up @@ -144,8 +148,7 @@ def stack_heterogeneous(tensors, pad_value=0.0):


def optimizer_of_name(optimizer_name, model_parameters, **kwargs):
"""
Build a torch.optim optimizer from a string name and model parameters.
"""Build a torch.optim optimizer from a string name and model parameters.
Use a SGD optimizer with momentum if the optimizer_name is "SGDMomentum".
"""
Expand All @@ -162,8 +165,8 @@ def optimizer_of_name(optimizer_name, model_parameters, **kwargs):


def find_least_used_cuda_gpu():
"""
Find the least used CUDA GPU on the system using nvidia-smi.
"""Find the least used CUDA GPU on the system using nvidia-smi.
If they are all idle, return None.
"""
result = subprocess.run(
Expand All @@ -182,10 +185,10 @@ def find_least_used_cuda_gpu():


def pick_device(gpu_index=None):
"""
Pick a device for PyTorch to use. If CUDA is available, use the least used
GPU, and if all are idle use the gpu_index modulo the number of GPUs. If
gpu_index is None, then use a random GPU.
"""Pick a device for PyTorch to use.
If CUDA is available, use the least used GPU, and if all are idle use the gpu_index
modulo the number of GPUs. If gpu_index is None, then use a random GPU.
"""

# check that CUDA is usable
Expand Down Expand Up @@ -214,11 +217,10 @@ def check_CUDA():


def print_tensor_devices(scope="local"):
"""
Print the devices of all PyTorch tensors in the given scope.
"""Print the devices of all PyTorch tensors in the given scope.
Args:
scope (str): 'local' for local scope, 'global' for global scope.
scope (str): 'local' for local scope, 'global' for global scope.
"""
if scope == "local":
frame = inspect.currentframe()
Expand Down Expand Up @@ -275,15 +277,14 @@ def forward(self, x: Tensor) -> Tensor:


def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
"""
Linearly increase the learning rate from min_lr to max_lr over warmup_epochs,
"""Linearly increase the learning rate from min_lr to max_lr over warmup_epochs,
then linearly decrease the learning rate from max_lr to min_lr.
See https://github.com/matsengrp/netam/pull/41 for more details.
pd.Series([
linear_bump_lr(epoch, warmup_epochs=20, total_epochs=200, max_lr=0.01, min_lr=1e-5)
for epoch in range(200)]).plot()
Example:
.. code-block:: python
pd.Series([linear_bump_lr(epoch, warmup_epochs=20, total_epochs=200, max_lr=0.01, min_lr=1e-5) for epoch in range(200)]).plot()
"""
if epoch < warmup_epochs:
lr = min_lr + ((max_lr - min_lr) / warmup_epochs) * epoch
Expand Down
38 changes: 13 additions & 25 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
Here we define a mutation-selection model that is just about mutation vs no mutation, and is trainable.
"""Here we define a mutation-selection model that is just about mutation vs no mutation,
and is trainable.
We'll use these conventions:
* B is the batch size
* L is the max sequence length
"""

import copy
Expand Down Expand Up @@ -101,9 +100,8 @@ def of_seriess(
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
):
"""
Alternative constructor that takes the raw data and calculates the
initial branch lengths.
"""Alternative constructor that takes the raw data and calculates the initial
branch lengths.
The `_series` arguments are series of Tensors which get stacked to
create the full object.
Expand All @@ -125,10 +123,8 @@ def of_seriess(

@classmethod
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
"""
Alternative constructor that takes in a pcp_df and calculates the
initial branch lengths.
"""
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column"
return cls.of_seriess(
pcp_df["parent"],
Expand All @@ -150,8 +146,7 @@ def clone(self):
return new_dataset

def subset_via_indices(self, indices):
"""
Create a new dataset with a subset of the data, as per `indices`.
"""Create a new dataset with a subset of the data, as per `indices`.
Whether the new dataset is a deep copy or a shallow copy using slices
depends on `indices`: if `indices` is an iterable of integers, then we
Expand All @@ -167,9 +162,7 @@ def subset_via_indices(self, indices):
return new_dataset

def split(self, into_count: int):
"""
Split self into a list of into_count subsets.
"""
"""Split self into a list of into_count subsets."""
dataset_size = len(self)
indices = list(range(dataset_size))
split_indices = np.array_split(indices, into_count)
Expand Down Expand Up @@ -280,8 +273,7 @@ def to(self, device):


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
"""
Perform a train-val split based on a "in_train" column.
"""Perform a train-val split based on a "in_train" column.
Stays here so it can be used in tests.
"""
Expand Down Expand Up @@ -311,9 +303,8 @@ def load_branch_lengths(self, in_csv_prefix):
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def prediction_pair_of_batch(self, batch):
"""
Get log neutral mutation probabilities and log selection factors for a batch of data.
"""
"""Get log neutral mutation probabilities and log selection factors for a batch
of data."""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_mut_probs = batch["log_neutral_aa_mut_probs"].to(self.device)
Expand All @@ -332,8 +323,7 @@ def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
return predictions

def predictions_of_batch(self, batch):
"""
Make predictions for a batch of data.
"""Make predictions for a batch of data.
Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
Expand Down Expand Up @@ -447,9 +437,7 @@ def to_crepe(self):


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
"""
The worker used for parallel branch length optimization.
"""
"""The worker used for parallel branch length optimization."""
burrito = DNSMBurrito(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)

Expand Down
Loading

0 comments on commit 3f98c6c

Please sign in to comment.