Skip to content

Commit

Permalink
Add validations for Personalize input and configurations
Browse files Browse the repository at this point in the history
Description:
===========
This commit adds following changes:
* Change scoring logic to consider ranks instead of scores
* Add validation for Personalize inputs and configurations
* Add E2E tests in the form of unit tests for Personalize ranking

Signed-off-by: Ketan Kulkarni <[email protected]>
  • Loading branch information
kulket committed Jul 5, 2023
1 parent 25d3e08 commit eb62fd1
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 488 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.ValidationUtil;

import java.util.Map;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -160,6 +161,7 @@ public PersonalizeRankingResponseProcessor create(Map<String, Processor.Factory<

PersonalizeIntelligentRankerConfiguration rankerConfig =
new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, awsRegion, weight);
ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, tag);
AWSCredentialsProvider credentialsProvider = PersonalizeCredentialsProviderFactory.getCredentialsProvider(personalizeClientSettings, iamRoleArn, awsRegion);
PersonalizeClient personalizeClient = clientBuilder.apply(credentialsProvider, awsRegion);
return new PersonalizeRankingResponseProcessor(tag, description, rankerConfig, personalizeClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,4 @@ public interface PersonalizedRanker {
* @return Re ranked search hits
*/
SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters);

/**
* Validate Personalize configuration for calling Personalize service
* @param requestParameters Request parameters for Personalize present in search request
* @return True if valid configuration present else false.
*/
boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters);

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest;
import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult;
import com.amazonaws.services.personalizeruntime.model.PredictedItem;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TotalHits;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters;
Expand All @@ -23,9 +25,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.LinkedList;
import java.util.Map;
import java.util.stream.Collectors;

Expand All @@ -51,15 +52,17 @@ public AmazonPersonalizedRankerImpl(PersonalizeIntelligentRankerConfiguration co
@Override
public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters) {
try {
if (!isValidPersonalizeConfigPresent(requestParameters)) {
throw new IllegalArgumentException("Required configurations missing from Personalize " +
"response processor configuration or search request parameters");
}
validatePersonalizeRequestParams(requestParameters);
List<SearchHit> originalHits = Arrays.asList(hits.getHits());
// Do not make Personalize call if weight is zero which implies Personalization is turned off.
if (rankerConfig.getWeight() == 0) {
logger.info("Not applying Personalized ranking. Given value for weight configuration: {}", rankerConfig.getWeight());
return hits;
}
String itemIdfield = rankerConfig.getItemIdField();
List<String> documentIdsToRank;
// If item field is not specified in the configuration then use default _id field.
if (!itemIdfield.isEmpty()) {
if (!StringUtils.isEmpty(itemIdfield)) {
documentIdsToRank = originalHits.stream()
.filter(h -> h.getSourceAsMap().get(itemIdfield) != null)
.map(h -> h.getSourceAsMap().get(itemIdfield).toString())
Expand All @@ -70,13 +73,17 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
.map(h -> h.getId())
.collect(Collectors.toList());
}
if (documentIdsToRank.size() == 0) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "item_id_field",
"no item ids found to apply Personalized reranking. Please check configured value for item_id_field");
}
logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray()));
String userId = requestParameters.getUserId();
Map<String, String> context = requestParameters.getContext() != null ?
requestParameters.getContext().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> isValidPersonalizeContext(e)))
.collect(Collectors.toMap(Map.Entry::getKey, e -> (String)e.getValue()))
: null;
logger.info("User ID from request parameters. User ID: {}", userId);
logger.info("User ID from personalize request parameters - User ID: {}", userId);
if (context != null && !context.isEmpty()) {
logger.info("Personalize context provided in the search request");
}
Expand All @@ -88,109 +95,66 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa
.withUserId(userId);
GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest);

List<PredictedItem> personalizeRrankingResult = result.getPersonalizedRanking();
Map<String, Float> idToPersonalizeRankingScoreMap = new HashMap<>();
Map<String, Float> idToOpenSearchScoreMap = new HashMap<>();
Map<String, SearchHit> itemIdToSearchHitMap = new HashMap<>();
// Build a map with key as item id and value as personalize ranking score
for (PredictedItem item : personalizeRrankingResult) {
idToPersonalizeRankingScoreMap.put(item.getItemId(), item.getScore().floatValue());
}

// Build a map with key as item id and value as open search scores and another map
// with key as item id and value as corresponding search hit
for (SearchHit hit : originalHits) {
if (!itemIdfield.isEmpty()){
idToOpenSearchScoreMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit.getScore());
itemIdToSearchHitMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit);
}
else{
idToOpenSearchScoreMap.put(hit.getId(), hit.getScore());
itemIdToSearchHitMap.put(hit.getId(), hit);
}
}


float weight = (float) rankerConfig.getWeight();
SearchHits newHits = combineScores(idToPersonalizeRankingScoreMap, idToOpenSearchScoreMap,
itemIdToSearchHitMap, hits.getTotalHits(), weight);
return newHits;
SearchHits personalizedHits = combineScores(hits, result);
return personalizedHits;
} catch (Exception ex) {
logger.error("Failed to re rank with Personalize. Returning original search results without Personalize re ranking.", ex);
return hits;
logger.error("Failed to re rank with Personalize.", ex);
throw ex;
}
}

//Combine open search hits and personalize campaign response
public SearchHits combineScores(Map<String, Float> idToPersonalizeRankingScoreMap,
Map<String, Float> idToOpenSearchScoreMap,
Map<String, SearchHit> itemIdToSearchHitMap,
TotalHits totalHits, float weight) {
//Update open search score based on the personalize campaign response for each item id
List<String> openSearchItemId = new ArrayList<String>(idToOpenSearchScoreMap.keySet());
for (String itemId : openSearchItemId) {
if(idToPersonalizeRankingScoreMap.containsKey(itemId)){
float personalizedScore = idToPersonalizeRankingScoreMap.get(itemId);
float openSearchScore = idToOpenSearchScoreMap.get(itemId);
float combinedScore = (float) (weight / Math.log(openSearchScore + 1)
+ (1 - weight) / Math.log(personalizedScore + 1));
idToOpenSearchScoreMap.put(itemId, combinedScore);
private SearchHits combineScores(SearchHits originalHits, GetPersonalizedRankingResult personalizedRankingResult) {
List<PredictedItem> personalziedRanking = personalizedRankingResult.getPersonalizedRanking();
List<String> personalizedRankedItemsList = new LinkedList<>();
for (PredictedItem item : personalziedRanking) {
personalizedRankedItemsList.add(item.getItemId());
}
int totalHits = originalHits.getHits().length;
List<SearchHit> rerankedHits = new ArrayList<>(totalHits);
float maxScore = 0f;
double weight = rankerConfig.getWeight();
for (int i = 0 ; i < totalHits ; i++) {
String openSearchItemId;
SearchHit hit = originalHits.getAt(i);
if (!StringUtils.isEmpty(rankerConfig.getItemIdField())) {
openSearchItemId = hit.getSourceAsMap().get(rankerConfig.getItemIdField()).toString();
} else {
openSearchItemId = hit.getId();
}
int openSearchRank = i + 1;
int personalizedRank = personalizedRankedItemsList.indexOf(openSearchItemId) + 1;
float combinedScore = (float) (((1- weight) / (Math.log(openSearchRank + 1) / Math.log(2)))
+ ((weight) / (Math.log(personalizedRank + 1) / Math.log(2))));
maxScore = Math.max(maxScore, combinedScore);
hit.score(combinedScore);
rerankedHits.add(hit);
}

//Create a new list of search hits in the decreasing order of the combined scores
Map<String, Float> sortedScores = sortByValue(idToOpenSearchScoreMap);

List<SearchHit> rerankedHits = sortedScores.keySet().stream()
.map(itemId -> {
SearchHit hit = itemIdToSearchHitMap.get(itemId);
hit.score(sortedScores.get(itemId));
return hit;
})
.collect(Collectors.toList());
float maxScore = sortedScores.values().stream().max(Float::compare).orElse(0f);
return new SearchHits(rerankedHits.toArray(new SearchHit[0]), totalHits, maxScore);
}


//Sort map by reverse order of the values
public Map<String, Float> sortByValue(Map<String, Float> map) {
return map.entrySet().stream()
.sorted(Map.Entry.<String, Float>comparingByValue().reversed())
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(oldValue, newValue) -> oldValue, LinkedHashMap::new));
rerankedHits.sort(Comparator.comparing(SearchHit::getScore).reversed());
return new SearchHits(rerankedHits.toArray(new SearchHit[0]), originalHits.getTotalHits(), maxScore);
}


/**
* Validate Personalize configuration for calling Personalize service
* @param requestParameters Request parameters for Personalize present in search request
* @return True if valid configuration present else false.
*/
public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters) {
boolean isValidPersonalizeConfig = true;

if (requestParameters == null || requestParameters.getUserId().isEmpty()) {
isValidPersonalizeConfig = false;
logger.error("Required Personalize parameters are not provided in the search request");
private void validatePersonalizeRequestParams(PersonalizeRequestParameters requestParameters) {
if (requestParameters == null || StringUtils.isEmpty(requestParameters.getUserId())) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "user_id",
"required Personalize request parameter is missing");
}

if (rankerConfig == null || rankerConfig.getPersonalizeCampaign().isEmpty() ||
rankerConfig.getWeight() < 0.0 || rankerConfig.getWeight() > 1.0) {
isValidPersonalizeConfig = false;
logger.error("Required Personalized ranker configuration is missing");
if (requestParameters.getContext() != null) {
try {
requestParameters.getContext().entrySet().stream().forEach(e -> isValidPersonalizeContext(e));
} catch (IllegalArgumentException iae) {
throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "context", iae.getMessage());
}
}
return isValidPersonalizeConfig;
}

private String isValidPersonalizeContext(Map.Entry<String, Object> contextEntry) throws IllegalArgumentException {
if (contextEntry.getValue() instanceof String) {
return (String) contextEntry.getValue();
} else {
throw new IllegalArgumentException("Personalize context value is not of type String. " +
"Invalid context value: " + contextEntry.getValue());
private void isValidPersonalizeContext(Map.Entry<String, Object> contextEntry) throws IllegalArgumentException {
if (!(contextEntry.getValue() instanceof String)) {
throw new IllegalArgumentException("Personalize context value is not of type String. Invalid context value: " + contextEntry.getValue());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils;
import com.amazonaws.arn.Arn;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration;

import java.util.Arrays;
import java.util.Set;
import java.util.HashSet;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;

public class ValidationUtil {

Check warning on line 20 in src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java#L20

Added line #L20 was not covered by tests
private static Set<String> SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME));

/**
* Validate Personalize configuration for calling Personalize service.
* Throws OpenSearchParseException type exception if validation fails.
* @param config Personalize intelligent ranker configuration
* @param processorType Name of search pipeline processor
* @param processorTag Name of processor tag
*/
public static void validatePersonalizeIntelligentRankerConfiguration (PersonalizeIntelligentRankerConfiguration config,
String processorType,
String processorTag
) {
// Validate weight value
if (config.getWeight() < 0.0 || config.getWeight() > 1.0) {
throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "weight", "invalid value for weight");
}
// Validate Personalize campaign ARN
if(!isValidCampaignOrRoleArn(config.getPersonalizeCampaign(), "personalize")) {
throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "campaign_arn", "invalid format for Personalize campaign arn");
}
// Validate IAM Role Arn for Personalize access
String iamRoleArn = config.getIamRoleArn();
if(!StringUtils.isEmpty(iamRoleArn) && !isValidCampaignOrRoleArn(iamRoleArn, "iam")) {
throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "iam_role_arn", "invalid format for Personalize iam role arn");
}
// Validate Personalize recipe
if(!SUPPORTED_PERSONALIZE_RECIPES.contains(config.getRecipe())) {
throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "recipe", "not supported recipe provided");

Check warning on line 49 in src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java#L49

Added line #L49 was not covered by tests
}
}

private static boolean isValidCampaignOrRoleArn(String arn, String expectedService) {
try {
Arn arnObj = Arn.fromString(arn);
String arnService = arnObj.getService();
if (arnObj.getResource() == null || !arnService.equals(expectedService)) {
return false;
}
} catch (IllegalArgumentException iae) {
return false;
}
return true;
}
}
Loading

0 comments on commit eb62fd1

Please sign in to comment.