From 93e44596a778442c3a4c1321250d525198a978b9 Mon Sep 17 00:00:00 2001 From: Ivan Kripakov Date: Mon, 6 Nov 2023 15:21:43 +0400 Subject: [PATCH] provide more granular way to manage embedding cache --- .../cache/EmbeddingCacheProvider.java | 71 ++++-- .../frs/core/trainservice/dao/SubjectDao.java | 4 + .../core/trainservice/dto/CacheActionDto.java | 67 ++++-- .../service/EmbeddingService.java | 5 +- .../service/NotificationHandler.java | 108 ++++++++++ .../service/NotificationReceiverService.java | 74 ++++--- .../trainservice/service/SubjectService.java | 33 ++- .../trainservice/system/global/Constants.java | 2 +- .../cache/EmbeddingCacheProviderTest.java | 203 +++++++++++++++++- .../NotificationReceiverServiceTest.java | 114 ++++++++++ .../service/SubjectServiceTest.java | 35 ++- java/pom.xml | 12 +- 12 files changed, 629 insertions(+), 99 deletions(-) create mode 100644 java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationHandler.java create mode 100644 java/api/src/test/java/com/exadel/frs/core/trainservice/service/NotificationReceiverServiceTest.java diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java index 294d52fac7..10aafb8e14 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java @@ -1,18 +1,25 @@ package com.exadel.frs.core.trainservice.cache; +import com.exadel.frs.commonservice.entity.Embedding; +import com.exadel.frs.commonservice.projection.EmbeddingProjection; import com.exadel.frs.core.trainservice.dto.CacheActionDto; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.CacheAction; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects; import com.exadel.frs.core.trainservice.service.EmbeddingService; import com.exadel.frs.core.trainservice.service.NotificationSenderService; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; -import java.util.Optional; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; - import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID; @Component @@ -34,34 +41,60 @@ public class EmbeddingCacheProvider { .build(); public EmbeddingCollection getOrLoad(final String apiKey) { - var result = cache.getIfPresent(apiKey); - if (result == null) { result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from); - cache.put(apiKey, result); - - notifyCacheEvent("UPDATE", apiKey); } - return result; } - public void ifPresent(String apiKey, Consumer consumer) { - Optional.ofNullable(cache.getIfPresent(apiKey)) - .ifPresent(consumer); + public void removeEmbedding(String apiKey, EmbeddingProjection embedding) { + getOrLoad(apiKey).removeEmbedding(embedding); + notifyCacheEvent( + CacheAction.REMOVE_EMBEDDINGS, + apiKey, + new RemoveEmbeddings(Map.of(embedding.subjectName(), List.of(embedding.embeddingId()))) + ); + } + + public void updateSubjectName(String apiKey, String oldSubjectName, String newSubjectName) { + getOrLoad(apiKey).updateSubjectName(oldSubjectName, newSubjectName); + notifyCacheEvent(CacheAction.RENAME_SUBJECTS, apiKey, new RenameSubjects(Map.of(oldSubjectName, newSubjectName))); + } - cache.getIfPresent(apiKey); - notifyCacheEvent("UPDATE", apiKey); + public void removeBySubjectName(String apiKey, String subjectName) { + getOrLoad(apiKey).removeEmbeddingsBySubjectName(subjectName); + notifyCacheEvent(CacheAction.REMOVE_SUBJECTS, apiKey, new RemoveSubjects(List.of(subjectName))); + } + + + public void addEmbedding(String apiKey, Embedding embedding) { + getOrLoad(apiKey).addEmbedding(embedding); + notifyCacheEvent(CacheAction.ADD_EMBEDDINGS, apiKey, new AddEmbeddings(List.of(embedding.getId()))); + } + + /** + * Method can be used to make changes in cache without sending notification. + * Use it carefully, because changes you do will not be visible for other compreface-api instances + * + * @param apiKey domain + * @param action what to do with {@link EmbeddingCollection} + */ + public void exposeIfPresent(String apiKey, Consumer action) { + action.accept(getOrLoad(apiKey)); } public void invalidate(final String apiKey) { cache.invalidate(apiKey); - notifyCacheEvent("DELETE", apiKey); + notifyCacheEvent(CacheAction.INVALIDATE, apiKey, null); } - + /** + * @deprecated + * See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)} + */ + @Deprecated(forRemoval = true) public void receivePutOnCache(String apiKey) { var result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from); cache.put(apiKey, result); @@ -71,8 +104,8 @@ public void receiveInvalidateCache(final String apiKey) { cache.invalidate(apiKey); } - private void notifyCacheEvent(String event, String apiKey) { - CacheActionDto cacheActionDto = new CacheActionDto(event, apiKey, SERVER_UUID); + private void notifyCacheEvent(CacheAction event, String apiKey, Object action) { + CacheActionDto cacheActionDto = new CacheActionDto<>(event, apiKey, SERVER_UUID, action); notificationSenderService.notifyCacheChange(cacheActionDto); } } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/dao/SubjectDao.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/dao/SubjectDao.java index f470241262..29a2aa2e4b 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/dao/SubjectDao.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/dao/SubjectDao.java @@ -32,6 +32,10 @@ public Collection getSubjectNames(final String apiKey) { return subjectRepository.getSubjectNames(apiKey); } + public List loadAllEmbeddingsByIds(Iterable ids) { + return embeddingRepository.findAllById(ids); + } + @Transactional public Subject deleteSubjectByName(final String apiKey, final String subjectName) { final Optional subjectOptional = subjectRepository.findByApiKeyAndSubjectNameIgnoreCase(apiKey, subjectName); diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/dto/CacheActionDto.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/dto/CacheActionDto.java index d15d444c25..bb072643c8 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/dto/CacheActionDto.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/dto/CacheActionDto.java @@ -1,21 +1,62 @@ package com.exadel.frs.core.trainservice.dto; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; +import java.util.List; +import java.util.Map; +import java.util.UUID; -@Data -@AllArgsConstructor -@NoArgsConstructor -public class CacheActionDto { +@JsonIgnoreProperties(ignoreUnknown = true) // here and below "ignoreUnknown = true" for backward compatibility +public record CacheActionDto( + CacheAction cacheAction, + String apiKey, + @JsonProperty("uuid") + UUID serverUUID, + T payload +) { + public CacheActionDto withPayload(S payload) { + return new CacheActionDto<>( + cacheAction, + apiKey, + serverUUID, + payload + ); + } - @JsonProperty("cacheAction") - private String cacheAction; + public enum CacheAction { + // UPDATE and DELETE stays here to support rolling update + @Deprecated + UPDATE, + @Deprecated + DELETE, + REMOVE_EMBEDDINGS, + REMOVE_SUBJECTS, + ADD_EMBEDDINGS, + RENAME_SUBJECTS, + INVALIDATE + } - @JsonProperty("apiKey") - private String apiKey; + @JsonIgnoreProperties(ignoreUnknown = true) + public record RemoveEmbeddings( + Map> embeddings + ) { + } - @JsonProperty("uuid") - private String serverUUID; + @JsonIgnoreProperties(ignoreUnknown = true) + public record RemoveSubjects( + List subjects + ) { + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public record AddEmbeddings( + List embeddings + ) { + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public record RenameSubjects( + Map subjectsNamesMapping + ) { + } } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java index 4a64f7e88b..f99bdd5ca2 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java @@ -9,7 +9,6 @@ import com.exadel.frs.core.trainservice.system.global.Constants; import java.util.stream.Stream; import lombok.RequiredArgsConstructor; -import lombok.val; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.stereotype.Service; @@ -32,9 +31,9 @@ public int updateEmbedding(UUID embeddingId, double[] embedding, String calculat return embeddingRepository.updateEmbedding(embeddingId, embedding, calculator); } - @Transactional + @org.springframework.transaction.annotation.Transactional(readOnly = true) public T doWithEnhancedEmbeddingProjectionStream(String apiKey, Function, T> func) { - try (val stream = embeddingRepository.findBySubjectApiKey(apiKey)) { + try (var stream = embeddingRepository.findBySubjectApiKey(apiKey)) { return func.apply(stream); } } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationHandler.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationHandler.java new file mode 100644 index 0000000000..ec3807e9bd --- /dev/null +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationHandler.java @@ -0,0 +1,108 @@ +package com.exadel.frs.core.trainservice.service; + +import com.exadel.frs.commonservice.projection.EmbeddingProjection; +import com.exadel.frs.core.trainservice.cache.EmbeddingCacheProvider; +import com.exadel.frs.core.trainservice.dto.CacheActionDto; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects; +import java.util.Objects; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +@RequiredArgsConstructor +public class NotificationHandler { + private final EmbeddingCacheProvider cacheProvider; + private final SubjectService subjectService; + + public void removeEmbeddings(CacheActionDto action) { + action.payload().embeddings() + .entrySet() + .stream() + .filter(e -> StringUtils.isNotBlank(e.getKey())) + .filter(e -> Objects.nonNull(e.getValue())) + .filter(e -> !e.getValue().isEmpty()) + .flatMap(e -> e.getValue().stream().filter(Objects::nonNull).map(id -> new EmbeddingProjection(id, e.getKey()))) + .forEach( + em -> cacheProvider.exposeIfPresent( + action.apiKey(), + c -> c.removeEmbedding(em) + ) + ); + } + + public void removeSubjects(CacheActionDto action) { + action.payload().subjects() + .stream() + .filter(StringUtils::isNotBlank) + .forEach( + s -> cacheProvider.exposeIfPresent( + action.apiKey(), + c -> c.removeEmbeddingsBySubjectName(s) + ) + ); + } + + + public void addEmbeddings(CacheActionDto action) { + var filtered = action.payload().embeddings() + .stream() + .filter(Objects::nonNull) + .toList(); + subjectService.loadEmbeddingsById(filtered) + .forEach( + em -> cacheProvider.exposeIfPresent( + action.apiKey(), + c -> c.addEmbedding(em) + ) + ); + } + + public void renameSubjects(CacheActionDto action) { + action.payload().subjectsNamesMapping() + .entrySet() + .stream() + .filter(e -> StringUtils.isNotBlank(e.getKey())) + .filter(e -> StringUtils.isNotBlank(e.getValue())) + .forEach( + e -> cacheProvider.exposeIfPresent( + action.apiKey(), + c -> c.updateSubjectName(e.getKey(), e.getValue()) + ) + ); + } + + public void invalidate(CacheActionDto action) { + cacheProvider.exposeIfPresent( + action.apiKey(), + e -> cacheProvider.receiveInvalidateCache(action.apiKey()) + ); + } + + /** + * @param action cacheAction + * @deprecated in favour more granular cache managing. + * See {@link CacheActionDto}. + * Stays here to support rolling update + */ + @Deprecated(forRemoval = true) + public void handleDelete(CacheActionDto action) { + cacheProvider.receiveInvalidateCache(action.apiKey()); + } + + /** + * @param action cacheAction + * @deprecated in favour more granular cache managing. + * See {@link CacheActionDto}. + * Stays here to support rolling update + */ + @Deprecated(forRemoval = true) + public void handleUpdate(CacheActionDto action) { + cacheProvider.receivePutOnCache(action.apiKey()); + } +} diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationReceiverService.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationReceiverService.java index 9d000ba244..f98c17c34b 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationReceiverService.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/NotificationReceiverService.java @@ -1,23 +1,24 @@ package com.exadel.frs.core.trainservice.service; -import com.exadel.frs.core.trainservice.cache.EmbeddingCacheProvider; import com.exadel.frs.core.trainservice.dto.CacheActionDto; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.impossibl.postgres.api.jdbc.PGConnection; import com.impossibl.postgres.api.jdbc.PGNotificationListener; import com.impossibl.postgres.jdbc.PGDataSource; - +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Map; +import java.util.Objects; +import javax.annotation.PostConstruct; import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; -import javax.annotation.PostConstruct; -import java.sql.SQLException; -import java.sql.Statement; - import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID; @Service("notificationReceiverService") @@ -25,16 +26,12 @@ @RequiredArgsConstructor public class NotificationReceiverService { + private static final TypeReference>> CACHE_ACTION_DTO_TR = new TypeReference<>() {}; @Qualifier("dsPgNot") private final PGDataSource pgNotificationDatasource; - - private PGConnection connection; - - private final EmbeddingCacheProvider embeddingCacheProvider; - + private final NotificationHandler notificationHandler; private final ObjectMapper objectMapper; - private static PGNotificationListener listener; @PostConstruct @@ -69,25 +66,48 @@ public void closed() { connection.addNotificationListener(listener); } - private void synchronizeCacheWithNotification(String payload) { - + // package private for test purposes + void synchronizeCacheWithNotification(String payload) { try { - CacheActionDto cacheActionDto = objectMapper.readValue(payload, CacheActionDto.class); - if (cacheActionDto != null - && !StringUtils.isBlank(cacheActionDto.getServerUUID()) - && !cacheActionDto.getServerUUID().equals(SERVER_UUID) - && !StringUtils.isBlank(cacheActionDto.getApiKey()) - && !StringUtils.isBlank(cacheActionDto.getCacheAction()) - ) { - - if (cacheActionDto.getCacheAction().equals("UPDATE")) { - embeddingCacheProvider.receivePutOnCache(cacheActionDto.getApiKey()); - } else if (cacheActionDto.getCacheAction().equals("DELETE")) { - embeddingCacheProvider.receiveInvalidateCache(cacheActionDto.getApiKey()); - } + var cacheActionDto = objectMapper.readValue(payload, CACHE_ACTION_DTO_TR); + if (SERVER_UUID.equals(cacheActionDto.serverUUID())) { + return; + } + if (Objects.isNull(cacheActionDto.serverUUID())) { + log.warn("Received notification with empty serverUUID: {}", cacheActionDto); + return; + } + if (StringUtils.isBlank(cacheActionDto.apiKey())) { + log.warn("Received notification with blank api key: {}", cacheActionDto); + return; } + if (Objects.isNull(cacheActionDto.cacheAction())) { + log.warn("Received notification with blank cache action type: {}", cacheActionDto); + return; + } + processNotification(cacheActionDto); } catch (JsonProcessingException e) { log.error(e.getMessage()); } } + + private void processNotification(CacheActionDto> action) { + switch (action.cacheAction()) { + case INVALIDATE -> notificationHandler.invalidate(action); + case REMOVE_SUBJECTS -> notificationHandler.removeSubjects(convert(action)); + case ADD_EMBEDDINGS -> notificationHandler.addEmbeddings(convert(action)); + case REMOVE_EMBEDDINGS -> notificationHandler.removeEmbeddings(convert(action)); + case RENAME_SUBJECTS -> notificationHandler.renameSubjects(convert(action)); + case DELETE -> notificationHandler.handleDelete(action); + case UPDATE -> notificationHandler.handleUpdate(action); + default -> log.error("Can't process action with actionType={}", action.cacheAction()); + } + } + + @SneakyThrows + private CacheActionDto convert(CacheActionDto> action) { + return action.withPayload( + objectMapper.convertValue(action.payload(), new TypeReference<>() {}) + ); + } } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java index bd7b0dce91..bb3fe4b139 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/SubjectService.java @@ -1,11 +1,11 @@ package com.exadel.frs.core.trainservice.service; -import static java.math.RoundingMode.HALF_UP; import com.exadel.frs.commonservice.entity.Embedding; import com.exadel.frs.commonservice.entity.Subject; import com.exadel.frs.commonservice.exception.EmbeddingNotFoundException; import com.exadel.frs.commonservice.exception.TooManyFacesException; import com.exadel.frs.commonservice.exception.WrongEmbeddingCountException; +import com.exadel.frs.commonservice.projection.EmbeddingProjection; import com.exadel.frs.commonservice.sdk.faces.FacesApiClient; import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResponse; import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResult; @@ -39,6 +39,8 @@ import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; +import static java.math.RoundingMode.HALF_UP; + @Service @RequiredArgsConstructor @Slf4j @@ -71,15 +73,16 @@ public int deleteSubjectsByApiKey(final String apiKey) { return deletedCount; } + public List loadEmbeddingsById(Iterable embeddingsIds) { + return subjectDao.loadAllEmbeddingsByIds(embeddingsIds); + } + public int removeAllSubjectEmbeddings(final String apiKey, final String subjectName) { int removed; if (StringUtils.isNotEmpty(subjectName)) { removed = subjectDao.removeAllSubjectEmbeddings(apiKey, subjectName); if (removed > 0) { - embeddingCacheProvider.ifPresent( - apiKey, - c -> c.removeEmbeddingsBySubjectName(subjectName) - ); + embeddingCacheProvider.removeBySubjectName(apiKey, subjectName); } } else { removed = subjectDao.removeAllSubjectEmbeddings(apiKey); @@ -101,10 +104,7 @@ public Subject deleteSubjectByNameAndApiKey(final String apiKey, final String su var subject = subjectDao.deleteSubjectByName(apiKey, subjectName); // remove subject from cache if required - embeddingCacheProvider.ifPresent( - apiKey, - c -> c.removeEmbeddingsBySubjectName(subjectName) - ); + embeddingCacheProvider.removeBySubjectName(apiKey, subjectName); return subject; } @@ -113,10 +113,7 @@ public Embedding removeSubjectEmbedding(final String apiKey, final UUID embeddin var embedding = subjectDao.removeSubjectEmbedding(apiKey, embeddingId); // remove embedding from cache if required - embeddingCacheProvider.ifPresent( - apiKey, - c -> c.removeEmbedding(embedding) - ); + embeddingCacheProvider.removeEmbedding(apiKey, EmbeddingProjection.from(embedding)); return embedding; } @@ -142,10 +139,7 @@ public boolean updateSubjectName(final String apiKey, final String oldSubjectNam if (updated) { // update cache if required - embeddingCacheProvider.ifPresent( - apiKey, - c -> c.updateSubjectName(oldSubjectName, newSubjectName) - ); + embeddingCacheProvider.updateSubjectName(apiKey, oldSubjectName, newSubjectName); } return updated; @@ -217,10 +211,7 @@ private Pair saveCalculatedEmbedding(byte[] content, final Pair pair = subjectDao.addEmbedding(modelKey, subjectName, embeddingToSave); - embeddingCacheProvider.ifPresent( - modelKey, - subjectCollection -> subjectCollection.addEmbedding(pair.getRight()) - ); + embeddingCacheProvider.addEmbedding(modelKey, pair.getRight()); return pair; } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/system/global/Constants.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/system/global/Constants.java index df0ad55d3b..d55fe42088 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/system/global/Constants.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/system/global/Constants.java @@ -62,6 +62,6 @@ public class Constants { public static final String NUMBER_VALUE_EXAMPLE = "1"; public static final String DEMO_API_KEY = "00000000-0000-0000-0000-000000000002"; public static final String FACENET2018 = "Facenet2018"; - public static final String SERVER_UUID = UUID.randomUUID().toString(); + public static final UUID SERVER_UUID = UUID.randomUUID(); public static final String CACHE_CONTROL_HEADER_VALUE = "public, max-age=31536000"; } diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java index a2a0cc626f..f4ddf5142d 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java @@ -16,29 +16,54 @@ package com.exadel.frs.core.trainservice.cache; -import static com.exadel.frs.core.trainservice.ItemsBuilder.makeEnhancedEmbeddingProjection; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; +import com.exadel.frs.commonservice.entity.Embedding; +import com.exadel.frs.commonservice.entity.Subject; +import com.exadel.frs.commonservice.projection.EmbeddingProjection; import com.exadel.frs.commonservice.projection.EnhancedEmbeddingProjection; +import com.exadel.frs.core.trainservice.dto.CacheActionDto; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.CacheAction; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects; import com.exadel.frs.core.trainservice.service.EmbeddingService; import com.exadel.frs.core.trainservice.service.NotificationReceiverService; import com.exadel.frs.core.trainservice.service.NotificationSenderService; +import com.exadel.frs.core.trainservice.system.global.Constants; +import java.util.List; +import java.util.Map; +import java.util.UUID; import java.util.function.Function; import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import static com.exadel.frs.core.trainservice.ItemsBuilder.makeEnhancedEmbeddingProjection; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + @ExtendWith(MockitoExtension.class) class EmbeddingCacheProviderTest { private static final String API_KEY = "model_key"; + private static final String SUBJECT_NAME = "subject_name"; + private static final String NEW_SUBJECT_NAME = "new_subject_name"; + private static final UUID EMBEDDING_ID_1 = UUID.randomUUID(); + private static final UUID EMBEDDING_ID_2 = UUID.randomUUID(); + private static final String TEST_CALCULATOR = "test-calculator"; @Mock private EmbeddingService embeddingService; @@ -52,8 +77,21 @@ class EmbeddingCacheProviderTest { @InjectMocks private EmbeddingCacheProvider embeddingCacheProvider; + @BeforeEach + @SuppressWarnings("unchecked") + public void resetStaticCache() { + embeddingCacheProvider.invalidate(API_KEY); + when(embeddingService.doWithEnhancedEmbeddingProjectionStream(eq(API_KEY), any())) + .thenAnswer(invocation -> { + var function = (Function, ?>) invocation.getArgument(1); + return function.apply(Stream.of()); + }); + } + @Test + @SuppressWarnings("unchecked") void getOrLoad() { + reset(embeddingService); var projections = new EnhancedEmbeddingProjection[]{ makeEnhancedEmbeddingProjection("A"), makeEnhancedEmbeddingProjection("B"), @@ -73,4 +111,155 @@ void getOrLoad() { assertThat(actual.getProjections().size(), is(projections.length)); assertThat(actual.getEmbeddings(), notNullValue()); } + + @Test + void removeEmbedding() { + // arrange + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_1)); + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + reset(notificationReceiverService); + + // act + embeddingCacheProvider.removeEmbedding(API_KEY, new EmbeddingProjection(EMBEDDING_ID_1, SUBJECT_NAME)); + + // assert + var embeddings = embeddingCacheProvider.getOrLoad(API_KEY); + Assertions.assertThat(embeddings.getEmbeddings().size(0)).isEqualTo(1); + Assertions.assertThat(embeddings.getProjections()).containsOnly(new EmbeddingProjection(EMBEDDING_ID_2, SUBJECT_NAME)); + + verify(notificationSenderService, times(1)).notifyCacheChange( + buildCacheActionDto(CacheAction.REMOVE_EMBEDDINGS, new RemoveEmbeddings(Map.of(SUBJECT_NAME, List.of(EMBEDDING_ID_1)))) + ); + } + + @Test + void updateSubjectName() { + // arrange + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_1)); + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + reset(notificationSenderService); + + // act + embeddingCacheProvider.updateSubjectName(API_KEY, SUBJECT_NAME, NEW_SUBJECT_NAME); + + // assert + var embeddings = embeddingCacheProvider.getOrLoad(API_KEY); + Assertions.assertThat(embeddings.getEmbeddings().size(0)).isEqualTo(2); + Assertions.assertThat(embeddings.getProjections()).containsExactly( + new EmbeddingProjection(EMBEDDING_ID_1, NEW_SUBJECT_NAME), + new EmbeddingProjection(EMBEDDING_ID_2, NEW_SUBJECT_NAME) + ); + + verify(notificationSenderService, times(1)).notifyCacheChange( + buildCacheActionDto(CacheAction.RENAME_SUBJECTS, new RenameSubjects(Map.of(SUBJECT_NAME, NEW_SUBJECT_NAME))) + ); + } + + @Test + void removeBySubjectName() { + // arrange + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_1)); + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + reset(notificationSenderService); + + // act + embeddingCacheProvider.removeBySubjectName(API_KEY, SUBJECT_NAME); + + // assert + var embeddings = embeddingCacheProvider.getOrLoad(API_KEY); + Assertions.assertThat(embeddings.getEmbeddings().size(0)).isZero(); + Assertions.assertThat(embeddings.getProjections()).isEmpty(); + + verify(notificationSenderService, times(1)).notifyCacheChange( + buildCacheActionDto(CacheAction.REMOVE_SUBJECTS, new RemoveSubjects(List.of(SUBJECT_NAME))) + ); + } + + @Test + void addEmbedding() { + // act + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + + // assert + var embeddings = embeddingCacheProvider.getOrLoad(API_KEY); + Assertions.assertThat(embeddings.getEmbeddings().size(0)).isEqualTo(1); + Assertions.assertThat(embeddings.getProjections()).containsOnly(new EmbeddingProjection(EMBEDDING_ID_2, SUBJECT_NAME)); + + verify(notificationSenderService, times(1)).notifyCacheChange( + buildCacheActionDto(CacheAction.ADD_EMBEDDINGS, new AddEmbeddings(List.of(EMBEDDING_ID_2))) + ); + } + + @Test + void invalidate() { + // arrange + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + reset(notificationSenderService); + + // act + embeddingCacheProvider.invalidate(API_KEY); + + // assert + embeddingCacheProvider.exposeIfPresent(API_KEY, ec -> { + Assertions.assertThat(ec.getEmbeddings().size(0)).isZero(); + Assertions.assertThat(ec.getProjections()).isEmpty(); + }); + verify(notificationSenderService).notifyCacheChange( + buildCacheActionDto( + CacheAction.INVALIDATE, + null + ) + ); + } + + @Test + void receivePutOnCache() { + // arrange + receiveInvalidateCache(); + } + + @Test + void receiveInvalidateCache() { + // arrange + embeddingCacheProvider.addEmbedding(API_KEY, buildEmbedding(EMBEDDING_ID_2)); + + // act + embeddingCacheProvider.receiveInvalidateCache(API_KEY); + + // assert + embeddingCacheProvider.exposeIfPresent(API_KEY, ec -> { + Assertions.assertThat(ec.getEmbeddings().size(0)).isZero(); + Assertions.assertThat(ec.getProjections()).isEmpty(); + }); + } + + @NotNull + private static CacheActionDto buildCacheActionDto( + CacheAction cacheAction, + T payload + ) { + return new CacheActionDto<>( + cacheAction, + API_KEY, + Constants.SERVER_UUID, + payload + ); + } + + static Embedding buildEmbedding( + UUID embeddingId + ) { + var subj = new Subject( + UUID.randomUUID(), + API_KEY, + SUBJECT_NAME + ); + return new Embedding( + embeddingId, + subj, + new double[]{21.22, 222.444}, + TEST_CALCULATOR, + null + ); + } } diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/service/NotificationReceiverServiceTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/NotificationReceiverServiceTest.java new file mode 100644 index 0000000000..703c92574e --- /dev/null +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/NotificationReceiverServiceTest.java @@ -0,0 +1,114 @@ +package com.exadel.frs.core.trainservice.service; + +import com.exadel.frs.core.trainservice.dto.CacheActionDto; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings; +import com.exadel.frs.core.trainservice.dto.CacheActionDto.CacheAction; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.impossibl.postgres.api.jdbc.PGConnection; +import com.impossibl.postgres.api.jdbc.PGNotificationListener; +import com.impossibl.postgres.jdbc.PGDataSource; +import java.util.List; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class NotificationReceiverServiceTest { + static final UUID SERVER_UUID = UUID.randomUUID(); + public static final String API_KEY = "API_KEY"; + @Mock + NotificationHandler handler; + @Mock + PGDataSource pgNotificationDatasource; + @Mock + PGConnection connection; + @Mock + PGNotificationListener listener; + @Spy + ObjectMapper objectMapper; + @InjectMocks + NotificationReceiverService service; + + @ParameterizedTest + @MethodSource("okNotifications") + void synchronizeCacheWithNormalNotification( + CacheActionDto cacheActionDto, + Consumer verifier + ) throws JsonProcessingException { + // act + service.synchronizeCacheWithNotification(objectMapper.writeValueAsString(cacheActionDto)); + + // assert + verifier.accept(handler); + } + + @ParameterizedTest + @MethodSource("badNotifications") + void synchronizeCacheWithBadNotification(CacheActionDto cacheActionDto) throws JsonProcessingException { + // act + service.synchronizeCacheWithNotification(objectMapper.writeValueAsString(cacheActionDto)); + + // assert + Mockito.verifyNoInteractions(handler); + } + + static Stream okNotifications() { + return Stream.of( + Arguments.of( + buildCacheAction(CacheAction.INVALIDATE, null), + (Consumer) h -> Mockito.verify(h, Mockito.only()).invalidate(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.ADD_EMBEDDINGS, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).addEmbeddings(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.REMOVE_EMBEDDINGS, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).removeEmbeddings(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.REMOVE_SUBJECTS, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).removeSubjects(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.RENAME_SUBJECTS, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).renameSubjects(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.UPDATE, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).handleUpdate(Mockito.any()) + ), + Arguments.of( + buildCacheAction(CacheAction.DELETE, new AddEmbeddings(List.of(SERVER_UUID))), + (Consumer) h -> Mockito.verify(h, Mockito.only()).handleDelete(Mockito.any()) + ) + ); + } + + static Stream badNotifications() { + return Stream.of( + Arguments.of(new CacheActionDto<>(null, API_KEY, SERVER_UUID, null)), + Arguments.of(new CacheActionDto<>(CacheAction.INVALIDATE, " ", SERVER_UUID, null)), + Arguments.of(new CacheActionDto<>(CacheAction.INVALIDATE, API_KEY, null, null)) + ); + } + + static CacheActionDto buildCacheAction(CacheAction action, T payload) { + return new CacheActionDto<>( + action, + API_KEY, + SERVER_UUID, + payload + ); + } +} \ No newline at end of file diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java index 2d43962615..acc5b37191 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java @@ -21,7 +21,6 @@ import static com.exadel.frs.core.trainservice.system.global.Constants.IMAGE_ID; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.eq; @@ -46,10 +45,10 @@ import com.exadel.frs.core.trainservice.component.FaceClassifierPredictor; import com.exadel.frs.core.trainservice.component.classifiers.EuclideanDistanceClassifier; import com.exadel.frs.core.trainservice.dao.SubjectDao; -import com.exadel.frs.core.trainservice.dto.EmbeddingVerificationProcessResult; import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams; import com.exadel.frs.core.trainservice.dto.ProcessImageParams; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -116,7 +115,7 @@ void testRemoveAllSubjectEmbeddings() { // verify deleted from DB verify(subjectDao).removeAllSubjectEmbeddings(API_KEY, subjectName); // verify cache - verify(embeddingCacheProvider).ifPresent(eq(API_KEY), any()); + verify(embeddingCacheProvider).removeBySubjectName(API_KEY, subjectName); } @Test @@ -127,21 +126,23 @@ void deleteSubjectByName() { // verify deleted from DB verify(subjectDao).deleteSubjectByName(API_KEY, subjectName); // verify cache - verify(embeddingCacheProvider).ifPresent(eq(API_KEY), any()); + verify(embeddingCacheProvider).removeBySubjectName(eq(API_KEY), any()); } @Test void testRemoveSubjectEmbedding() { var embeddingId = UUID.randomUUID(); - when(subjectDao.removeSubjectEmbedding(API_KEY, embeddingId)).thenReturn(new Embedding()); + Embedding em = new Embedding(); + em.setSubject(new Subject()); + when(subjectDao.removeSubjectEmbedding(API_KEY, embeddingId)).thenReturn(em); subjectService.removeSubjectEmbedding(API_KEY, embeddingId); // verify deleted from DB verify(subjectDao).removeSubjectEmbedding(API_KEY, embeddingId); // verify cache update attempt - verify(embeddingCacheProvider).ifPresent(eq(API_KEY), any()); + verify(embeddingCacheProvider).removeEmbedding(eq(API_KEY), any()); } static Stream subjectNamePairsFailed() { @@ -176,7 +177,7 @@ void testUpdateSubjectNameSuccess(String oldSubjectName, String newSubjectName) assertThat(updated).isTrue(); // verify cache update attempt - verify(embeddingCacheProvider).ifPresent(eq(API_KEY), any()); + verify(embeddingCacheProvider).updateSubjectName(API_KEY, oldSubjectName, newSubjectName); } @Test @@ -322,6 +323,26 @@ void testInvalidImageIdException(boolean status){ )); } + @Test + void loadEmbeddingsById() { + // arrange + var embeddingIds = List.of(UUID.randomUUID(), UUID.randomUUID()); + var embeddings = embeddingIds.stream() + .map(id -> { + var e = new Embedding(); + e.setId(id); + return e; + }).toList(); + when(subjectDao.loadAllEmbeddingsByIds(embeddingIds)) + .thenReturn(embeddings); + + // act + var actual = subjectService.loadEmbeddingsById(embeddingIds); + + // assert + assertThat(actual).isEqualTo(embeddings); + } + private static FindFacesResponse findFacesResponse(int faceCount) { return FindFacesResponse.builder() .result( diff --git a/java/pom.xml b/java/pom.xml index c3d50cdbf6..91e155a63e 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -12,6 +12,16 @@ + + + Mac M1 Max + + 1.0.0-M2.1 + macosx-arm64 + + + + com.exadel frs-java 0.0.1-SNAPSHOT @@ -48,7 +58,7 @@ 1.0.0-beta7 4.8.0 0.8.9 - 1.6.2 + 2.3.0 9.1.6 1.6.10