diff --git a/it/test_test_mode_as_param.py b/it/test_test_mode_as_param.py new file mode 100644 index 000000000..18f053534 --- /dev/null +++ b/it/test_test_mode_as_param.py @@ -0,0 +1,39 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +pytest_rally = pytest.importorskip("pytest_rally") + +BASE_PARAMS = {"test_mode": True} + + +def params(updates=None): + base = BASE_PARAMS.copy() + if updates is None: + return base + else: + return {**base, **updates} + + +class TestQueryWithTestModeParam: + def test_msmarco(self, es_cluster, rally): + ret = rally.race( + track="msmarco-v2-vector", + track_params=params(), + ) + assert ret == 0 diff --git a/msmarco-v2-vector/queries-recall-test.json.bz2 b/msmarco-v2-vector/queries-recall-test.json.bz2 new file mode 100644 index 000000000..75db84ea4 Binary files /dev/null and b/msmarco-v2-vector/queries-recall-test.json.bz2 differ diff --git a/msmarco-v2-vector/queries-test.json.bz2 b/msmarco-v2-vector/queries-test.json.bz2 new file mode 100644 index 000000000..07b1a79b3 Binary files /dev/null and b/msmarco-v2-vector/queries-test.json.bz2 differ diff --git a/msmarco-v2-vector/track.json b/msmarco-v2-vector/track.json index 31f829705..79abc9af1 100644 --- a/msmarco-v2-vector/track.json +++ b/msmarco-v2-vector/track.json @@ -7,6 +7,10 @@ "pytrec_eval==0.5", "numpy" ], + "parameters": { + "test-mode": {{ test_mode | default(false) |tojson}} + }, + "indices": [ { "name": "msmarco-v2", diff --git a/msmarco-v2-vector/track.py b/msmarco-v2-vector/track.py index f6e100460..fc1a25a42 100644 --- a/msmarco-v2-vector/track.py +++ b/msmarco-v2-vector/track.py @@ -1,11 +1,14 @@ import bz2 import csv import json +import logging import os import statistics from collections import defaultdict from typing import Any, Dict, List +from esrally.driver.runner import Runner + Qrels = Dict[str, Dict[str, int]] Results = Dict[str, Dict[str, float]] @@ -90,6 +93,7 @@ class KnnParamSource: def __init__(self, track, params, **kwargs): # choose a suitable index: if there is only one defined for this track # choose that one, but let the user always override index + self.logger = logging.getLogger(__name__) if len(track.indices) == 1: default_index = track.indices[0].name else: @@ -99,13 +103,22 @@ def __init__(self, track, params, **kwargs): self._cache = params.get("cache", False) self._params = params self._queries = [] - + self.test_mode = track.selected_challenge_or_default.parameters.get("test-mode", False) cwd = os.path.dirname(__file__) - with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file: + + if self.test_mode: + queries_filename = QUERIES_FILENAME.replace(".json.bz2", "-test.json.bz2") + if not os.path.exists(os.path.join(cwd, queries_filename)): + self.logger.warning("Test mode enabled but test queries file not found, using default queries file") + queries_filename = QUERIES_FILENAME + else: + queries_filename = QUERIES_FILENAME + + with bz2.open(os.path.join(cwd, queries_filename), "r") as queries_file: for vector_query in queries_file: self._queries.append(json.loads(vector_query)) self._iters = 0 - self._maxIters = len(self._queries) + self._max_iters = len(self._queries) self.infinite = True def partition(self, partition_index, total_partitions): @@ -131,7 +144,7 @@ def params(self): result["body"]["knn"]["filter"] = self._params["filter"] self._iters += 1 - if self._iters >= self._maxIters: + if self._iters >= self._max_iters: self._iters = 0 return result @@ -161,8 +174,9 @@ def params(self): } -class KnnRecallRunner: +class KnnRecallRunner(Runner): async def __call__(self, es, params): + self.logger = logging.getLogger(__name__) top_k = params["size"] num_candidates = params["num_candidates"] num_rescore = params["num_rescore"] @@ -177,7 +191,16 @@ async def __call__(self, es, params): exact_total = 0 min_recall = top_k nodes_visited = [] - with bz2.open(os.path.join(cwd, QUERIES_RECALL_FILENAME), "r") as queries_file: + if self.test_mode: + queries_recall_filename = QUERIES_RECALL_FILENAME.replace(".json.bz2", "-test.json.bz2") + if not os.path.exists(queries_recall_filename): + self.logger.warning( + "Test mode enabled but test queries file not found %s, using default queries file", queries_recall_filename + ) + queries_recall_filename = QUERIES_RECALL_FILENAME + else: + queries_recall_filename = QUERIES_RECALL_FILENAME + with bz2.open(os.path.join(cwd, queries_recall_filename), "r") as queries_file: for line in queries_file: query = json.loads(line) query_id = query["query_id"] @@ -233,6 +256,7 @@ def __repr__(self, *args, **kwargs): def register(registry): + config = registry.config registry.register_param_source("knn-param-source", KnnParamSource) registry.register_param_source("knn-recall-param-source", KnnRecallParamSource) - registry.register_runner("knn-recall", KnnRecallRunner(), async_runner=True) + registry.register_runner("knn-recall", KnnRecallRunner(config=config), async_runner=True)