Skip to content

Commit

Permalink
OPENNLP-1661: Fix custom models being wiped from OpenNLP user.home di…
Browse files Browse the repository at this point in the history
…rectory (#704)

- deletes AbstractDownloadUtilTest.java removing historical code that wiped models
- adds package-private DownloadUtil#existsModel(..) method to check models for certain language exist locally
- adds to and adjusts related test classes
  • Loading branch information
mawiesne authored Dec 3, 2024
1 parent f4de6c2 commit e91ceb1
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 109 deletions.
60 changes: 49 additions & 11 deletions opennlp-tools/src/main/java/opennlp/tools/util/DownloadUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,45 @@ public enum ModelType {
System.getProperty("OPENNLP_DOWNLOAD_BASE_URL", "https://dlcdn.apache.org/opennlp/");
private static final String MODEL_URI_PATH =
System.getProperty("OPENNLP_DOWNLOAD_MODEL_PATH", "models/ud-models-1.2/");
private static final String OPENNLP_DOWNLOAD_HOME = "OPENNLP_DOWNLOAD_HOME";

private static Map<String, Map<ModelType, String>> availableModels;

/**
* Checks if a model of the specified {@code modelType} has been downloaded already
* for a particular {@code language}.
*
* @param language The ISO language code of the requested model.
* @param modelType The {@link DownloadUtil.ModelType type} of model.
* @return {@code true} if a model exists locally, {@code false} otherwise.
* @throws IOException Thrown if IO errors occurred or the computed hash sum
* of an associated, local model file was incorrect.
*/
static boolean existsModel(String language, ModelType modelType) throws IOException {
Map<ModelType, String> modelsByLanguage = getAvailableModels().get(language);
if (modelsByLanguage == null) {
return false;
} else {
final String url = modelsByLanguage.get(modelType);
if (url != null) {
final Path homeDirectory = getDownloadHome();
final String filename = url.substring(url.lastIndexOf("/") + 1);
final Path localFile = homeDirectory.resolve(filename);
boolean exists;
if (Files.exists(localFile)) {
// if this does not throw the requested model is valid!
validateModel(new URL(url + ".sha512"), localFile);
exists = true;
} else {
exists = false;
}
return exists;
} else {
return false;
}
}
}

/**
* Triggers a download for the specified {@link DownloadUtil.ModelType}.
*
Expand All @@ -94,7 +130,7 @@ public static <T extends BaseModel> T downloadModel(String language, ModelType m
Class<T> type) throws IOException {

if (getAvailableModels().containsKey(language)) {
final String url = (getAvailableModels().get(language).get(modelType));
final String url = getAvailableModels().get(language).get(modelType);
if (url != null) {
return downloadModel(new URL(url), type);
}
Expand All @@ -119,8 +155,7 @@ public static <T extends BaseModel> T downloadModel(String language, ModelType m
*/
public static <T extends BaseModel> T downloadModel(URL url, Class<T> type) throws IOException {

final Path homeDirectory = Paths.get(System.getProperty("OPENNLP_DOWNLOAD_HOME",
System.getProperty("user.home"))).resolve(".opennlp");
final Path homeDirectory = getDownloadHome();

if (!Files.isDirectory(homeDirectory)) {
try {
Expand All @@ -131,20 +166,17 @@ public static <T extends BaseModel> T downloadModel(URL url, Class<T> type) thro
}

final String filename = url.toString().substring(url.toString().lastIndexOf("/") + 1);
final Path localFile = Paths.get(homeDirectory.toString(), filename);
final Path localFile = homeDirectory.resolve(filename);

if (!Files.exists(localFile)) {
logger.debug("Downloading model from {} to {}.", url, localFile);
logger.debug("Downloading model to {}.", localFile);

try (final InputStream in = url.openStream()) {
Files.copy(in, localFile, StandardCopyOption.REPLACE_EXISTING);
}

validateModel(new URL(url + ".sha512"), localFile);

logger.debug("Download complete.");
} else {
System.out.println("Model file already exists. Skipping download.");
logger.debug("Model file '{}' already exists. Skipping download.", filename);
}

Expand All @@ -167,7 +199,7 @@ public static Map<String, Map<ModelType, String>> getAvailableModels() {
}

/**
* Validates the downloaded model.
* Validates a downloaded model via the specified {@link Path downloadedModel path}.
*
* @param sha512 the url to get the sha512 hash
* @param downloadedModel the model file to check
Expand All @@ -187,8 +219,8 @@ private static void validateModel(URL sha512, Path downloadedModel) throws IOExc
// Validate SHA512 checksum
final String actualChecksum = calculateSHA512(downloadedModel);
if (!actualChecksum.equalsIgnoreCase(expectedChecksum)) {
throw new IOException("SHA512 checksum validation failed. Expected: "
+ expectedChecksum + ", but got: " + actualChecksum);
throw new IOException("SHA512 checksum validation failed for " + downloadedModel.getFileName() +
". Expected: " + expectedChecksum + ", but got: " + actualChecksum);
}
}

Expand All @@ -198,6 +230,7 @@ private static String calculateSHA512(Path file) throws IOException {
try (InputStream fis = Files.newInputStream(file);
DigestInputStream dis = new DigestInputStream(fis, digest)) {
byte[] buffer = new byte[4096];
//noinspection StatementWithEmptyBody
while (dis.read(buffer) != -1) {
// Reading the file to update the digest
}
Expand All @@ -217,6 +250,11 @@ private static String byteArrayToHexString(byte[] bytes) {
}
}

private static Path getDownloadHome() {
return Paths.get(System.getProperty(OPENNLP_DOWNLOAD_HOME,
System.getProperty("user.home"))).resolve(".opennlp");
}

@Internal
static class DownloadParser {

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

@EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org")
public class DownloadUtilDownloadTwiceTest extends AbstractDownloadUtilTest {
public class DownloadUtilDownloadTwiceTest {

/*
* Programmatic change to debug log to ensure that we can see log messages to
Expand All @@ -60,24 +60,24 @@ public static void cleanup() {

@Test
public void testDownloadModelTwice() throws IOException {
String lang = "de";
DownloadUtil.ModelType type = DownloadUtil.ModelType.SENTENCE_DETECTOR;

try (LogCaptor logCaptor = LogCaptor.forClass(DownloadUtil.class)) {

DownloadUtil.downloadModel("de",
DownloadUtil.ModelType.SENTENCE_DETECTOR, SentenceModel.class);

assertEquals(2, logCaptor.getDebugLogs().size());
checkDebugLogsContainMessageFragment(logCaptor.getDebugLogs(), "Download complete.");
boolean alreadyDownloaded = DownloadUtil.existsModel(lang, type);
DownloadUtil.downloadModel(lang, type, SentenceModel.class);

if (! alreadyDownloaded) {
assertEquals(2, logCaptor.getDebugLogs().size());
checkDebugLogsContainMessageFragment(logCaptor.getDebugLogs(), "Download complete.");
} else {
assertEquals(1, logCaptor.getDebugLogs().size());
checkDebugLogsContainMessageFragment(logCaptor.getDebugLogs(), "already exists. Skipping download.");
}
logCaptor.clearLogs();

// try to download again
DownloadUtil.downloadModel("de",
DownloadUtil.ModelType.SENTENCE_DETECTOR, SentenceModel.class);
assertEquals(1, logCaptor.getDebugLogs().size());
checkDebugLogsContainMessageFragment(logCaptor.getDebugLogs(), "already exists. Skipping download.");
logCaptor.clearLogs();

DownloadUtil.downloadModel("de",
DownloadUtil.ModelType.SENTENCE_DETECTOR, SentenceModel.class);
DownloadUtil.downloadModel(lang, type, SentenceModel.class);
assertEquals(1, logCaptor.getDebugLogs().size());
checkDebugLogsContainMessageFragment(logCaptor.getDebugLogs(), "already exists. Skipping download.");
logCaptor.clearLogs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.net.URL;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -32,11 +33,12 @@
import opennlp.tools.tokenize.TokenizerModel;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class DownloadUtilTest extends AbstractDownloadUtilTest {
public class DownloadUtilTest {

@ParameterizedTest(name = "Verify \"{0}\" sentence model")
@ValueSource(strings = {"en", "fr", "de", "it", "nl", "bg", "ca", "cs", "da", "el",
Expand All @@ -62,13 +64,33 @@ public void testDownloadModelByURL(String language, URL url) throws IOException
assertTrue(model.isLoadedFromSerialized());
}

@Test
@EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org")
public void testExistsModel() throws IOException {
final String lang = "en";
final DownloadUtil.ModelType type = DownloadUtil.ModelType.SENTENCE_DETECTOR;
// Prepare
SentenceModel model = DownloadUtil.downloadModel(lang, type, SentenceModel.class);
assertNotNull(model);
assertEquals(lang, model.getLanguage());
// Test
assertTrue(DownloadUtil.existsModel(lang, type));
}

@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = {"xy", "\t", "\n"})
@EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org")
public void testExistsModelInvalid(String input) throws IOException {
assertFalse(DownloadUtil.existsModel(input, DownloadUtil.ModelType.SENTENCE_DETECTOR));
}

@ParameterizedTest(name = "Detect invalid input: \"{0}\"")
@NullAndEmptySource
@ValueSource(strings = {" ", "\t", "\n"})
public void testDownloadModelInvalid(String input) {
assertThrows(IOException.class, () -> DownloadUtil.downloadModel(
input, DownloadUtil.ModelType.SENTENCE_DETECTOR, SentenceModel.class),
"Invalid model");
assertThrows(IOException.class, () -> DownloadUtil.downloadModel(input,
DownloadUtil.ModelType.SENTENCE_DETECTOR, SentenceModel.class), "Invalid model");
}

private static final DownloadUtil.ModelType MT_TOKENIZER = DownloadUtil.ModelType.TOKENIZER;
Expand Down

0 comments on commit e91ceb1

Please sign in to comment.