|
| 1 | +from flask import Flask, render_template, request |
| 2 | +from functools import lru_cache |
| 3 | +import math |
| 4 | +import os |
| 5 | +from dotenv import load_dotenv |
| 6 | + |
| 7 | +from colbert.infra import Run, RunConfig, ColBERTConfig |
| 8 | +from colbert import Searcher |
| 9 | + |
| 10 | +load_dotenv() |
| 11 | + |
| 12 | +INDEX_NAME = os.getenv("INDEX_NAME") |
| 13 | +INDEX_ROOT = os.getenv("INDEX_ROOT") |
| 14 | +app = Flask(__name__) |
| 15 | + |
| 16 | +searcher = Searcher(index=f"{INDEX_ROOT}/{INDEX_NAME}") |
| 17 | +counter = {"api" : 0} |
| 18 | + |
| 19 | +@lru_cache(maxsize=1000000) |
| 20 | +def api_search_query(query, k): |
| 21 | + print(f"Query={query}") |
| 22 | + if k == None: k = 10 |
| 23 | + k = min(int(k), 100) |
| 24 | + pids, ranks, scores = searcher.search(query, k=100) |
| 25 | + pids, ranks, scores = pids[:k], ranks[:k], scores[:k] |
| 26 | + passages = [searcher.collection[pid] for pid in pids] |
| 27 | + probs = [math.exp(score) for score in scores] |
| 28 | + probs = [prob / sum(probs) for prob in probs] |
| 29 | + topk = [] |
| 30 | + for pid, rank, score, prob in zip(pids, ranks, scores, probs): |
| 31 | + text = searcher.collection[pid] |
| 32 | + d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} |
| 33 | + topk.append(d) |
| 34 | + topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) |
| 35 | + return {"query" : query, "topk": topk} |
| 36 | + |
| 37 | +@app.route("/api/search", methods=["GET"]) |
| 38 | +def api_search(): |
| 39 | + if request.method == "GET": |
| 40 | + counter["api"] += 1 |
| 41 | + print("API request count:", counter["api"]) |
| 42 | + return api_search_query(request.args.get("query"), request.args.get("k")) |
| 43 | + else: |
| 44 | + return ('', 405) |
| 45 | + |
| 46 | +if __name__ == "__main__": |
| 47 | + app.run("0.0.0.0", int(os.getenv("PORT"))) |
| 48 | + |
0 commit comments