From 041ca215883f503fa5e03b38d5283dc7ca775818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B0=B8=E5=BC=98?= Date: Wed, 26 May 2021 11:24:50 +0800 Subject: [PATCH] add retry for download index & support MipsSquaredEuclidean --- mars/learn/proxima/simple_index/searcher.py | 36 ++++++++++++++------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/mars/learn/proxima/simple_index/searcher.py b/mars/learn/proxima/simple_index/searcher.py index 8a4d88d11e..c5eca7a3bd 100644 --- a/mars/learn/proxima/simple_index/searcher.py +++ b/mars/learn/proxima/simple_index/searcher.py @@ -313,16 +313,30 @@ def _execute_download(cls, ctx, op: "ProximaSearcher"): exist_state = False if not os.path.exists(local_path.rsplit("/", 1)[0]): os.mkdir(local_path.rsplit("/", 1)[0]) - with open(local_path, 'wb') as out_f: - with fs.open(index_path, 'rb') as in_f: - # 32M - chunk_bytes = 32 * 1024 ** 2 - while True: - data = in_f.read(chunk_bytes) - if data: - out_f.write(data) - else: - break + + def read_index(): + with open(local_path, 'wb') as out_f: + with fs.open(index_path, 'rb') as in_f: + # 32M + chunk_bytes = 32 * 1024 ** 2 + while True: + data = in_f.read(chunk_bytes) + if data: + out_f.write(data) + else: + break + + # retry 3 times + for _ in range(3): + try: + read_index() + logger.warning(f"read success") + break + except: # noqa: E722 # nosec # pylint: disable=bare-except + logger.warning(f"read index file faild for times {_}") + os.remove(local_path) + logger.warning(f"remove {local_path} success") + continue logger.warning(f'ReadingFromVolume({op.key}), index path: {index_path}, ' f'local_path {local_path}' @@ -406,7 +420,7 @@ def _execute_agg(cls, ctx, op: "ProximaSearcher"): topk = op.topk # calculate topk on rows - if op.distance_metric == "InnerProduct": + if op.distance_metric == "InnerProduct" or op.distance_metric == "MipsSquaredEuclidean": inds = np.argsort(distances, axis=1)[:, -1:-topk - 1:-1] else: inds = np.argsort(distances, axis=1)[:, :topk]