From 64948c57cfe5edb4cde8322476812b4d4d447410 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 11 Jul 2023 18:51:44 -0700 Subject: [PATCH] Add validations for Personalize input and configurations (#151) (#160) * Add validations for Personalize input and configurations Description: =========== This commit adds following changes: * Change scoring logic to consider ranks instead of scores * Add validation for Personalize inputs and configurations * Add E2E tests in the form of unit tests for Personalize ranking Signed-off-by: Ketan Kulkarni * Add user agent prefix to Personalize client to understand plugin adoption Description: ============ This change adds following: * User agent configuration for Personalize client to understand plugin adoption * Add validation related unit tests Signed-off-by: Ketan Kulkarni * Add validations for Personalize input and configurations Description: =========== This commit adds following changes: * Change scoring logic to consider ranks instead of scores * Add validation for Personalize inputs and configurations * Add E2E tests in the form of unit tests for Personalize ranking Signed-off-by: Ketan Kulkarni * Change response processor config name Description: =========== This commit includes following: * Change response processor name from personalize_ranking to personalized_search_ranking * Change client settings to use same naming convention * Address review comments Signed-off-by: Ketan Kulkarni * Fix typo in unit test Signed-off-by: Ketan Kulkarni --------- Signed-off-by: Ketan Kulkarni (cherry picked from commit e305c138d34253d57d24486fc30d1238fa867a1b) Co-authored-by: kulket <130191298+kulket@users.noreply.github.com> --- .../PersonalizeRankingResponseProcessor.java | 4 +- .../client/PersonalizeClient.java | 5 + .../client/PersonalizeClientSettings.java | 6 +- .../reranker/PersonalizedRanker.java | 8 - .../impl/AmazonPersonalizedRankerImpl.java | 156 ++++---- .../utils/ValidationUtil.java | 60 ++++ ...sonalizeRankingResponseProcessorTests.java | 335 ++++++++++++++++++ .../PersonalizeResponseProcessorTests.java | 214 ----------- .../AmazonPersonalizeRankerImplTests.java | 147 +------- .../utils/PersonalizeRuntimeTestUtil.java | 63 ++-- .../utils/SearchTestUtil.java | 25 +- .../utils/ValidationUtilTests.java | 94 +++++ 12 files changed, 623 insertions(+), 494 deletions(-) create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java delete mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java index e7adc67..3bbc9d6 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java @@ -29,6 +29,7 @@ import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRankerFactory; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.ValidationUtil; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -41,7 +42,7 @@ public class PersonalizeRankingResponseProcessor extends AbstractProcessor imple private static final Logger logger = LogManager.getLogger(PersonalizeRankingResponseProcessor.class); - public static final String TYPE = "personalize_ranking"; + public static final String TYPE = "personalized_search_ranking"; private final String tag; private final String description; private final PersonalizeClient personalizeClient; @@ -163,6 +164,7 @@ public PersonalizeRankingResponseProcessor create(Map) () -> AmazonPersonalizeRuntimeClientBuilder.standard() .withCredentials(credentialsProvider) .withRegion(awsRegion) + .withClientConfiguration(clientConfiguration) .build()); } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java index c3de442..234aa19 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/client/PersonalizeClientSettings.java @@ -28,17 +28,17 @@ public final class PersonalizeClientSettings { /** * The access key (ie login id) for connecting to Personalize. */ - public static final Setting ACCESS_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.access_key", null); + public static final Setting ACCESS_KEY_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.access_key", null); /** * The secret key (ie password) for connecting to Personalize. */ - public static final Setting SECRET_KEY_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.secret_key", null); + public static final Setting SECRET_KEY_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.secret_key", null); /** * The session token for connecting to Personalize. */ - public static final Setting SESSION_TOKEN_SETTING = SecureSetting.secureString("personalize_intelligent_ranking.aws.session_token", null); + public static final Setting SESSION_TOKEN_SETTING = SecureSetting.secureString("personalized_search_ranking.aws.session_token", null); private final AWSCredentials credentials; diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java index 4699897..8470f9f 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRanker.java @@ -19,12 +19,4 @@ public interface PersonalizedRanker { * @return Re ranked search hits */ SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters); - - /** - * Validate Personalize configuration for calling Personalize service - * @param requestParameters Request parameters for Personalize present in search request - * @return True if valid configuration present else false. - */ - boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters); - } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java index 5d2e5db..cb8cac1 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java @@ -12,9 +12,10 @@ import com.amazonaws.services.personalizeruntime.model.PredictedItem; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.TotalHits; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; @@ -23,9 +24,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; +import java.util.LinkedList; import java.util.Map; import java.util.stream.Collectors; @@ -51,15 +51,17 @@ public AmazonPersonalizedRankerImpl(PersonalizeIntelligentRankerConfiguration co @Override public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestParameters) { try { - if (!isValidPersonalizeConfigPresent(requestParameters)) { - throw new IllegalArgumentException("Required configurations missing from Personalize " + - "response processor configuration or search request parameters"); - } + validatePersonalizeRequestParams(requestParameters); List originalHits = Arrays.asList(hits.getHits()); + // Do not make Personalize call if weight is zero which implies Personalization is turned off. + if (rankerConfig.getWeight() == 0) { + logger.info("Not applying Personalized ranking. Given value for weight configuration: {}", rankerConfig.getWeight()); + return hits; + } String itemIdfield = rankerConfig.getItemIdField(); List documentIdsToRank; // If item field is not specified in the configuration then use default _id field. - if (!itemIdfield.isEmpty()) { + if (itemIdfield != null && !itemIdfield.isBlank()) { documentIdsToRank = originalHits.stream() .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) @@ -70,13 +72,17 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa .map(h -> h.getId()) .collect(Collectors.toList()); } + if (documentIdsToRank.size() == 0) { + throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "item_id_field", + "no item ids found to apply Personalized reranking. Please check configured value for item_id_field"); + } logger.info("Document Ids to re-rank with Personalize: {}", Arrays.toString(documentIdsToRank.toArray())); String userId = requestParameters.getUserId(); Map context = requestParameters.getContext() != null ? requestParameters.getContext().entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> isValidPersonalizeContext(e))) + .collect(Collectors.toMap(Map.Entry::getKey, e -> (String)e.getValue())) : null; - logger.info("User ID from request parameters. User ID: {}", userId); + logger.info("User ID from personalize request parameters - User ID: {}", userId); if (context != null && !context.isEmpty()) { logger.info("Personalize context provided in the search request"); } @@ -88,109 +94,67 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa .withUserId(userId); GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest); - List personalizeRrankingResult = result.getPersonalizedRanking(); - Map idToPersonalizeRankingScoreMap = new HashMap<>(); - Map idToOpenSearchScoreMap = new HashMap<>(); - Map itemIdToSearchHitMap = new HashMap<>(); - // Build a map with key as item id and value as personalize ranking score - for (PredictedItem item : personalizeRrankingResult) { - idToPersonalizeRankingScoreMap.put(item.getItemId(), item.getScore().floatValue()); - } - - // Build a map with key as item id and value as open search scores and another map - // with key as item id and value as corresponding search hit - for (SearchHit hit : originalHits) { - if (!itemIdfield.isEmpty()){ - idToOpenSearchScoreMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit.getScore()); - itemIdToSearchHitMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit); - } - else{ - idToOpenSearchScoreMap.put(hit.getId(), hit.getScore()); - itemIdToSearchHitMap.put(hit.getId(), hit); - } - } - - - float weight = (float) rankerConfig.getWeight(); - SearchHits newHits = combineScores(idToPersonalizeRankingScoreMap, idToOpenSearchScoreMap, - itemIdToSearchHitMap, hits.getTotalHits(), weight); - return newHits; + SearchHits personalizedHits = combineScores(hits, result); + return personalizedHits; } catch (Exception ex) { - logger.error("Failed to re rank with Personalize. Returning original search results without Personalize re ranking.", ex); - return hits; + logger.error("Failed to re rank with Personalize.", ex); + throw ex; } } //Combine open search hits and personalize campaign response - public SearchHits combineScores(Map idToPersonalizeRankingScoreMap, - Map idToOpenSearchScoreMap, - Map itemIdToSearchHitMap, - TotalHits totalHits, float weight) { - //Update open search score based on the personalize campaign response for each item id - List openSearchItemId = new ArrayList(idToOpenSearchScoreMap.keySet()); - for (String itemId : openSearchItemId) { - if(idToPersonalizeRankingScoreMap.containsKey(itemId)){ - float personalizedScore = idToPersonalizeRankingScoreMap.get(itemId); - float openSearchScore = idToOpenSearchScoreMap.get(itemId); - float combinedScore = (float) (weight / Math.log(openSearchScore + 1) - + (1 - weight) / Math.log(personalizedScore + 1)); - idToOpenSearchScoreMap.put(itemId, combinedScore); + private SearchHits combineScores(SearchHits originalHits, GetPersonalizedRankingResult personalizedRankingResult) { + List personalziedRanking = personalizedRankingResult.getPersonalizedRanking(); + List personalizedRankedItemsList = new LinkedList<>(); + for (PredictedItem item : personalziedRanking) { + personalizedRankedItemsList.add(item.getItemId()); + } + int totalHits = originalHits.getHits().length; + List rerankedHits = new ArrayList<>(totalHits); + float maxScore = 0f; + double weight = rankerConfig.getWeight(); + for (int i = 0 ; i < totalHits ; i++) { + String openSearchItemId; + SearchHit hit = originalHits.getAt(i); + String itemIdField = rankerConfig.getItemIdField(); + if (itemIdField != null && !(itemIdField.isBlank())) { + openSearchItemId = hit.getSourceAsMap().get(rankerConfig.getItemIdField()).toString(); + } else { + openSearchItemId = hit.getId(); } + int openSearchRank = i + 1; + int personalizedRank = personalizedRankedItemsList.indexOf(openSearchItemId) + 1; + float combinedScore = (float) (((1- weight) / (Math.log(openSearchRank + 1) / Math.log(2))) + + ((weight) / (Math.log(personalizedRank + 1) / Math.log(2)))); + maxScore = Math.max(maxScore, combinedScore); + hit.score(combinedScore); + rerankedHits.add(hit); } - - //Create a new list of search hits in the decreasing order of the combined scores - Map sortedScores = sortByValue(idToOpenSearchScoreMap); - - List rerankedHits = sortedScores.keySet().stream() - .map(itemId -> { - SearchHit hit = itemIdToSearchHitMap.get(itemId); - hit.score(sortedScores.get(itemId)); - return hit; - }) - .collect(Collectors.toList()); - float maxScore = sortedScores.values().stream().max(Float::compare).orElse(0f); - return new SearchHits(rerankedHits.toArray(new SearchHit[0]), totalHits, maxScore); - } - - - //Sort map by reverse order of the values - public Map sortByValue(Map map) { - return map.entrySet().stream() - .sorted(Map.Entry.comparingByValue().reversed()) - .collect(Collectors.toMap( - Map.Entry::getKey, - Map.Entry::getValue, - (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + rerankedHits.sort(Comparator.comparing(SearchHit::getScore).reversed()); + return new SearchHits(rerankedHits.toArray(new SearchHit[0]), originalHits.getTotalHits(), maxScore); } - /** * Validate Personalize configuration for calling Personalize service * @param requestParameters Request parameters for Personalize present in search request - * @return True if valid configuration present else false. */ - public boolean isValidPersonalizeConfigPresent(PersonalizeRequestParameters requestParameters) { - boolean isValidPersonalizeConfig = true; - - if (requestParameters == null || requestParameters.getUserId().isEmpty()) { - isValidPersonalizeConfig = false; - logger.error("Required Personalize parameters are not provided in the search request"); + private void validatePersonalizeRequestParams(PersonalizeRequestParameters requestParameters) { + if (requestParameters == null || requestParameters.getUserId() == null || requestParameters.getUserId().isBlank()) { + throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "user_id", + "required Personalize request parameter is missing"); } - - if (rankerConfig == null || rankerConfig.getPersonalizeCampaign().isEmpty() || - rankerConfig.getWeight() < 0.0 || rankerConfig.getWeight() > 1.0) { - isValidPersonalizeConfig = false; - logger.error("Required Personalized ranker configuration is missing"); + if (requestParameters.getContext() != null) { + try { + requestParameters.getContext().entrySet().stream().forEach(e -> isValidPersonalizeContext(e)); + } catch (IllegalArgumentException iae) { + throw ConfigurationUtils.newConfigurationException(PersonalizeRankingResponseProcessor.TYPE, "", "context", iae.getMessage()); + } } - return isValidPersonalizeConfig; } - private String isValidPersonalizeContext(Map.Entry contextEntry) throws IllegalArgumentException { - if (contextEntry.getValue() instanceof String) { - return (String) contextEntry.getValue(); - } else { - throw new IllegalArgumentException("Personalize context value is not of type String. " + - "Invalid context value: " + contextEntry.getValue()); + private void isValidPersonalizeContext(Map.Entry contextEntry) throws IllegalArgumentException { + if (!(contextEntry.getValue() instanceof String)) { + throw new IllegalArgumentException("Personalize context value is not of type String. Invalid context value: " + contextEntry.getValue()); } } } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java new file mode 100644 index 0000000..c318d46 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; +import com.amazonaws.arn.Arn; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; + +import java.util.Arrays; +import java.util.Set; +import java.util.HashSet; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +public class ValidationUtil { + private static Set SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)); + + /** + * Validate Personalize configuration for calling Personalize service. + * Throws OpenSearchParseException type exception if validation fails. + * @param config Personalize intelligent ranker configuration + * @param processorType Name of search pipeline processor + * @param processorTag Name of processor tag + */ + public static void validatePersonalizeIntelligentRankerConfiguration (PersonalizeIntelligentRankerConfiguration config, + String processorType, + String processorTag + ) { + // Validate weight value + if (config.getWeight() < 0.0 || config.getWeight() > 1.0) { + throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "weight", "invalid value for weight"); + } + // Validate Personalize campaign ARN + if(!isValidCampaignOrRoleArn(config.getPersonalizeCampaign(), "personalize")) { + throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "campaign_arn", "invalid format for Personalize campaign arn"); + } + // Validate IAM Role Arn for Personalize access + String iamRoleArn = config.getIamRoleArn(); + if(!(iamRoleArn == null || iamRoleArn.isBlank()) && !isValidCampaignOrRoleArn(iamRoleArn, "iam")) { + throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "iam_role_arn", "invalid format for Personalize iam role arn"); + } + // Validate Personalize recipe + if(!SUPPORTED_PERSONALIZE_RECIPES.contains(config.getRecipe())) { + throw ConfigurationUtils.newConfigurationException(processorType, processorTag, "recipe", "not supported recipe provided"); + } + } + + private static boolean isValidCampaignOrRoleArn(String arn, String expectedService) { + try { + Arn arnObj = Arn.fromString(arn); + return arnObj.getService().equals(expectedService); + } catch (IllegalArgumentException iae) { + return false; + } + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java new file mode 100644 index 0000000..71cbfdf --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java @@ -0,0 +1,335 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking; + +import com.amazonaws.http.IdleConnectionReaper; +import org.apache.lucene.search.TotalHits; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.env.TestEnvironment; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.PersonalizeRuntimeTestUtil; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.SearchTestUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +public class PersonalizeRankingResponseProcessorTests extends OpenSearchTestCase { + + private static final String TYPE = PersonalizeRankingResponseProcessor.TYPE; + private Settings settings = buildEnvSettings(Settings.EMPTY); + private Environment env = TestEnvironment.newEnvironment(settings); + private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + private String iamRoleArn = "arn:aws:iam::000000000000:role/test"; + private String itemIdField = "ITEM_ID"; + private String region = "us-west-2"; + private double weight = 1.0; + private int numHits = 10; + + private PersonalizeClientSettings clientSettings = PersonalizeClientSettings.getClientSettings(env.settings()); + + public void testCreateFactoryThrowsExceptionWithEmptyConfig() { + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + Collections.emptyMap(), + null + )); + IdleConnectionReaper.shutdown(); + } + + public void testFactory() { + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); + // Test config without campaign + Map configuration = new HashMap<>(); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + configuration, + null + )); + configuration.clear(); + + // Test config without recipe + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + configuration, + null + )); + configuration.clear(); + + // Test config without region + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + configuration, + null + )); + configuration.clear(); + + // Test config without weight + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + configuration, + null + )); + configuration.clear(); + + // Test configuration with invalid weight value + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", "invalid"); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + expectThrows(OpenSearchParseException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + configuration, + null + )); + configuration.clear(); + IdleConnectionReaper.shutdown(); + } + + public void testCreateFactoryWithAllPersonalizeConfig() throws Exception { + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); + + Map configuration = buildPersonalizeResponseProcessorConfig(); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + assertEquals(TYPE, personalizeResponseProcessor.getType()); + assertEquals("testTag", personalizeResponseProcessor.getTag()); + assertEquals("testingAllFields", personalizeResponseProcessor.getDescription()); + IdleConnectionReaper.shutdown(); + } + + public void testProcessorWithNoHits() throws Exception { + PersonalizeClient mockClient = mock(PersonalizeClient.class); + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = buildPersonalizeResponseProcessorConfig(); + + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + SearchRequest searchRequest = new SearchRequest(); + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + SearchResponse response = personalizeResponseProcessor.processResponse(searchRequest, searchResponse); + assertEquals(hits.getTotalHits().value, response.getHits().getTotalHits().value); + IdleConnectionReaper.shutdown(); + } + + public void testProcessorWithPersonalizeContext() throws Exception { + PersonalizeClient mockClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = buildPersonalizeResponseProcessorConfig(); + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + Map personalizeContext = new HashMap<>(); + personalizeContext.put("contextKey2", "contextValue2"); + + SearchResponse personalizedResponse = + getPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, numHits); + + List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); + List rerankedDocumentIds; + rerankedDocumentIds = transformedHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdField) != null) + .map(h -> h.getSourceAsMap().get(itemIdField).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + IdleConnectionReaper.shutdown(); + } + + public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exception { + PersonalizeClient mockClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient();; + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); + + Map configuration = buildPersonalizeResponseProcessorConfig(); + PersonalizeRankingResponseProcessor personalizeResponseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + Map personalizeContext = new HashMap<>(); + personalizeContext.put("contextKey2", 5); + + expectThrows(OpenSearchParseException.class, () -> + getPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, numHits)); + IdleConnectionReaper.shutdown(); + } + + public void testPersonalizeRankingResponse() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemField = "ITEM_ID"; + Map configuration = buildPersonalizeResponseProcessorConfig(); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + SearchResponse personalizedResponse = getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits); + + List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); + List rerankedDocumentIds; + rerankedDocumentIds = transformedHits.stream() + .filter(h -> h.getSourceAsMap().get(itemField) != null) + .map(h -> h.getSourceAsMap().get(itemField).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + IdleConnectionReaper.shutdown(); + } + + public void testPersonalizeRankingResponseWithInvalidItemIdFieldName() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemFieldInvalid = "ITEM_ID_NOT_VALID"; + Map configuration = buildPersonalizeResponseProcessorConfig(); + configuration.put("item_id_field", itemFieldInvalid); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + expectThrows(OpenSearchParseException.class, () -> + getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits)); + IdleConnectionReaper.shutdown(); + } + + public void testPersonalizeRankingResponseWithDefaultItemIdField() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemIdFieldEmpty = ""; + Map configuration = buildPersonalizeResponseProcessorConfig(); + configuration.put("item_id_field", itemIdFieldEmpty); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + + SearchResponse personalizedResponse = getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits); + + List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); + List rerankedDocumentIds; + rerankedDocumentIds = transformedHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + IdleConnectionReaper.shutdown(); + } + + private SearchResponse getPersonalizedRankingProcessorResponse(PersonalizeRankingResponseProcessor responseProcessor, + Map personalizeContext, + int numHits) throws Exception { + + PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); + SearchRequest request = SearchTestUtil.createSearchRequestWithPersonalizeRequest(personalizeRequestParams); + + SearchHits searchHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numHits); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); + + SearchResponse personalizedResponse = responseProcessor.processResponse(request, searchResponse); + + return personalizedResponse; + } + + private Map buildPersonalizeResponseProcessorConfig() { + Map configuration = new HashMap<>(); + configuration.put("campaign_arn", personalizeCampaign); + configuration.put("item_id_field", itemIdField); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(weight)); + configuration.put("iam_role_arn", iamRoleArn); + configuration.put("aws_region", region); + return configuration; + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java deleted file mode 100644 index 7d0f2eb..0000000 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeResponseProcessorTests.java +++ /dev/null @@ -1,214 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.search.relevance.transformer.personalizeintelligentranking; - -import com.amazonaws.http.IdleConnectionReaper; -import org.apache.lucene.search.TotalHits; -import org.opensearch.OpenSearchParseException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchResponseSections; -import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.common.settings.Settings; -import org.opensearch.env.Environment; -import org.opensearch.env.TestEnvironment; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; -import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClientSettings; -import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; -import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder; -import org.opensearch.test.OpenSearchTestCase; - -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.mockito.Mockito.mock; -import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; - -public class PersonalizeResponseProcessorTests extends OpenSearchTestCase { - - private static final String TYPE = "personalize_ranking"; - private Settings settings = buildEnvSettings(Settings.EMPTY); - private Environment env = TestEnvironment.newEnvironment(settings); - private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; - private String iamRoleArn = ""; - private String recipe = "sample-personalize-recipe"; - private String itemIdField = ""; - private String region = "us-west-2"; - private double weight = 0.25; - - private PersonalizeClientSettings clientSettings = PersonalizeClientSettings.getClientSettings(env.settings()); - - public void testCreateFactoryThrowsExceptionWithEmptyConfig() { - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); - expectThrows(OpenSearchParseException.class, () -> factory.create( - Collections.emptyMap(), - null, - null, - false, - Collections.emptyMap(), - null - )); - } - - public void testCreateFactoryWithAllPersonalizeConfig() throws Exception { - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings); - - Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("recipe", recipe); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); - - PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); - - assertEquals(TYPE, personalizeResponseProcessor.getType()); - assertEquals("testTag", personalizeResponseProcessor.getTag()); - assertEquals("testingAllFields", personalizeResponseProcessor.getDescription()); - IdleConnectionReaper.shutdown(); - } - - public void testProcessorWithNoHits() throws Exception { - PersonalizeClient mockClient = mock(PersonalizeClient.class); - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); - - Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("recipe", recipe); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); - - PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); - SearchRequest searchRequest = new SearchRequest(); - SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); - SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); - - personalizeResponseProcessor.processResponse(searchRequest, searchResponse); - } - - public void testProcessorWithHits() throws Exception { - PersonalizeClient mockClient = mock(PersonalizeClient.class); - - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); - - Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); - - PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); - SearchRequest searchRequest = new SearchRequest(); - SearchHit[] searchHits = new SearchHit[10]; - for (int i = 0; i < searchHits.length; i++) { - searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); - searchHits[i].score(1.0f); - } - SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); - SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); - - personalizeResponseProcessor.processResponse(searchRequest, searchResponse); - } - - public void testProcessorWithHitsAndSearchProcessorExt() throws Exception { - PersonalizeClient mockClient = mock(PersonalizeClient.class); - - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); - - Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); - - PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); - - Map personalizeContext = new HashMap<>(); - personalizeContext.put("contextKey2", "contextValue2"); - PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); - PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); - extBuilder.setRequestParameters(personalizeRequestParams); - - SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() - .ext(List.of(extBuilder)); - - SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); - SearchHit[] searchHits = new SearchHit[10]; - for (int i = 0; i < searchHits.length; i++) { - searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); - searchHits[i].score(1.0f); - } - SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); - SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); - - personalizeResponseProcessor.processResponse(searchRequest, searchResponse); - } - - public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exception { - PersonalizeClient mockClient = mock(PersonalizeClient.class); - - PersonalizeRankingResponseProcessor.Factory factory - = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> mockClient); - - Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); - - PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration,null); - - Map personalizeContext = new HashMap<>(); - personalizeContext.put("contextKey2", 5); - PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); - PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); - extBuilder.setRequestParameters(personalizeRequestParams); - - SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() - .ext(List.of(extBuilder)); - - SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); - SearchHit[] searchHits = new SearchHit[10]; - for (int i = 0; i < searchHits.length; i++) { - searchHits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); - searchHits[i].score(1.0f); - } - SearchHits hits = new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f); - SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 1, new ShardSearchFailure[0], null); - - personalizeResponseProcessor.processResponse(searchRequest, searchResponse); - } -} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java index 5e42712..32b4a61 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java @@ -9,6 +9,7 @@ package org.opensearch.search.relevance.transformer.personalizeintelligentranking.ranker.impl; import org.mockito.Mockito; +import org.opensearch.OpenSearchParseException; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; @@ -19,7 +20,6 @@ import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.SearchTestUtil; import org.opensearch.test.OpenSearchTestCase; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -99,8 +99,8 @@ public void testReRankWithInvalidRequestParameterContext() throws IOException { requestParameters.setUserId("28"); requestParameters.setContext(context); SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + expectThrows(OpenSearchParseException.class, () -> + ranker.rerank(responseHits, requestParameters)); } public void testReRankWithNoUserId() throws IOException { @@ -115,8 +115,8 @@ public void testReRankWithNoUserId() throws IOException { PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setContext(context); SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - assertEquals(responseHits.getHits().length, transformedHits.getHits().length); + expectThrows(OpenSearchParseException.class, () -> + ranker.rerank(responseHits, requestParameters)); } public void testReRankWithEmptyItemIdField() throws IOException { @@ -148,7 +148,6 @@ public void testReRankWithNullItemIdField() throws IOException { assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } - public void testReRankWithWeightAsZero() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 0); @@ -170,12 +169,10 @@ public void testReRankWithWeightAsZero() throws IOException { .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); - + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 0); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); } - public void testReRankWithWeightAsOne() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 1); @@ -198,12 +195,11 @@ public void testReRankWithWeightAsOne() throws IOException { .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOne(numOfHits); - + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 1); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); } - public void testReRankWithWeightAsNietherZeroOrOne() throws IOException { + public void testReRankWithWeightAsNeitherZeroOrOne() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); PersonalizeClient client = Mockito.mock(PersonalizeClient.class); @@ -224,69 +220,18 @@ public void testReRankWithWeightAsNietherZeroOrOne() throws IOException { .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) .collect(Collectors.toList()); - ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOne(numOfHits); - ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); + ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 1); + ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 0); assertNotEquals(rerankedDocumentIdsWhenWeightIsOne, rerankedDocumentIds); assertNotEquals(rerankedDocumentIdsWhenWeightIsZero, rerankedDocumentIds); } - public void testReRankWithWeightAsGreaterThanOne() throws IOException { - PersonalizeIntelligentRankerConfiguration rankerConfig = - new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 2); - PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); - - AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); - PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); - requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - - List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); - List rerankedDocumentIds; - - rerankedDocumentIds = originalHits.stream() - .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) - .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) - .collect(Collectors.toList()); - - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); - assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); - } - - public void testReRankWithWeightAsLessThanZero() throws IOException { - PersonalizeIntelligentRankerConfiguration rankerConfig = - new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, -1); - PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); - - AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); - PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); - requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - - List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); - List rerankedDocumentIds; - - rerankedDocumentIds = originalHits.stream() - .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) - .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) - .collect(Collectors.toList()); - - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); - assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); - } - - public void testReRankWithWeightAsZeroWithNullItemIdField() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 0); PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); @@ -295,7 +240,6 @@ public void testReRankWithWeightAsZeroWithNullItemIdField() throws IOException { SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); List rerankedDocumentIds; rerankedDocumentIds = originalHits.stream() @@ -303,8 +247,7 @@ public void testReRankWithWeightAsZeroWithNullItemIdField() throws IOException { .map(h -> h.getId()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); - + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 0); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); } @@ -313,7 +256,7 @@ public void testReRankWithWeightAsOneWithNullItemIdField() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 1); PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); @@ -323,7 +266,6 @@ public void testReRankWithWeightAsOneWithNullItemIdField() throws IOException { SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); List rerankedDocumentIds; rerankedDocumentIds = originalHits.stream() @@ -331,16 +273,15 @@ public void testReRankWithWeightAsOneWithNullItemIdField() throws IOException { .map(h -> h.getId()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(numOfHits); - + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 1); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); } - public void testReRankWithWeightAsNietherZeroOrOneWithNullItemIdField() throws IOException { + public void testReRankWithWeightAsNeitherZeroOrOneWithNullItemIdField() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, weight); PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); @@ -349,7 +290,6 @@ public void testReRankWithWeightAsNietherZeroOrOneWithNullItemIdField() throws I SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); List rerankedDocumentIds; rerankedDocumentIds = originalHits.stream() @@ -357,61 +297,10 @@ public void testReRankWithWeightAsNietherZeroOrOneWithNullItemIdField() throws I .map(h -> h.getId()) .collect(Collectors.toList()); - ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(numOfHits); - ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); + ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 1); + ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numOfHits, 0); assertNotEquals(rerankedDocumentIdsWhenWeightIsOne, rerankedDocumentIds); assertNotEquals(rerankedDocumentIdsWhenWeightIsZero, rerankedDocumentIds); } - - public void testReRankWithWeightAsGreaterThanOneWithNullItemIdField() throws IOException { - PersonalizeIntelligentRankerConfiguration rankerConfig = - new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 2); - PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); - - AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); - PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); - requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - - List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); - List rerankedDocumentIds; - - rerankedDocumentIds = originalHits.stream() - .filter(h -> h.getId() != null) - .map(h -> h.getId()) - .collect(Collectors.toList()); - - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); - assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); - } - - public void testReRankWithWeightAsLessThanZeroWithNullItemIdField() throws IOException { - PersonalizeIntelligentRankerConfiguration rankerConfig = - new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, -1); - PersonalizeClient client = Mockito.mock(PersonalizeClient.class); - Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); - - AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); - PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); - requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); - SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); - - List originalHits = Arrays.asList(transformedHits.getHits()); - String itemIdfield = rankerConfig.getItemIdField(); - List rerankedDocumentIds; - - rerankedDocumentIds = originalHits.stream() - .filter(h -> h.getId() != null) - .map(h -> h.getId()) - .collect(Collectors.toList()); - - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); - assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); - } - } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java index 4ca4b34..4028c3d 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java @@ -10,9 +10,14 @@ import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; import com.amazonaws.services.personalizeruntime.model.PredictedItem; +import org.mockito.Mockito; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; + +import static org.mockito.ArgumentMatchers.any; public class PersonalizeRuntimeTestUtil { @@ -46,51 +51,31 @@ public static GetPersonalizedRankingResult buildGetPersonalizedRankingResult(int return result; } - - public static GetPersonalizedRankingResult buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(int numOfHits) { - List predictedItems = new ArrayList<>(); - for(int i = numOfHits; i >= 1; i--){ - PredictedItem predictedItem = new PredictedItem(). - withScore((double) i/10). - withItemId("doc"+ (i - 1)); - predictedItems.add(predictedItem); - } - - GetPersonalizedRankingResult result = new GetPersonalizedRankingResult() - .withPersonalizedRanking(predictedItems) - .withRecommendationId("sampleRecommendationId"); - return result; - } - - public static ArrayList expectedRankedItemIdsWhenWeightIsOne(int numOfHits){ - ArrayList expectedRankedItemIds = new ArrayList<>(); - for(int i = numOfHits; i >= 1; i--){ - expectedRankedItemIds.add(String.valueOf(i-1)); - } - return expectedRankedItemIds; - } - - public static ArrayList expectedRankedItemIdsWhenWeightIsZero(int numOfHits){ + public static ArrayList expectedRankedItemIdsForGivenWeight(int numOfHits, int weight){ ArrayList expectedRankedItemIds = new ArrayList<>(); - for(int i = 0; i = 1; i--){ + expectedRankedItemIds.add(String.valueOf(i-1)); + } } return expectedRankedItemIds; } - public static ArrayList expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(int numOfHits){ - ArrayList expectedRankedItemIds = new ArrayList<>(); - for(int i = numOfHits; i >= 1; i--){ - expectedRankedItemIds.add("doc" + (i - 1)); - } - return expectedRankedItemIds; + public static PersonalizeClient buildMockPersonalizeClient() { + return buildMockPersonalizeClient(r -> buildGetPersonalizedRankingResult(10)); } - public static ArrayList expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(int numOfHits){ - ArrayList expectedRankedItemIds = new ArrayList<>(); - for(int i = 0; i mockGetPersonalizedRankingImpl) { + PersonalizeClient personalizeClient = Mockito.mock(PersonalizeClient.class); + Mockito.doAnswer(invocation -> { + GetPersonalizedRankingRequest request = invocation.getArgument(0); + return mockGetPersonalizedRankingImpl.apply(request); + }).when(personalizeClient).getPersonalizedRanking(any(GetPersonalizedRankingRequest.class)); + return personalizeClient; } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java index 945788d..2fa1b5b 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/SearchTestUtil.java @@ -8,31 +8,48 @@ package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParametersExtBuilder; import java.io.IOException; +import java.util.List; import java.util.Map; public class SearchTestUtil { public static SearchHits getSampleSearchHitsForPersonalize(int numHits) throws IOException { SearchHit[] hitsArray = new SearchHit[numHits]; + float maxScore = 0.0f; for (int i = 0; i < numHits; i++) { XContentBuilder sourceContent = JsonXContent.contentBuilder() .startObject() - .field("_id", String.valueOf(i)) .field("ITEM_ID", String.valueOf(i)) .field("body", "Body text for document number " + i) .field("title", "This is the title for document " + i) .endObject(); - hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); - hitsArray[i].score((float) (numHits-i)/10); + hitsArray[i] = new SearchHit(i, String.valueOf(i), Map.of(), Map.of()); + float score = (float)(numHits-i)/10; + maxScore = Math.max(score, maxScore); + hitsArray[i].score(score); hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); } - SearchHits searchHits = new SearchHits(hitsArray, new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchHits searchHits = new SearchHits(hitsArray, new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), maxScore); return searchHits; } + + public static SearchRequest createSearchRequestWithPersonalizeRequest(PersonalizeRequestParameters personalizeRequestParams) { + PersonalizeRequestParametersExtBuilder extBuilder = new PersonalizeRequestParametersExtBuilder(); + extBuilder.setRequestParameters(personalizeRequestParams); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource() + .ext(List.of(extBuilder)); + + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + return searchRequest; + } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java new file mode 100644 index 0000000..537d78f --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils; + +import com.amazonaws.http.IdleConnectionReaper; +import org.opensearch.OpenSearchParseException; +import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; +import org.opensearch.test.OpenSearchTestCase; + +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; + +public class ValidationUtilTests extends OpenSearchTestCase { + + private static final String TYPE = "personalize_ranking"; + private static final String TAG = "test_tag"; + private String personalizeCampaign = "arn:aws:personalize:us-west-2:000000000000:campaign/test-campaign"; + private String iamRoleArn = "arn:aws:iam::000000000000:role/test"; + private String itemIdField = "ITEM_ID"; + private String region = "us-west-2"; + private double weight = 1.0; + + public void testValidRankerConfig () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG); + } + + public void testInvalidCampaignArn () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration("invalid:campaign/test", iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testEmptyCampaignArn () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration("", iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testNonPersonalizeArnAsCampaignArn () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration("arn:aws:es:us-west-2:000000000000:domain/testmovies", iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testInvalidIamRoleArn () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, "invalid:arn/test", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testNonIamArnAsIamRoleArn () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, "arn:aws:es:us-west-2:000000000000:domain/testmovies", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testEmptyIamRoleArnAllowed () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, "", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight); + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG); + } + + public void testInvalidWeightValueGreaterThanRange () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, 3.0); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testInvalidWeightValueLessThanRange () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, -1.0); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } + + public void testNonPersonalizedRankingRecipeConfig () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, "aws-user-personalization", itemIdField, region, -1.0); + expectThrows(OpenSearchParseException.class, () -> + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG)); + } +}