diff --git a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java index eefea08..40105f9 100644 --- a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java +++ b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/configuration/Constants.java @@ -13,4 +13,5 @@ */ public class Constants { public static final String AMAZON_PERSONALIZED_RANKING_RECIPE_NAME = "aws-personalized-ranking"; + public static final String AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME = "aws-personalized-ranking-v2"; } diff --git a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java index 1e30ec9..ead8dd8 100644 --- a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java +++ b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/PersonalizedRankerFactory.java @@ -14,6 +14,7 @@ import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl.AmazonPersonalizedRankerImpl; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME; /** * Factory for creating Personalize ranker instance based on Personalize ranker configuration @@ -29,7 +30,9 @@ public class PersonalizedRankerFactory { */ public PersonalizedRanker getPersonalizedRanker(PersonalizeIntelligentRankerConfiguration config, PersonalizeClient client){ PersonalizedRanker ranker = null; - if (config.getRecipe().equals(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)) { + String recipeInConfig = config.getRecipe(); + if (recipeInConfig.equals(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME) + || recipeInConfig.equals(AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME)) { ranker = new AmazonPersonalizedRankerImpl(config, client); } else { logger.error("Personalize recipe provided in configuration is not supported for re ranking search results"); diff --git a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java index d92a974..c6c92f7 100644 --- a/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java +++ b/amazon-personalize-ranking/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtil.java @@ -15,9 +15,13 @@ import java.util.HashSet; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME; public class ValidationUtil { - private static Set SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)); + private static Set SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList( + AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, + AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME + )); /** * Validate Personalize configuration for calling Personalize service. diff --git a/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java b/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java index 7e7087b..f83bbcb 100644 --- a/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java +++ b/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessorTests.java @@ -39,6 +39,7 @@ import static org.mockito.Mockito.mock; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor.TYPE; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME; public class PersonalizeRankingResponseProcessorTests extends OpenSearchTestCase { @@ -284,6 +285,81 @@ public void testPersonalizeRankingResponse() throws Exception { IdleConnectionReaper.shutdown(); } + public void testPersonalizeRankingV2Response() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemField = "ITEM_ID"; + Map configuration = buildPersonalizeResponseProcessorConfig(); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); + + SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS); + + List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); + List rerankedDocumentIds; + rerankedDocumentIds = transformedHits.stream() + .filter(h -> h.getSourceAsMap().get(itemField) != null) + .map(h -> h.getSourceAsMap().get(itemField).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + IdleConnectionReaper.shutdown(); + } + + public void testPersonalizeRankingV2ResponseWithInvalidItemIdFieldName() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemFieldInvalid = "ITEM_ID_NOT_VALID"; + Map configuration = buildPersonalizeResponseProcessorConfig(); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME); + configuration.put("item_id_field", itemFieldInvalid); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); + + expectThrows(OpenSearchParseException.class, () -> + createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS)); + IdleConnectionReaper.shutdown(); + } + + public void testPersonalizeRankingV2ResponseWithDefaultItemIdField() throws Exception { + PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); + + PersonalizeRankingResponseProcessor.Factory factory + = new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient); + + String itemIdFieldEmpty = ""; + Map configuration = buildPersonalizeResponseProcessorConfig(); + configuration.put("item_id_field", itemIdFieldEmpty); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME); + + PersonalizeRankingResponseProcessor responseProcessor = + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); + + SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS); + + List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); + List rerankedDocumentIds; + rerankedDocumentIds = transformedHits.stream() + .map(SearchHit::getId) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + IdleConnectionReaper.shutdown(); + } + public void testPersonalizeRankingResponseWithInvalidItemIdFieldName() throws Exception { PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient(); diff --git a/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java b/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java index 537d78f..1227915 100644 --- a/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java +++ b/amazon-personalize-ranking/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/ValidationUtilTests.java @@ -13,6 +13,7 @@ import org.opensearch.test.OpenSearchTestCase; import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME; +import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME; public class ValidationUtilTests extends OpenSearchTestCase { @@ -30,6 +31,12 @@ public void testValidRankerConfig () { ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG); } + public void testValidRankerConfigPersonalizedRankingV2 () { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME, itemIdField, region, weight); + ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG); + } + public void testInvalidCampaignArn () { PersonalizeIntelligentRankerConfiguration rankerConfig = new PersonalizeIntelligentRankerConfiguration("invalid:campaign/test", iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight);