Skip to content

Commit 49dd452

Browse files
committed
OPENNLP-1525 Improve TokenizerME to make use of abbreviations provided in TokenizerModel
- enhances TokenizerME impl to make use of abbreviations if available in TokenizerModel - adjusts expectation in TokenizerFactoryTest#testCustomPatternForTokenizerMEDeu to 14 tokens
1 parent 5deae24 commit 49dd452

File tree

2 files changed

+78
-32
lines changed

2 files changed

+78
-32
lines changed

opennlp-tools/src/main/java/opennlp/tools/tokenize/TokenizerME.java

+41-5
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public class TokenizerME extends AbstractTokenizer {
110110

111111
private final List<Span> newTokens;
112112

113+
/*
114+
* The {@link Dictionary abbreviation dictionary} if available (may be {@code null}).
115+
*/
116+
private final Dictionary abbDict;
117+
113118
/**
114119
* Initializes a {@link TokenizerME} by downloading a default model.
115120
* @param language The language of the tokenizer.
@@ -132,6 +137,7 @@ public TokenizerME(TokenizerModel model) {
132137
this.model = model.getMaxentModel();
133138
this.useAlphaNumericOptimization = factory.isUseAlphaNumericOptimization();
134139

140+
abbDict = model.getAbbreviations();
135141
newTokens = new ArrayList<>();
136142
tokProbs = new ArrayList<>(50);
137143
}
@@ -151,6 +157,7 @@ public TokenizerME(TokenizerModel model, Factory factory) {
151157
this.model = model.getMaxentModel();
152158
useAlphaNumericOptimization = model.useAlphaNumericOptimization();
153159

160+
abbDict = model.getAbbreviations();
154161
newTokens = new ArrayList<>();
155162
tokProbs = new ArrayList<>(50);
156163
}
@@ -182,6 +189,7 @@ public double[] getTokenProbabilities() {
182189
*
183190
* @return A {@link Span} array containing individual tokens as elements.
184191
*/
192+
@Override
185193
public Span[] tokenizePos(String d) {
186194
WhitespaceTokenizer whitespaceTokenizer = WhitespaceTokenizer.INSTANCE;
187195
whitespaceTokenizer.setKeepNewLines(keepNewLines);
@@ -208,14 +216,22 @@ public Span[] tokenizePos(String d) {
208216
String best = model.getBestOutcome(probs);
209217
tokenProb *= probs[model.getIndex(best)];
210218
if (best.equals(TokenizerME.SPLIT)) {
211-
newTokens.add(new Span(start, j));
212-
tokProbs.add(tokenProb);
213-
start = j;
219+
if (isAcceptableAbbreviation(tok)) {
220+
newTokens.add(new Span(start, end));
221+
tokProbs.add(tokenProb);
222+
start = j + 1; // To compensate for the abbreviation dot
223+
} else {
224+
newTokens.add(new Span(start, j));
225+
tokProbs.add(tokenProb);
226+
start = j;
227+
}
214228
tokenProb = 1.0;
215229
}
216230
}
217-
newTokens.add(new Span(start, end));
218-
tokProbs.add(tokenProb);
231+
if (start < end) {
232+
newTokens.add(new Span(start, end));
233+
tokProbs.add(tokenProb);
234+
}
219235
}
220236
}
221237

@@ -258,4 +274,24 @@ public boolean useAlphaNumericOptimization() {
258274
return useAlphaNumericOptimization;
259275
}
260276

277+
/**
278+
* Allows checking a token abbreviation candidate for acceptability.
279+
*
280+
* <p>Note: The implementation always returns {@code false} if no
281+
* abbreviation dictionary is available for the underlying model.</p>
282+
*
283+
* @param s the {@link CharSequence} in which the break occurred.
284+
* @return {@code true} if the candidate is acceptable, {@code false} otherwise.
285+
*/
286+
protected boolean isAcceptableAbbreviation(CharSequence s) {
287+
if (abbDict == null)
288+
return false;
289+
290+
for (String abb : abbDict.asStringSet()) {
291+
if (abb.equals(s.toString())) {
292+
return true;
293+
}
294+
}
295+
return false;
296+
}
261297
}

opennlp-tools/src/test/java/opennlp/tools/tokenize/TokenizerFactoryTest.java

+37-27
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import java.io.ByteArrayInputStream;
2121
import java.io.ByteArrayOutputStream;
2222
import java.io.IOException;
23-
import java.io.InputStream;
2423
import java.nio.charset.StandardCharsets;
24+
import java.util.Locale;
2525
import java.util.regex.Pattern;
2626

2727
import org.junit.jupiter.api.Assertions;
@@ -54,24 +54,28 @@ private static TokenizerModel train(TokenizerFactory factory)
5454
return TokenizerME.train(createSampleStream(), factory, TrainingParameters.defaultParams());
5555
}
5656

57-
private static Dictionary loadAbbDictionary() throws IOException {
58-
InputStream in = TokenizerFactoryTest.class.getClassLoader()
59-
.getResourceAsStream("opennlp/tools/sentdetect/abb.xml");
60-
61-
return new Dictionary(in);
57+
private static Dictionary loadAbbDictionary(Locale loc) throws IOException {
58+
final String abbrevDict;
59+
if (loc.equals(Locale.GERMAN)) {
60+
abbrevDict = "opennlp/tools/sentdetect/abb_DE.xml";
61+
} else {
62+
abbrevDict = "opennlp/tools/sentdetect/abb.xml";
63+
}
64+
return new Dictionary(TokenizerFactoryTest.class.getClassLoader()
65+
.getResourceAsStream(abbrevDict));
6266
}
6367

6468
@Test
6569
void testDefault() throws IOException {
6670

67-
Dictionary dic = loadAbbDictionary();
71+
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
6872
final String lang = "eng";
6973

7074
TokenizerModel model = train(new TokenizerFactory(lang, dic, false, null));
7175

7276
TokenizerFactory factory = model.getFactory();
7377
Assertions.assertNotNull(factory.getAbbreviationDictionary());
74-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
78+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
7579

7680
String defaultPattern = Factory.DEFAULT_ALPHANUMERIC.pattern();
7781
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
@@ -87,7 +91,7 @@ void testDefault() throws IOException {
8791

8892
factory = fromSerialized.getFactory();
8993
Assertions.assertNotNull(factory.getAbbreviationDictionary());
90-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
94+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
9195

9296
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
9397
Assertions.assertEquals(lang, factory.getLanguageCode());
@@ -105,7 +109,7 @@ void testNullDict() throws IOException {
105109

106110
TokenizerFactory factory = model.getFactory();
107111
Assertions.assertNull(factory.getAbbreviationDictionary());
108-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
112+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
109113

110114
String defaultPattern = Factory.DEFAULT_ALPHANUMERIC.pattern();
111115
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
@@ -121,7 +125,7 @@ void testNullDict() throws IOException {
121125

122126
factory = fromSerialized.getFactory();
123127
Assertions.assertNull(factory.getAbbreviationDictionary());
124-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
128+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
125129

126130
Assertions.assertEquals(defaultPattern, factory.getAlphaNumericPattern().pattern());
127131
Assertions.assertEquals(lang, factory.getLanguageCode());
@@ -141,7 +145,7 @@ void testCustomPatternAndAlphaOpt() throws IOException {
141145

142146
TokenizerFactory factory = model.getFactory();
143147
Assertions.assertNull(factory.getAbbreviationDictionary());
144-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
148+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
145149

146150
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
147151
Assertions.assertEquals(lang, factory.getLanguageCode());
@@ -156,7 +160,7 @@ void testCustomPatternAndAlphaOpt() throws IOException {
156160

157161
factory = fromSerialized.getFactory();
158162
Assertions.assertNull(factory.getAbbreviationDictionary());
159-
Assertions.assertTrue(factory.getContextGenerator() instanceof DefaultTokenContextGenerator);
163+
Assertions.assertInstanceOf(DefaultTokenContextGenerator.class, factory.getContextGenerator());
160164
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
161165
Assertions.assertEquals(lang, factory.getLanguageCode());
162166
Assertions.assertEquals(lang, model.getLanguage());
@@ -165,18 +169,24 @@ void testCustomPatternAndAlphaOpt() throws IOException {
165169

166170
void checkCustomPatternForTokenizerME(String lang, String pattern, String sentence,
167171
int expectedNumTokens) throws IOException {
168-
169-
TokenizerModel model = train(new TokenizerFactory(lang, null, true,
172+
Locale loc = Locale.ENGLISH;
173+
if ("deu".equals(lang)) {
174+
loc = Locale.GERMAN;
175+
}
176+
TokenizerModel model = train(new TokenizerFactory(lang, loadAbbDictionary(loc), true,
170177
Pattern.compile(pattern)));
171178

172179
TokenizerME tokenizer = new TokenizerME(model);
173180
String[] tokens = tokenizer.tokenize(sentence);
174181

175182
Assertions.assertEquals(expectedNumTokens, tokens.length);
176-
String[] sentSplit = sentence.replaceAll("\\.", " .")
177-
.replaceAll("'", " '").split(" ");
183+
String[] sentSplit = sentence.replaceAll("'", " '").split(" ");
178184
for (int i = 0; i < sentSplit.length; i++) {
179-
Assertions.assertEquals(sentSplit[i], tokens[i]);
185+
String sElement = sentSplit[i];
186+
if (i == sentSplit.length - 1) {
187+
sElement = sElement.replace(".", ""); // compensate for sentence ending
188+
}
189+
Assertions.assertEquals(sElement, tokens[i]);
180190
}
181191
}
182192

@@ -185,7 +195,7 @@ void testCustomPatternForTokenizerMEDeu() throws IOException {
185195
String lang = "deu";
186196
String pattern = "^[A-Za-z0-9äéöüÄÉÖÜß]+$";
187197
String sentence = "Ich wähle den auf S. 183 ff. mitgeteilten Traum von der botanischen Monographie.";
188-
checkCustomPatternForTokenizerME(lang, pattern, sentence, 16);
198+
checkCustomPatternForTokenizerME(lang, pattern, sentence, 14);
189199
}
190200

191201
@Test
@@ -267,16 +277,16 @@ void testContractionsEng() throws IOException {
267277
@Test
268278
void testDummyFactory() throws IOException {
269279

270-
Dictionary dic = loadAbbDictionary();
280+
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
271281
final String lang = "eng";
272282
String pattern = "^[0-9A-Za-z]+$";
273283

274284
TokenizerModel model = train(new DummyTokenizerFactory(lang, dic, true,
275285
Pattern.compile(pattern)));
276286

277287
TokenizerFactory factory = model.getFactory();
278-
Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
279-
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
288+
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
289+
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
280290
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
281291
Assertions.assertEquals(lang, factory.getLanguageCode());
282292
Assertions.assertEquals(lang, model.getLanguage());
@@ -289,8 +299,8 @@ void testDummyFactory() throws IOException {
289299
TokenizerModel fromSerialized = new TokenizerModel(in);
290300

291301
factory = fromSerialized.getFactory();
292-
Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
293-
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
302+
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
303+
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
294304
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
295305
Assertions.assertEquals(lang, factory.getLanguageCode());
296306
Assertions.assertEquals(lang, model.getLanguage());
@@ -299,16 +309,16 @@ void testDummyFactory() throws IOException {
299309

300310
@Test
301311
void testCreateDummyFactory() throws IOException {
302-
Dictionary dic = loadAbbDictionary();
312+
Dictionary dic = loadAbbDictionary(Locale.ENGLISH);
303313
final String lang = "eng";
304314
String pattern = "^[0-9A-Za-z]+$";
305315

306316
TokenizerFactory factory = TokenizerFactory.create(
307317
DummyTokenizerFactory.class.getCanonicalName(), lang, dic, true,
308318
Pattern.compile(pattern));
309319

310-
Assertions.assertTrue(factory.getAbbreviationDictionary() instanceof DummyDictionary);
311-
Assertions.assertTrue(factory.getContextGenerator() instanceof DummyContextGenerator);
320+
Assertions.assertInstanceOf(DummyDictionary.class, factory.getAbbreviationDictionary());
321+
Assertions.assertInstanceOf(DummyContextGenerator.class, factory.getContextGenerator());
312322
Assertions.assertEquals(pattern, factory.getAlphaNumericPattern().pattern());
313323
Assertions.assertEquals(lang, factory.getLanguageCode());
314324
Assertions.assertTrue(factory.isUseAlphaNumericOptimization());

0 commit comments

Comments
 (0)