Skip to content

Commit 3fd36e3

Browse files
ThejasThejas
Thejas
authored and
Thejas
committed
Add server.py, .env and dependencies
1 parent be4271f commit 3fd36e3

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

.env

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
INDEX_ROOT=""
2+
INDEX_NAME=""
3+
PORT="8893"

conda_env.yml

+2
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ dependencies:
2626
- tqdm
2727
- transformers
2828
- ujson
29+
- flask
30+
- python-dotenv

conda_env_cpu.yml

+2
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ dependencies:
1818
- tqdm
1919
- transformers
2020
- ujson
21+
- flask
22+
- python-dotenv

server.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from flask import Flask, render_template, request
2+
from functools import lru_cache
3+
import math
4+
import os
5+
from dotenv import load_dotenv
6+
7+
from colbert.infra import Run, RunConfig, ColBERTConfig
8+
from colbert import Searcher
9+
10+
load_dotenv()
11+
12+
INDEX_NAME = os.getenv("INDEX_NAME")
13+
INDEX_ROOT = os.getenv("INDEX_ROOT")
14+
app = Flask(__name__)
15+
16+
searcher = Searcher(index=f"{INDEX_ROOT}/{INDEX_NAME}")
17+
counter = {"api" : 0}
18+
19+
@lru_cache(maxsize=1000000)
20+
def api_search_query(query, k):
21+
print(f"Query={query}")
22+
if k == None: k = 10
23+
k = min(int(k), 100)
24+
pids, ranks, scores = searcher.search(query, k=100)
25+
pids, ranks, scores = pids[:k], ranks[:k], scores[:k]
26+
passages = [searcher.collection[pid] for pid in pids]
27+
probs = [math.exp(score) for score in scores]
28+
probs = [prob / sum(probs) for prob in probs]
29+
topk = []
30+
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
31+
text = searcher.collection[pid]
32+
d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob}
33+
topk.append(d)
34+
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
35+
return {"query" : query, "topk": topk}
36+
37+
@app.route("/api/search", methods=["GET"])
38+
def api_search():
39+
if request.method == "GET":
40+
counter["api"] += 1
41+
print("API request count:", counter["api"])
42+
return api_search_query(request.args.get("query"), request.args.get("k"))
43+
else:
44+
return ('', 405)
45+
46+
if __name__ == "__main__":
47+
app.run("0.0.0.0", int(os.getenv("PORT")))
48+

0 commit comments

Comments
 (0)