Skip to content

Commit

Permalink
Update to v0.2.0
Browse files Browse the repository at this point in the history
Rename spacy model with `en_core_web_sm`.
Fix special token handling for transformers lm.
Update todos in readme.md.
  • Loading branch information
sai-prasanna committed Nov 27, 2019
1 parent 2a7cff2 commit 98f1aea
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 32 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ Unlike many approaches to GEC, this approach does NOT require annotated training

This work builds upon https://github.com/chrisjbryant/lmgec-lite/



## Components

### Language Models
Expand All @@ -42,8 +40,8 @@ Pre-trained language models for other languages, inflectors, common error patter

## TODOs

* Research on distilling gpt-2 to a smaller model (LSTM?) to reduce the horrendous latency.
* Experiment on GEC dev sets to obtain optimal thresholds.
* Anyway to handle insertions.
* Use edits in existing GEC corpus to generate candidates.
* Tests
* Publish benchmarks of the model.
* Think of simple ways to generate insertion candidates.
* Add more languages.
* Check whether LemmInflect proposals are actually better than just using [AGID](https://github.com/sai-prasanna/lmgec-lite/tree/master/resources/agid-2016.01.19).
2 changes: 1 addition & 1 deletion lmproof/candidate_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load(cls, language: str) -> "SpellCorrectGenerator":
/ "frequency_dictionary_en_82_765.txt"
)
sym_spell.create_dictionary(str(dict_path))
spacy_model = spacy.load("en", disable=["parser", "ner"])
spacy_model = spacy.load("en_core_web_sm", disable=["parser", "ner"])
else:
raise RuntimeError(f"The language {language} is currently not language.")
return cls(sym_spell, spacy_model)
Expand Down
8 changes: 6 additions & 2 deletions lmproof/scorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional
import logging

import torch
import logging
from torch.nn import CrossEntropyLoss
from transformers import (
AutoTokenizer,
Expand All @@ -25,6 +25,7 @@ def __init__(
model: PreTrainedModel,
device: str = "cpu",
batch_size: int = 1,
add_special_tokens: bool = False,
normalize: bool = False,
):
# Load pre-trained model tokenizer (vocabulary)
Expand All @@ -33,6 +34,7 @@ def __init__(
self.model = model.to(self.device).eval()
self.batch_size = batch_size
self.normalize = normalize
self._add_special_tokens = add_special_tokens
self._loss_fn = CrossEntropyLoss(ignore_index=-1)

@classmethod
Expand All @@ -58,7 +60,9 @@ def score(self, sentences: List[str]) -> List[Optional[float]]:

tokenized_batch = []
for i, sentence in enumerate(batched_sentences):
tokens = self.tokenizer.encode(sentence)
tokens = self.tokenizer.encode(
sentence, add_special_tokens=self._add_special_tokens
)
if len(tokens) <= self.tokenizer.max_len:
tokenized_batch.append(torch.LongTensor(tokens)) # type: ignore
batch_scored_idx.append(i)
Expand Down
Loading

0 comments on commit 98f1aea

Please sign in to comment.