Skip to content

Commit

Permalink
provide more granular way to manage embedding cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-kripakov-m10 committed Nov 6, 2023
1 parent 4d826e5 commit 93e4459
Show file tree
Hide file tree
Showing 12 changed files with 629 additions and 99 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package com.exadel.frs.core.trainservice.cache;

import com.exadel.frs.commonservice.entity.Embedding;
import com.exadel.frs.commonservice.projection.EmbeddingProjection;
import com.exadel.frs.core.trainservice.dto.CacheActionDto;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.CacheAction;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects;
import com.exadel.frs.core.trainservice.service.EmbeddingService;
import com.exadel.frs.core.trainservice.service.NotificationSenderService;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;

@Component
Expand All @@ -34,34 +41,60 @@ public class EmbeddingCacheProvider {
.build();

public EmbeddingCollection getOrLoad(final String apiKey) {

var result = cache.getIfPresent(apiKey);

if (result == null) {
result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);

cache.put(apiKey, result);

notifyCacheEvent("UPDATE", apiKey);
}

return result;
}

public void ifPresent(String apiKey, Consumer<EmbeddingCollection> consumer) {
Optional.ofNullable(cache.getIfPresent(apiKey))
.ifPresent(consumer);
public void removeEmbedding(String apiKey, EmbeddingProjection embedding) {
getOrLoad(apiKey).removeEmbedding(embedding);
notifyCacheEvent(
CacheAction.REMOVE_EMBEDDINGS,
apiKey,
new RemoveEmbeddings(Map.of(embedding.subjectName(), List.of(embedding.embeddingId())))
);
}

public void updateSubjectName(String apiKey, String oldSubjectName, String newSubjectName) {
getOrLoad(apiKey).updateSubjectName(oldSubjectName, newSubjectName);
notifyCacheEvent(CacheAction.RENAME_SUBJECTS, apiKey, new RenameSubjects(Map.of(oldSubjectName, newSubjectName)));
}

cache.getIfPresent(apiKey);
notifyCacheEvent("UPDATE", apiKey);
public void removeBySubjectName(String apiKey, String subjectName) {
getOrLoad(apiKey).removeEmbeddingsBySubjectName(subjectName);
notifyCacheEvent(CacheAction.REMOVE_SUBJECTS, apiKey, new RemoveSubjects(List.of(subjectName)));
}


public void addEmbedding(String apiKey, Embedding embedding) {
getOrLoad(apiKey).addEmbedding(embedding);
notifyCacheEvent(CacheAction.ADD_EMBEDDINGS, apiKey, new AddEmbeddings(List.of(embedding.getId())));
}

/**
* Method can be used to make changes in cache without sending notification.
* Use it carefully, because changes you do will not be visible for other compreface-api instances
*
* @param apiKey domain
* @param action what to do with {@link EmbeddingCollection}
*/
public void exposeIfPresent(String apiKey, Consumer<EmbeddingCollection> action) {
action.accept(getOrLoad(apiKey));
}

public void invalidate(final String apiKey) {
cache.invalidate(apiKey);
notifyCacheEvent("DELETE", apiKey);
notifyCacheEvent(CacheAction.INVALIDATE, apiKey, null);
}


/**
* @deprecated
* See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)}
*/
@Deprecated(forRemoval = true)
public void receivePutOnCache(String apiKey) {
var result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);
cache.put(apiKey, result);
Expand All @@ -71,8 +104,8 @@ public void receiveInvalidateCache(final String apiKey) {
cache.invalidate(apiKey);
}

private void notifyCacheEvent(String event, String apiKey) {
CacheActionDto cacheActionDto = new CacheActionDto(event, apiKey, SERVER_UUID);
private void notifyCacheEvent(CacheAction event, String apiKey, Object action) {
CacheActionDto<?> cacheActionDto = new CacheActionDto<>(event, apiKey, SERVER_UUID, action);
notificationSenderService.notifyCacheChange(cacheActionDto);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public Collection<String> getSubjectNames(final String apiKey) {
return subjectRepository.getSubjectNames(apiKey);
}

public List<Embedding> loadAllEmbeddingsByIds(Iterable<UUID> ids) {
return embeddingRepository.findAllById(ids);
}

@Transactional
public Subject deleteSubjectByName(final String apiKey, final String subjectName) {
final Optional<Subject> subjectOptional = subjectRepository.findByApiKeyAndSubjectNameIgnoreCase(apiKey, subjectName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,62 @@
package com.exadel.frs.core.trainservice.dto;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Map;
import java.util.UUID;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class CacheActionDto {
@JsonIgnoreProperties(ignoreUnknown = true) // here and below "ignoreUnknown = true" for backward compatibility
public record CacheActionDto<T>(
CacheAction cacheAction,
String apiKey,
@JsonProperty("uuid")
UUID serverUUID,
T payload
) {
public <S> CacheActionDto<S> withPayload(S payload) {
return new CacheActionDto<>(
cacheAction,
apiKey,
serverUUID,
payload
);
}

@JsonProperty("cacheAction")
private String cacheAction;
public enum CacheAction {
// UPDATE and DELETE stays here to support rolling update
@Deprecated
UPDATE,
@Deprecated
DELETE,
REMOVE_EMBEDDINGS,
REMOVE_SUBJECTS,
ADD_EMBEDDINGS,
RENAME_SUBJECTS,
INVALIDATE
}

@JsonProperty("apiKey")
private String apiKey;
@JsonIgnoreProperties(ignoreUnknown = true)
public record RemoveEmbeddings(
Map<String, List<UUID>> embeddings
) {
}

@JsonProperty("uuid")
private String serverUUID;
@JsonIgnoreProperties(ignoreUnknown = true)
public record RemoveSubjects(
List<String> subjects
) {
}

@JsonIgnoreProperties(ignoreUnknown = true)
public record AddEmbeddings(
List<UUID> embeddings
) {
}

@JsonIgnoreProperties(ignoreUnknown = true)
public record RenameSubjects(
Map<String, String> subjectsNamesMapping
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.exadel.frs.core.trainservice.system.global.Constants;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.val;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
Expand All @@ -32,9 +31,9 @@ public int updateEmbedding(UUID embeddingId, double[] embedding, String calculat
return embeddingRepository.updateEmbedding(embeddingId, embedding, calculator);
}

@Transactional
@org.springframework.transaction.annotation.Transactional(readOnly = true)
public <T> T doWithEnhancedEmbeddingProjectionStream(String apiKey, Function<Stream<EnhancedEmbeddingProjection>, T> func) {
try (val stream = embeddingRepository.findBySubjectApiKey(apiKey)) {
try (var stream = embeddingRepository.findBySubjectApiKey(apiKey)) {
return func.apply(stream);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.exadel.frs.core.trainservice.service;

import com.exadel.frs.commonservice.projection.EmbeddingProjection;
import com.exadel.frs.core.trainservice.cache.EmbeddingCacheProvider;
import com.exadel.frs.core.trainservice.dto.CacheActionDto;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.AddEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveEmbeddings;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RemoveSubjects;
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;

@Slf4j
@Service
@RequiredArgsConstructor
public class NotificationHandler {
private final EmbeddingCacheProvider cacheProvider;
private final SubjectService subjectService;

public void removeEmbeddings(CacheActionDto<RemoveEmbeddings> action) {
action.payload().embeddings()
.entrySet()
.stream()
.filter(e -> StringUtils.isNotBlank(e.getKey()))
.filter(e -> Objects.nonNull(e.getValue()))
.filter(e -> !e.getValue().isEmpty())
.flatMap(e -> e.getValue().stream().filter(Objects::nonNull).map(id -> new EmbeddingProjection(id, e.getKey())))
.forEach(
em -> cacheProvider.exposeIfPresent(
action.apiKey(),
c -> c.removeEmbedding(em)
)
);
}

public void removeSubjects(CacheActionDto<RemoveSubjects> action) {
action.payload().subjects()
.stream()
.filter(StringUtils::isNotBlank)
.forEach(
s -> cacheProvider.exposeIfPresent(
action.apiKey(),
c -> c.removeEmbeddingsBySubjectName(s)
)
);
}


public void addEmbeddings(CacheActionDto<AddEmbeddings> action) {
var filtered = action.payload().embeddings()
.stream()
.filter(Objects::nonNull)
.toList();
subjectService.loadEmbeddingsById(filtered)
.forEach(
em -> cacheProvider.exposeIfPresent(
action.apiKey(),
c -> c.addEmbedding(em)
)
);
}

public void renameSubjects(CacheActionDto<RenameSubjects> action) {
action.payload().subjectsNamesMapping()
.entrySet()
.stream()
.filter(e -> StringUtils.isNotBlank(e.getKey()))
.filter(e -> StringUtils.isNotBlank(e.getValue()))
.forEach(
e -> cacheProvider.exposeIfPresent(
action.apiKey(),
c -> c.updateSubjectName(e.getKey(), e.getValue())
)
);
}

public <T> void invalidate(CacheActionDto<T> action) {
cacheProvider.exposeIfPresent(
action.apiKey(),
e -> cacheProvider.receiveInvalidateCache(action.apiKey())
);
}

/**
* @param action cacheAction
* @deprecated in favour more granular cache managing.
* See {@link CacheActionDto}.
* Stays here to support rolling update
*/
@Deprecated(forRemoval = true)
public <T> void handleDelete(CacheActionDto<T> action) {
cacheProvider.receiveInvalidateCache(action.apiKey());
}

/**
* @param action cacheAction
* @deprecated in favour more granular cache managing.
* See {@link CacheActionDto}.
* Stays here to support rolling update
*/
@Deprecated(forRemoval = true)
public <T> void handleUpdate(CacheActionDto<T> action) {
cacheProvider.receivePutOnCache(action.apiKey());
}
}
Loading

0 comments on commit 93e4459

Please sign in to comment.