Skip to content

Commit

Permalink
better names and type annots
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 3, 2024
1 parent c3bef03 commit 2af11e1
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
clamp_probability,
aa_mask_tensor_of,
stack_heterogeneous,
pick_device,
)
import netam.framework as framework
from netam.hyper_burrito import HyperBurrito
Expand All @@ -47,11 +46,11 @@
class DNSMDataset(Dataset):
def __init__(
self,
nt_parents,
nt_children,
all_rates,
all_subs_probs,
branch_lengths,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
branch_lengths: torch.Tensor,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
Expand Down Expand Up @@ -97,15 +96,18 @@ def __init__(
@classmethod
def from_data(
cls,
nt_parents,
nt_children,
all_rates,
all_subs_probs,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
):
"""
Alternative constructor that takes the raw data and calculates the initial
branch lengths.
Alternative constructor that takes the raw data and calculates the
initial branch lengths.
The `_series` arguments are series of Tensor which get concatenated to
create the full object.
"""
initial_branch_lengths = np.array(
[
Expand All @@ -117,9 +119,8 @@ def from_data(
return cls(
nt_parents.reset_index(drop=True),
nt_children.reset_index(drop=True),
# TODO we should use different names or something
stack_heterogeneous(all_rates.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs.reset_index(drop=True)),
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
initial_branch_lengths,
)

Expand All @@ -135,8 +136,8 @@ def clone(self):

def clone_with_indices(self, indices):
new_dataset = DNSMDataset(
self.nt_parents[indices],
self.nt_children[indices],
self.nt_parents[indices].reset_index(drop=True),
self.nt_children[indices].reset_index(drop=True),
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
Expand Down

0 comments on commit 2af11e1

Please sign in to comment.