Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPENNLP-1525 Improve TokenizerME to make use of abbreviations provided in TokenizerModel #562

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import opennlp.tools.util.DownloadUtil;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Span;
import opennlp.tools.util.StringList;
import opennlp.tools.util.TrainingParameters;

/**
Expand Down Expand Up @@ -110,6 +111,11 @@ public class TokenizerME extends AbstractTokenizer {

private final List<Span> newTokens;

/*
* The {@link Dictionary abbreviation dictionary} if available (may be {@code null}).
*/
private final Dictionary abbDict;

/**
* Initializes a {@link TokenizerME} by downloading a default model.
* @param language The language of the tokenizer.
Expand All @@ -132,6 +138,7 @@ public TokenizerME(TokenizerModel model) {
this.model = model.getMaxentModel();
this.useAlphaNumericOptimization = factory.isUseAlphaNumericOptimization();

abbDict = model.getAbbreviations();
newTokens = new ArrayList<>();
tokProbs = new ArrayList<>(50);
}
Expand All @@ -151,6 +158,7 @@ public TokenizerME(TokenizerModel model, Factory factory) {
this.model = model.getMaxentModel();
useAlphaNumericOptimization = model.useAlphaNumericOptimization();

abbDict = model.getAbbreviations();
newTokens = new ArrayList<>();
tokProbs = new ArrayList<>(50);
}
Expand Down Expand Up @@ -182,6 +190,7 @@ public double[] getTokenProbabilities() {
*
* @return A {@link Span} array containing individual tokens as elements.
*/
@Override
public Span[] tokenizePos(String d) {
WhitespaceTokenizer whitespaceTokenizer = WhitespaceTokenizer.INSTANCE;
whitespaceTokenizer.setKeepNewLines(keepNewLines);
Expand All @@ -208,14 +217,22 @@ public Span[] tokenizePos(String d) {
String best = model.getBestOutcome(probs);
tokenProb *= probs[model.getIndex(best)];
if (best.equals(TokenizerME.SPLIT)) {
newTokens.add(new Span(start, j));
tokProbs.add(tokenProb);
start = j;
if (isAcceptableAbbreviation(tok)) {
newTokens.add(new Span(start, end));
tokProbs.add(tokenProb);
start = j + 1; // To compensate for the abbreviation dot
} else {
newTokens.add(new Span(start, j));
tokProbs.add(tokenProb);
start = j;
}
tokenProb = 1.0;
}
}
newTokens.add(new Span(start, end));
tokProbs.add(tokenProb);
if (start < end) {
newTokens.add(new Span(start, end));
tokProbs.add(tokenProb);
}
}
}

Expand Down Expand Up @@ -258,4 +275,19 @@ public boolean useAlphaNumericOptimization() {
return useAlphaNumericOptimization;
}

/**
* Allows checking a token abbreviation candidate for acceptability.
*
* <p>Note: The implementation always returns {@code false} if no
* abbreviation dictionary is available for the underlying model.</p>
*
* @param s the {@link CharSequence token} to check for.
* @return {@code true} if the candidate is acceptable, {@code false} otherwise.
*/
protected boolean isAcceptableAbbreviation(CharSequence s) {
if (abbDict == null)
return false;

return abbDict.contains(new StringList(s.toString()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Locale;
import java.util.regex.Pattern;

import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -54,24 +54,28 @@ private static TokenizerModel train(TokenizerFactory factory)
return TokenizerME.train(createSampleStream(), factory, TrainingParameters.defaultParams());
}

private static Dictionary loadAbbDictionary() throws IOException {
InputStream in = TokenizerFactoryTest.class.getClassLoader()
.getResourceAsStream("opennlp/tools/sentdetect/abb.xml");

return new Dictionary(in);
private static Dictionary loadAbbDictionary(Locale loc) throws IOException {
final String abbrevDict;
if (loc.equals(Locale.GERMAN)) {
abbrevDict = "opennlp/tools/sentdetect/abb_DE.xml";
} else {
abbrevDict = "opennlp/tools/sentdetect/abb.xml";
}
return new Dictionary(TokenizerFactoryTest.class.getClassLoader()
.getResourceAsStream(abbrevDict));
}

@Test
void testDefault() throws IOException {

Dictionary dic = loadAbbDictionary();
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
final String lang = "eng";

TokenizerModel model = train(new TokenizerFactory(lang, dic, false, null));

TokenizerFactory factory = model.getFactory();
Assertions.assertNotNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());

String defaultPattern = Factory.DEFAULT_ALPHANUMERIC.pattern();
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
Expand All @@ -87,7 +91,7 @@ void testDefault() throws IOException {

factory = fromSerialized.getFactory();
Assertions.assertNotNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());

Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Expand All @@ -105,7 +109,7 @@ void testNullDict() throws IOException {

TokenizerFactory factory = model.getFactory();
Assertions.assertNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());

String defaultPattern = Factory.DEFAULT_ALPHANUMERIC.pattern();
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
Expand All @@ -121,7 +125,7 @@ void testNullDict() throws IOException {

factory = fromSerialized.getFactory();
Assertions.assertNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());

Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Expand All @@ -141,7 +145,7 @@ void testCustomPatternAndAlphaOpt() throws IOException {

TokenizerFactory factory = model.getFactory();
Assertions.assertNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());

Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Expand All @@ -156,7 +160,7 @@ void testCustomPatternAndAlphaOpt() throws IOException {

factory = fromSerialized.getFactory();
Assertions.assertNull(factory.getAbbreviationDictionary());
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Assertions.assertEquals(lang, model.getLanguage());
Expand All @@ -165,18 +169,24 @@ void testCustomPatternAndAlphaOpt() throws IOException {

void checkCustomPatternForTokenizerME(String lang, String pattern, String sentence,
int expectedNumTokens) throws IOException {

TokenizerModel model = train(new TokenizerFactory(lang, null, true,
Locale loc = Locale.ENGLISH;
if ("deu".equals(lang)) {
loc = Locale.GERMAN;
}
TokenizerModel model = train(new TokenizerFactory(lang, loadAbbDictionary(loc), true,
Pattern.compile(pattern)));

TokenizerME tokenizer = new TokenizerME(model);
String[] tokens = tokenizer.tokenize(sentence);

Assertions.assertEquals(expectedNumTokens, tokens.length);
String[] sentSplit = sentence.replaceAll("\\.", " .")
.replaceAll("'", " '").split(" ");
String[] sentSplit = sentence.replaceAll("'", " '").split(" ");
for (int i = 0; i < sentSplit.length; i++) {
Assertions.assertEquals(sentSplit[i], tokens[i]);
String sElement = sentSplit[i];
if (i == sentSplit.length - 1) {
sElement = sElement.replace(".", ""); // compensate for sentence ending
}
Assertions.assertEquals(sElement, tokens[i]);
}
}

Expand All @@ -185,7 +195,7 @@ void testCustomPatternForTokenizerMEDeu() throws IOException {
String lang = "deu";
String pattern = "^[A-Za-z0-9äéöüÄÉÖÜß]+$";
String sentence = "Ich wähle den auf S. 183 ff. mitgeteilten Traum von der botanischen Monographie.";
checkCustomPatternForTokenizerME(lang, pattern, sentence, 16);
checkCustomPatternForTokenizerME(lang, pattern, sentence, 14);
}

@Test
Expand Down Expand Up @@ -267,16 +277,16 @@ void testContractionsEng() throws IOException {
@Test
void testDummyFactory() throws IOException {

Dictionary dic = loadAbbDictionary();
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
final String lang = "eng";
String pattern = "^[0-9A-Za-z]+$";

TokenizerModel model = train(new DummyTokenizerFactory(lang, dic, true,
Pattern.compile(pattern)));

TokenizerFactory factory = model.getFactory();
Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Assertions.assertEquals(lang, model.getLanguage());
Expand All @@ -289,8 +299,8 @@ void testDummyFactory() throws IOException {
TokenizerModel fromSerialized = new TokenizerModel(in);

factory = fromSerialized.getFactory();
Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Assertions.assertEquals(lang, model.getLanguage());
Expand All @@ -299,16 +309,16 @@ void testDummyFactory() throws IOException {

@Test
void testCreateDummyFactory() throws IOException {
Dictionary dic = loadAbbDictionary();
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
final String lang = "eng";
String pattern = "^[0-9A-Za-z]+$";

TokenizerFactory factory = TokenizerFactory.create(
DummyTokenizerFactory.class.getCanonicalName(), lang, dic, true,
Pattern.compile(pattern));

Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
Assertions.assertEquals(lang, factory.getLanguageCode());
Assertions.assertTrue(factory.isUseAlphaNumericOptimization());
Expand Down