Skip to content

Commit

Permalink
Added new jaxtyping syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Feb 22, 2024
1 parent 36dc46d commit 394c4c1
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 19 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ This is **not the original repo for HiPool and I am not an author on the HiPool
This repo uses [jaxtyping](https://github.com/google/jaxtyping) and [typeguard](https://typeguard.readthedocs.io/) to enforce correct tensor dimensions at runtime. If you see an unfamiliar type annotation or decorators like in the example the below, it's for type checking.

```python
@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def some_function(x: Float[torch.Tensor, "10, 768"]):
pass
```
Expand Down
5 changes: 2 additions & 3 deletions src/hipool/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import torch
from jaxtyping import Float, Integer, jaxtyped
from torch import Tensor
from typeguard import typechecked
from typeguard import typechecked as typechecker


@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def chunk_document(input_ids: list[int], chunk_len: int, overlap_len: int) -> dict:
"""Splits a document into chunks of wordpiece subwords."""

Expand Down
5 changes: 2 additions & 3 deletions src/hipool/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall
from typeguard import typechecked
from typeguard import typechecked as typechecker

from hipool.models import DocModel, TokenClassificationModel
from hipool.utils import collate_sentences
Expand Down Expand Up @@ -144,8 +144,7 @@ def eval_sentence_metalanguage(doc_data_loader, token_model: TokenClassification
f = class_metrics["f"].compute().item()
print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}")

@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def get_eval_mask(seq_input_ids, # : Integer[Tensor, "k c"],
overlap_len, longest_seq):
"""Create a mask to identify which tokens should be evaluated."""
Expand Down
8 changes: 3 additions & 5 deletions src/hipool/hipool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jaxtyping import Float, jaxtyped
from torch import Tensor
from torch_geometric.nn import DenseGCNConv
from typeguard import typechecked
from typeguard import typechecked as typechecker


class HiPool(torch.nn.Module):
Expand Down Expand Up @@ -48,8 +48,7 @@ def map_low_to_high(self, num_low_nodes: int,
mapping = mapping[:num_low_nodes]
return mapping

@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def cluster_attention(self, x: Float[Tensor, "low low_dim"],
low_to_high_mapping: Float[Tensor, "low high"],
attention_weights: Float[Tensor, "low_dim low_dim"]) -> Float[Tensor, "high low_dim"]:
Expand All @@ -67,8 +66,7 @@ def cluster_attention(self, x: Float[Tensor, "low low_dim"],
output: Float[Tensor, "high low_dim"] = torch.matmul(scores, x) + high_representations
return output

@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def forward(self, x: Float[Tensor, "low in_dim"],
adj_matrix: Float[Tensor, "low low"]):
"""A forward pass through the HiPool model.
Expand Down
8 changes: 3 additions & 5 deletions src/hipool/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from jaxtyping import Float, Integer, jaxtyped
from torch import nn, Tensor
from typeguard import typechecked
from typeguard import typechecked as typechecker

from hipool.hipool import HiPool

Expand Down Expand Up @@ -64,8 +64,7 @@ def __init__(self, num_labels, bert_model, device):
self.device = device
self.linear = nn.Linear(768, num_labels).to(device)

@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def forward(self, ids: Integer[Tensor, "_ c"],
mask: Integer[Tensor, "_ c"],
token_type_ids: Integer[Tensor, "_ c"]):
Expand All @@ -91,8 +90,7 @@ def __init__(self, num_labels, bert_model, device, use_doc_embedding=False, doc_
else:
self.linear = nn.Linear(768, num_labels).to(device)

@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecker)
def forward(self, ids: Integer[Tensor, "_ c"],
mask: Integer[Tensor, "_ c"],
token_type_ids: Integer[Tensor, "_ c"],
Expand Down
1 change: 0 additions & 1 deletion src/hipool/sent_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data.sampler import SequentialSampler
from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall
from tqdm import tqdm
from typeguard import typechecked

from hipool.models import SentenceClassificationModel
from hipool.curiam_categories import REDUCED_CATEGORIES
Expand Down

0 comments on commit 394c4c1

Please sign in to comment.