Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 43 additions & 54 deletions dti_reviewer/backend/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from flask import Blueprint, request, jsonify
from flask import Blueprint, request, jsonify, current_app
from tasks import query_experts_task
from celery.result import AsyncResult
from celery_app import celery
Expand All @@ -16,65 +16,40 @@ def search():
Search for experts based on a research abstract.
The endpoint processes the input query and returns a ranked list of experts
whose work most closely matches the research abstract.
"""
client_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr)
logger.info(f"Search request received from {client_ip}")

"""
try:
data = request.get_json()
data = request.get_json() or {}
logger.debug(f"Request data received: {data}")

query = data.get("query", None) if data else None

if query is None:
logger.warning(f"Missing query parameter in request from {client_ip}")
return (
jsonify(message="Missing 'query' parameter", results=[]),
400,
)

query = query.strip()
logger.debug(f"Query after stripping: '{query}' (length: {len(query)})")

if not query:
logger.warning(f"Empty query received from {client_ip}")
return (
jsonify(message="Query cannot be empty", results=[]),
400,
)

query = (data.get("query") or "").strip()
engine_id = (data.get("engine_id") or "").strip()

for param in [query, engine_id]:
if not param:
logger.warning(f"Missing or empty {param} parameter in request")
return jsonify(message=f"Missing '{param}' parameter", results=[]), 400

if len(query) < 3:
logger.warning(
f"Query too short from {client_ip}: '{query}' (length: {len(query)})"
)
return (
jsonify(message="Query too short. Minimum 3 characters", results=[]),
400,
)

logger.info(
f"Processing valid query from {client_ip}: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)

celery_task = query_experts_task.delay(query, top_n=25)
logger.warning(f"Query too short: '{query}' (length: {len(query)})")
return jsonify(message="Query too short. Minimum 3 characters", results=[]), 400

available_engines = current_app.config.get('available_engines', [])
if engine_id not in available_engines:
logger.warning(f"Invalid engine_id: {engine_id}")
return jsonify(
message="Invalid engine_id.",
results=[]
), 400

celery_task = query_experts_task.delay(query, engine_id=engine_id, top_n=25)
task_id = celery_task.id

logger.info(
f"Celery task {task_id} submitted successfully for query from {client_ip}"
)

return (
jsonify(
message="Task submitted",
task_id=task_id,
),
202,
)

logger.info(f"Celery task {task_id} submitted successfully")

return jsonify(message="Task submitted", task_id=task_id), 202

except Exception as e:
logger.error(
f"Error processing search request from {client_ip}: {str(e)}", exc_info=True
)
logger.error(f"Error processing search request: {str(e)}", exc_info=True)
return jsonify(message="Server error enqueuing task", task_id=None), 500


Expand Down Expand Up @@ -113,7 +88,6 @@ def task_status(task_id):
)
return jsonify(resp), 200

# Handle failure states (FAILURE, RETRY, REVOKED, etc.)
error_info = str(async_result.info) if async_result.info else "Unknown error"
resp["message"] = error_info
logger.error(f"Task {task_id} failed with state {state}: {error_info}")
Expand All @@ -127,3 +101,18 @@ def task_status(task_id):
return jsonify(
{"state": "ERROR", "message": "Error retrieving task status"}
), 500

@api_bp.route("/available_engines", methods=["GET"])
def available_engines():
"""
Get a list of available similarity engines.
"""
try:
resp = {
"engines": current_app.config.get('available_engines', []),
"message": "Available similarity engines retrieved successfully"
}
return jsonify(resp), 200
except Exception as e:
logger.error(f"Error retrieving available engines: {str(e)}", exc_info=True)
return jsonify(message="Error retrieving available engines"), 500
2 changes: 2 additions & 0 deletions dti_reviewer/backend/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask import Flask
from flask_cors import CORS
from similarity_engine import list_all_engines

def create_app(test_config=None):
# create and configure the app
Expand All @@ -8,6 +9,7 @@ def create_app(test_config=None):
if test_config is None:
# load the instance config, if it exists, when not testing
app.config.from_pyfile('config.py', silent=True)
app.config['available_engines'] = list_all_engines()
else:
# load the test config if passed in
app.config.from_mapping(test_config)
Expand Down
88 changes: 88 additions & 0 deletions dti_reviewer/backend/ggg.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9299e07d",
"metadata": {},
"outputs": [],
"source": [
"from similarity_engine import SimilarityEngineOrcid\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2726896",
"metadata": {},
"outputs": [],
"source": [
"# 0. Instantiate your engine once (so it loads the index, vectorizer, etc.)\n",
"engine = SimilarityEngineOrcid()\n",
"\n",
"# 1. Now benchmark using that instance:\n",
"import numpy as np\n",
"import time\n",
"\n",
"queries = [\"stellar mass\", \"supernovae\"]\n",
"\n",
"orig_times = []\n",
"faiss_times = []\n",
"\n",
"for q in queries:\n",
" # time original\n",
" t0 = time.perf_counter()\n",
" orig_idx = engine.query_experts(q) \n",
" dt_orig = time.perf_counter() - t0\n",
" orig_times.append(dt_orig)\n",
"\n",
" # time FAISS\n",
" t1 = time.perf_counter()\n",
" faiss_idx = engine.query_experts_with_faiss(q)\n",
" dt_faiss = time.perf_counter() - t1\n",
" faiss_times.append(dt_faiss)\n",
"\n",
" print(f\"Query '{q}': original = {dt_orig:.6f}s, FAISS = {dt_faiss:.6f}s\")\n",
"\n",
"# summary\n",
"print(\"\\nAverage timings:\")\n",
"print(f\" Original avg time: {np.mean(orig_times):.6f} s\")\n",
"print(f\" FAISS avg time: {np.mean(faiss_times):.6f} s\\n\")\n",
"\n",
"print(\"Sample top-5 (idx, score):\")\n",
"print(\" Original:\", list(zip(orig_idx[:5])))\n",
"print(\" FAISS: \", list(zip(faiss_idx[:5])))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56831229",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "deepreviewer",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
100 changes: 94 additions & 6 deletions dti_reviewer/backend/similarity_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from abc import ABC, abstractmethod

import faiss
import numpy as np
class BaseSimilarityEngine(ABC):
engine_id = "base_similarity_engine"
name = "BaseSimilarityEngine"
description = "Base class for similarity engines"
def __init__(self):
pass

@abstractmethod
def query_experts(self, query_text: str, top_n: int = 25):
raise NotImplementedError("Subclasses must implement this method")
def query_experts(self, query_text: str, top_n: int = 25, progress_callback=None):
pass

class SimilarityEngineOrcid(BaseSimilarityEngine):
"""
A class to handle the similarity engine for expert authors.
It builds and queries a TF-IDF index of author texts.
"""
engine_id = "orcid_similarity_engine"
name = "ORCID Similarity Engine"
description = "TF-IDF based similarity engine for ORCID authors"

def __init__(self):
self.dataset_path = Path("expert-data/LSPO_v1.h5")
Expand Down Expand Up @@ -78,14 +82,25 @@ def load_index_or_build(self):
with open(self.index_dir / "authors.pkl", "rb") as f:
self.authors = pickle.load(f)

def query_experts(self, query_text: str, top_n: int = 25):
def query_experts(self, query_text: str, top_n: int = 25, progress_callback=None):
self.load_index_or_build()

if progress_callback:
progress_callback(percent=0.10, message="Transforming query...")
q_vec = self.vectorizer.transform([query_text])

if progress_callback:
progress_callback(percent=0.20, message="Calculating similarities...")
sims = cosine_similarity(q_vec, self.tfidf_matrix).flatten()

if progress_callback:
progress_callback(percent=0.80, message="Identifying top experts...")
top_indices = sims.argsort()[::-1][:top_n]
top_authors = self.combined_texts.iloc[top_indices].copy()
top_authors["similarity"] = sims[top_indices]

if progress_callback:
progress_callback(percent=0.90, message="Fetching author information...")
author_info = (
self.authors[["@path", "author"]] # , 'doi']
# .dropna(subset=['doi']) # Remove missing DOIs
Expand All @@ -100,6 +115,9 @@ def query_experts(self, query_text: str, top_n: int = 25):
)
results = top_authors.merge(author_info, on="@path", how="left")
results = results[["@path", "author", "similarity"]]

if progress_callback:
progress_callback(percent=0.95, message="Resolving name variations...")
name_variations = (
self.authors[["@path", "author"]]
.dropna()
Expand All @@ -112,4 +130,74 @@ def query_experts(self, query_text: str, top_n: int = 25):
results = results.merge(name_variations, on="@path", how="left")
results = results.rename(columns={"@path": "orcid"})
return results

def query_experts_with_faiss(self, query_text: str, top_n: int = 25, progress_callback=None):
self.load_index_or_build()

# 1. Transform to TF-IDF (sparse) and convert to dense
if progress_callback:
progress_callback(percent=0.10, message="Transforming query...")
q_vec = self.vectorizer.transform([query_text]).toarray().astype(np.float32)
docs = self.tfidf_matrix.toarray().astype(np.float32)

# 2. Normalize both query and documents to unit length for cosine via IP
# (FAISS only supports L2 or inner-product directly)
faiss.normalize_L2(q_vec)
faiss.normalize_L2(docs)

# 3. (Re)build or load a FAISS index
d = docs.shape[1] # dimensionality
index = faiss.IndexFlatIP(d) # flat (exact) inner-product index
index.add(docs) # add all doc vectors

if progress_callback:
progress_callback(percent=0.20, message="Index built, searching top experts…")
D, I = index.search(q_vec, top_n) # D: similarities, I: indices

# 4. Gather results exactly as before
top_indices = I.flatten()
top_authors = self.combined_texts.iloc[top_indices].copy()
top_authors["similarity"] = D.flatten()

# …and then the same author-info and name-variation merges you already have:
if progress_callback:
progress_callback(percent=0.80, message="Fetching author information…")
author_info = (
self.authors[["@path", "author"]]
.groupby("@path")
.agg({"author": "first"})
.reset_index()
)
results = top_authors.merge(author_info, on="@path", how="left")

if progress_callback:
progress_callback(percent=0.90, message="Resolving name variations…")
name_variations = (
self.authors[["@path", "author"]]
.dropna()
.groupby("@path")["author"]
.apply(lambda names: list(sorted(set(names))))
.reset_index()
.rename(columns={"author": "name_variations"})
)
results = (
results
.merge(name_variations, on="@path", how="left")
.rename(columns={"@path": "orcid"})
)

if progress_callback:
progress_callback(percent=0.95, message="Done.")
return results

def list_all_engines():
available_engines = []
for cls in BaseSimilarityEngine.__subclasses__():
engine_data = {
"engine_id": cls.engine_id or cls.__name__.lower(),
"name": cls.name,
"description": cls.description
}
available_engines.append(engine_data)

return available_engines
Loading