From 394c4c1ea4a4f3b5ab4fb1bcf334ddd52deaf4c2 Mon Sep 17 00:00:00 2001 From: Michael Kranzlein <8162250+mkranzlein@users.noreply.github.com> Date: Thu, 22 Feb 2024 16:57:34 -0500 Subject: [PATCH] Added new jaxtyping syntax --- README.md | 3 +-- src/hipool/chunk.py | 5 ++--- src/hipool/eval.py | 5 ++--- src/hipool/hipool.py | 8 +++----- src/hipool/models.py | 8 +++----- src/hipool/sent_model_utils.py | 1 - 6 files changed, 11 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 7505420..66d69e9 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,7 @@ This is **not the original repo for HiPool and I am not an author on the HiPool This repo uses [jaxtyping](https://github.com/google/jaxtyping) and [typeguard](https://typeguard.readthedocs.io/) to enforce correct tensor dimensions at runtime. If you see an unfamiliar type annotation or decorators like in the example the below, it's for type checking. ```python -@jaxtyped -@typechecked +@jaxtyped(typechecker=typechecker) def some_function(x: Float[torch.Tensor, "10, 768"]): pass ``` diff --git a/src/hipool/chunk.py b/src/hipool/chunk.py index e3f7add..5d2ea32 100644 --- a/src/hipool/chunk.py +++ b/src/hipool/chunk.py @@ -9,11 +9,10 @@ import torch from jaxtyping import Float, Integer, jaxtyped from torch import Tensor -from typeguard import typechecked +from typeguard import typechecked as typechecker -@jaxtyped -@typechecked +@jaxtyped(typechecker=typechecker) def chunk_document(input_ids: list[int], chunk_len: int, overlap_len: int) -> dict: """Splits a document into chunks of wordpiece subwords.""" diff --git a/src/hipool/eval.py b/src/hipool/eval.py index 1e9e8cb..3709663 100644 --- a/src/hipool/eval.py +++ b/src/hipool/eval.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from torch.utils.data.sampler import SequentialSampler from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall -from typeguard import typechecked +from typeguard import typechecked as typechecker from hipool.models import DocModel, TokenClassificationModel from hipool.utils import collate_sentences @@ -144,8 +144,7 @@ def eval_sentence_metalanguage(doc_data_loader, token_model: TokenClassification f = class_metrics["f"].compute().item() print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}") -@jaxtyped -@typechecked +@jaxtyped(typechecker=typechecker) def get_eval_mask(seq_input_ids, # : Integer[Tensor, "k c"], overlap_len, longest_seq): """Create a mask to identify which tokens should be evaluated.""" diff --git a/src/hipool/hipool.py b/src/hipool/hipool.py index 93cb180..190f436 100644 --- a/src/hipool/hipool.py +++ b/src/hipool/hipool.py @@ -10,7 +10,7 @@ from jaxtyping import Float, jaxtyped from torch import Tensor from torch_geometric.nn import DenseGCNConv -from typeguard import typechecked +from typeguard import typechecked as typechecker class HiPool(torch.nn.Module): @@ -48,8 +48,7 @@ def map_low_to_high(self, num_low_nodes: int, mapping = mapping[:num_low_nodes] return mapping - @jaxtyped - @typechecked + @jaxtyped(typechecker=typechecker) def cluster_attention(self, x: Float[Tensor, "low low_dim"], low_to_high_mapping: Float[Tensor, "low high"], attention_weights: Float[Tensor, "low_dim low_dim"]) -> Float[Tensor, "high low_dim"]: @@ -67,8 +66,7 @@ def cluster_attention(self, x: Float[Tensor, "low low_dim"], output: Float[Tensor, "high low_dim"] = torch.matmul(scores, x) + high_representations return output - @jaxtyped - @typechecked + @jaxtyped(typechecker=typechecker) def forward(self, x: Float[Tensor, "low in_dim"], adj_matrix: Float[Tensor, "low low"]): """A forward pass through the HiPool model. diff --git a/src/hipool/models.py b/src/hipool/models.py index f4f0b57..6101930 100644 --- a/src/hipool/models.py +++ b/src/hipool/models.py @@ -9,7 +9,7 @@ import torch from jaxtyping import Float, Integer, jaxtyped from torch import nn, Tensor -from typeguard import typechecked +from typeguard import typechecked as typechecker from hipool.hipool import HiPool @@ -64,8 +64,7 @@ def __init__(self, num_labels, bert_model, device): self.device = device self.linear = nn.Linear(768, num_labels).to(device) - @jaxtyped - @typechecked + @jaxtyped(typechecker=typechecker) def forward(self, ids: Integer[Tensor, "_ c"], mask: Integer[Tensor, "_ c"], token_type_ids: Integer[Tensor, "_ c"]): @@ -91,8 +90,7 @@ def __init__(self, num_labels, bert_model, device, use_doc_embedding=False, doc_ else: self.linear = nn.Linear(768, num_labels).to(device) - @jaxtyped - @typechecked + @jaxtyped(typechecker=typechecker) def forward(self, ids: Integer[Tensor, "_ c"], mask: Integer[Tensor, "_ c"], token_type_ids: Integer[Tensor, "_ c"], diff --git a/src/hipool/sent_model_utils.py b/src/hipool/sent_model_utils.py index fba94c2..013a142 100644 --- a/src/hipool/sent_model_utils.py +++ b/src/hipool/sent_model_utils.py @@ -8,7 +8,6 @@ from torch.utils.data.sampler import SequentialSampler from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall from tqdm import tqdm -from typeguard import typechecked from hipool.models import SentenceClassificationModel from hipool.curiam_categories import REDUCED_CATEGORIES