diff --git a/dti_reviewer/backend/api.py b/dti_reviewer/backend/api.py index 42f7c40..702f7eb 100644 --- a/dti_reviewer/backend/api.py +++ b/dti_reviewer/backend/api.py @@ -1,5 +1,5 @@ import logging -from flask import Blueprint, request, jsonify +from flask import Blueprint, request, jsonify, current_app from tasks import query_experts_task from celery.result import AsyncResult from celery_app import celery @@ -16,65 +16,40 @@ def search(): Search for experts based on a research abstract. The endpoint processes the input query and returns a ranked list of experts whose work most closely matches the research abstract. - """ - client_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) - logger.info(f"Search request received from {client_ip}") - + """ try: - data = request.get_json() + data = request.get_json() or {} logger.debug(f"Request data received: {data}") - - query = data.get("query", None) if data else None - - if query is None: - logger.warning(f"Missing query parameter in request from {client_ip}") - return ( - jsonify(message="Missing 'query' parameter", results=[]), - 400, - ) - - query = query.strip() - logger.debug(f"Query after stripping: '{query}' (length: {len(query)})") - - if not query: - logger.warning(f"Empty query received from {client_ip}") - return ( - jsonify(message="Query cannot be empty", results=[]), - 400, - ) + + query = (data.get("query") or "").strip() + engine_id = (data.get("engine_id") or "").strip() + + for param in [query, engine_id]: + if not param: + logger.warning(f"Missing or empty {param} parameter in request") + return jsonify(message=f"Missing '{param}' parameter", results=[]), 400 if len(query) < 3: - logger.warning( - f"Query too short from {client_ip}: '{query}' (length: {len(query)})" - ) - return ( - jsonify(message="Query too short. Minimum 3 characters", results=[]), - 400, - ) - - logger.info( - f"Processing valid query from {client_ip}: '{query[:100]}{'...' if len(query) > 100 else ''}'" - ) - - celery_task = query_experts_task.delay(query, top_n=25) + logger.warning(f"Query too short: '{query}' (length: {len(query)})") + return jsonify(message="Query too short. Minimum 3 characters", results=[]), 400 + + available_engines = current_app.config.get('available_engines', []) + if engine_id not in available_engines: + logger.warning(f"Invalid engine_id: {engine_id}") + return jsonify( + message="Invalid engine_id.", + results=[] + ), 400 + + celery_task = query_experts_task.delay(query, engine_id=engine_id, top_n=25) task_id = celery_task.id - logger.info( - f"Celery task {task_id} submitted successfully for query from {client_ip}" - ) - - return ( - jsonify( - message="Task submitted", - task_id=task_id, - ), - 202, - ) - + logger.info(f"Celery task {task_id} submitted successfully") + + return jsonify(message="Task submitted", task_id=task_id), 202 + except Exception as e: - logger.error( - f"Error processing search request from {client_ip}: {str(e)}", exc_info=True - ) + logger.error(f"Error processing search request: {str(e)}", exc_info=True) return jsonify(message="Server error enqueuing task", task_id=None), 500 @@ -113,7 +88,6 @@ def task_status(task_id): ) return jsonify(resp), 200 - # Handle failure states (FAILURE, RETRY, REVOKED, etc.) error_info = str(async_result.info) if async_result.info else "Unknown error" resp["message"] = error_info logger.error(f"Task {task_id} failed with state {state}: {error_info}") @@ -127,3 +101,18 @@ def task_status(task_id): return jsonify( {"state": "ERROR", "message": "Error retrieving task status"} ), 500 + +@api_bp.route("/available_engines", methods=["GET"]) +def available_engines(): + """ + Get a list of available similarity engines. + """ + try: + resp = { + "engines": current_app.config.get('available_engines', []), + "message": "Available similarity engines retrieved successfully" + } + return jsonify(resp), 200 + except Exception as e: + logger.error(f"Error retrieving available engines: {str(e)}", exc_info=True) + return jsonify(message="Error retrieving available engines"), 500 \ No newline at end of file diff --git a/dti_reviewer/backend/app.py b/dti_reviewer/backend/app.py index a14a473..ca07067 100644 --- a/dti_reviewer/backend/app.py +++ b/dti_reviewer/backend/app.py @@ -1,5 +1,6 @@ from flask import Flask from flask_cors import CORS +from similarity_engine import list_all_engines def create_app(test_config=None): # create and configure the app @@ -8,6 +9,7 @@ def create_app(test_config=None): if test_config is None: # load the instance config, if it exists, when not testing app.config.from_pyfile('config.py', silent=True) + app.config['available_engines'] = list_all_engines() else: # load the test config if passed in app.config.from_mapping(test_config) diff --git a/dti_reviewer/backend/ggg.ipynb b/dti_reviewer/backend/ggg.ipynb new file mode 100644 index 0000000..7e1b685 --- /dev/null +++ b/dti_reviewer/backend/ggg.ipynb @@ -0,0 +1,88 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9299e07d", + "metadata": {}, + "outputs": [], + "source": [ + "from similarity_engine import SimilarityEngineOrcid\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2726896", + "metadata": {}, + "outputs": [], + "source": [ + "# 0. Instantiate your engine once (so it loads the index, vectorizer, etc.)\n", + "engine = SimilarityEngineOrcid()\n", + "\n", + "# 1. Now benchmark using that instance:\n", + "import numpy as np\n", + "import time\n", + "\n", + "queries = [\"stellar mass\", \"supernovae\"]\n", + "\n", + "orig_times = []\n", + "faiss_times = []\n", + "\n", + "for q in queries:\n", + " # time original\n", + " t0 = time.perf_counter()\n", + " orig_idx = engine.query_experts(q) \n", + " dt_orig = time.perf_counter() - t0\n", + " orig_times.append(dt_orig)\n", + "\n", + " # time FAISS\n", + " t1 = time.perf_counter()\n", + " faiss_idx = engine.query_experts_with_faiss(q)\n", + " dt_faiss = time.perf_counter() - t1\n", + " faiss_times.append(dt_faiss)\n", + "\n", + " print(f\"Query '{q}': original = {dt_orig:.6f}s, FAISS = {dt_faiss:.6f}s\")\n", + "\n", + "# summary\n", + "print(\"\\nAverage timings:\")\n", + "print(f\" Original avg time: {np.mean(orig_times):.6f} s\")\n", + "print(f\" FAISS avg time: {np.mean(faiss_times):.6f} s\\n\")\n", + "\n", + "print(\"Sample top-5 (idx, score):\")\n", + "print(\" Original:\", list(zip(orig_idx[:5])))\n", + "print(\" FAISS: \", list(zip(faiss_idx[:5])))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56831229", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deepreviewer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dti_reviewer/backend/similarity_engine.py b/dti_reviewer/backend/similarity_engine.py index 0d92547..8f0ab1c 100644 --- a/dti_reviewer/backend/similarity_engine.py +++ b/dti_reviewer/backend/similarity_engine.py @@ -5,21 +5,25 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from abc import ABC, abstractmethod - +import faiss +import numpy as np class BaseSimilarityEngine(ABC): + engine_id = "base_similarity_engine" name = "BaseSimilarityEngine" description = "Base class for similarity engines" - def __init__(self): - pass + @abstractmethod - def query_experts(self, query_text: str, top_n: int = 25): - raise NotImplementedError("Subclasses must implement this method") + def query_experts(self, query_text: str, top_n: int = 25, progress_callback=None): + pass class SimilarityEngineOrcid(BaseSimilarityEngine): """ A class to handle the similarity engine for expert authors. It builds and queries a TF-IDF index of author texts. """ + engine_id = "orcid_similarity_engine" + name = "ORCID Similarity Engine" + description = "TF-IDF based similarity engine for ORCID authors" def __init__(self): self.dataset_path = Path("expert-data/LSPO_v1.h5") @@ -78,14 +82,25 @@ def load_index_or_build(self): with open(self.index_dir / "authors.pkl", "rb") as f: self.authors = pickle.load(f) - def query_experts(self, query_text: str, top_n: int = 25): + def query_experts(self, query_text: str, top_n: int = 25, progress_callback=None): self.load_index_or_build() + + if progress_callback: + progress_callback(percent=0.10, message="Transforming query...") q_vec = self.vectorizer.transform([query_text]) + + if progress_callback: + progress_callback(percent=0.20, message="Calculating similarities...") sims = cosine_similarity(q_vec, self.tfidf_matrix).flatten() + + if progress_callback: + progress_callback(percent=0.80, message="Identifying top experts...") top_indices = sims.argsort()[::-1][:top_n] top_authors = self.combined_texts.iloc[top_indices].copy() top_authors["similarity"] = sims[top_indices] + if progress_callback: + progress_callback(percent=0.90, message="Fetching author information...") author_info = ( self.authors[["@path", "author"]] # , 'doi'] # .dropna(subset=['doi']) # Remove missing DOIs @@ -100,6 +115,9 @@ def query_experts(self, query_text: str, top_n: int = 25): ) results = top_authors.merge(author_info, on="@path", how="left") results = results[["@path", "author", "similarity"]] + + if progress_callback: + progress_callback(percent=0.95, message="Resolving name variations...") name_variations = ( self.authors[["@path", "author"]] .dropna() @@ -112,4 +130,74 @@ def query_experts(self, query_text: str, top_n: int = 25): results = results.merge(name_variations, on="@path", how="left") results = results.rename(columns={"@path": "orcid"}) return results + + def query_experts_with_faiss(self, query_text: str, top_n: int = 25, progress_callback=None): + self.load_index_or_build() + + # 1. Transform to TF-IDF (sparse) and convert to dense + if progress_callback: + progress_callback(percent=0.10, message="Transforming query...") + q_vec = self.vectorizer.transform([query_text]).toarray().astype(np.float32) + docs = self.tfidf_matrix.toarray().astype(np.float32) + + # 2. Normalize both query and documents to unit length for cosine via IP + # (FAISS only supports L2 or inner-product directly) + faiss.normalize_L2(q_vec) + faiss.normalize_L2(docs) + + # 3. (Re)build or load a FAISS index + d = docs.shape[1] # dimensionality + index = faiss.IndexFlatIP(d) # flat (exact) inner-product index + index.add(docs) # add all doc vectors + + if progress_callback: + progress_callback(percent=0.20, message="Index built, searching top experts…") + D, I = index.search(q_vec, top_n) # D: similarities, I: indices + + # 4. Gather results exactly as before + top_indices = I.flatten() + top_authors = self.combined_texts.iloc[top_indices].copy() + top_authors["similarity"] = D.flatten() + + # …and then the same author-info and name-variation merges you already have: + if progress_callback: + progress_callback(percent=0.80, message="Fetching author information…") + author_info = ( + self.authors[["@path", "author"]] + .groupby("@path") + .agg({"author": "first"}) + .reset_index() + ) + results = top_authors.merge(author_info, on="@path", how="left") + + if progress_callback: + progress_callback(percent=0.90, message="Resolving name variations…") + name_variations = ( + self.authors[["@path", "author"]] + .dropna() + .groupby("@path")["author"] + .apply(lambda names: list(sorted(set(names)))) + .reset_index() + .rename(columns={"author": "name_variations"}) + ) + results = ( + results + .merge(name_variations, on="@path", how="left") + .rename(columns={"@path": "orcid"}) + ) + + if progress_callback: + progress_callback(percent=0.95, message="Done.") + return results + +def list_all_engines(): + available_engines = [] + for cls in BaseSimilarityEngine.__subclasses__(): + engine_data = { + "engine_id": cls.engine_id or cls.__name__.lower(), + "name": cls.name, + "description": cls.description + } + available_engines.append(engine_data) + return available_engines diff --git a/dti_reviewer/backend/tasks.py b/dti_reviewer/backend/tasks.py index a510af8..f5ba410 100644 --- a/dti_reviewer/backend/tasks.py +++ b/dti_reviewer/backend/tasks.py @@ -1,5 +1,5 @@ from celery_app import celery -from similarity_engine import SimilarityEngine +from similarity_engine import SimilarityEngineOrcid from sklearn.metrics.pairwise import cosine_similarity engine = None @@ -7,7 +7,7 @@ def initialize_similarity_engine(): global engine if engine is None: - engine = SimilarityEngine() + engine = SimilarityEngineOrcid() engine.load_index_or_build() @@ -55,3 +55,65 @@ def query_experts_task(self, query_text: str, top_n: int = 25): self.update_state(state="PROGRESS", meta={"percent": 1.0}) return results.to_dict(orient="records") + +import faiss +import numpy as np + +def query_experts_with_faiss(self, query_text: str, top_n: int = 25, progress_callback=None): + self.load_index_or_build() + + # 1. Transform to TF-IDF (sparse) and convert to dense + if progress_callback: + progress_callback(percent=0.10, message="Transforming query...") + q_vec = self.vectorizer.transform([query_text]).toarray().astype(np.float32) + docs = self.tfidf_matrix.toarray().astype(np.float32) + + # 2. Normalize both query and documents to unit length for cosine via IP + # (FAISS only supports L2 or inner-product directly) + faiss.normalize_L2(q_vec) + faiss.normalize_L2(docs) + + # 3. (Re)build or load a FAISS index + d = docs.shape[1] # dimensionality + index = faiss.IndexFlatIP(d) # flat (exact) inner-product index + index.add(docs) # add all doc vectors + + if progress_callback: + progress_callback(percent=0.20, message="Index built, searching top experts…") + D, I = index.search(q_vec, top_n) # D: similarities, I: indices + + # 4. Gather results exactly as before + top_indices = I.flatten() + top_authors = self.combined_texts.iloc[top_indices].copy() + top_authors["similarity"] = D.flatten() + + # …and then the same author-info and name-variation merges you already have: + if progress_callback: + progress_callback(percent=0.80, message="Fetching author information…") + author_info = ( + self.authors[["@path", "author"]] + .groupby("@path") + .agg({"author": "first"}) + .reset_index() + ) + results = top_authors.merge(author_info, on="@path", how="left") + + if progress_callback: + progress_callback(percent=0.90, message="Resolving name variations…") + name_variations = ( + self.authors[["@path", "author"]] + .dropna() + .groupby("@path")["author"] + .apply(lambda names: list(sorted(set(names)))) + .reset_index() + .rename(columns={"author": "name_variations"}) + ) + results = ( + results + .merge(name_variations, on="@path", how="left") + .rename(columns={"@path": "orcid"}) + ) + + if progress_callback: + progress_callback(percent=0.95, message="Done.") + return results diff --git a/dti_reviewer/docker-compose.yml b/dti_reviewer/docker-compose.yml index 0b8f022..c94b704 100644 --- a/dti_reviewer/docker-compose.yml +++ b/dti_reviewer/docker-compose.yml @@ -12,8 +12,6 @@ services: environment: - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/0 - # ports: - # - "5000:5000" depends_on: - redis restart: unless-stopped diff --git a/dti_reviewer/my-app/nginx/default.conf b/dti_reviewer/my-app/nginx/default.conf index d4e4b58..754b1b6 100644 --- a/dti_reviewer/my-app/nginx/default.conf +++ b/dti_reviewer/my-app/nginx/default.conf @@ -1,3 +1,5 @@ +env BACKEND_API_TOKEN; + server { listen 80; server_name _; @@ -13,24 +15,22 @@ server { # Proxy both the POST /search and the OPTIONS /search preflight location /search { + # forward to your backend proxy_pass http://backend:5000/search; proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # Allow websockets if you ever use them - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header Authorization "Bearer $BACKEND_API_TOKEN"; } # Proxy the status polling location ~ ^/status/ { proxy_pass http://backend:5000; proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header Authorization "Bearer $BACKEND_API_TOKEN"; } } diff --git a/dti_reviewer/my-app/package-lock.json b/dti_reviewer/my-app/package-lock.json index 23b8543..3254d7c 100644 --- a/dti_reviewer/my-app/package-lock.json +++ b/dti_reviewer/my-app/package-lock.json @@ -10,6 +10,7 @@ "dependencies": { "@radix-ui/react-dialog": "^1.1.14", "@radix-ui/react-navigation-menu": "^1.2.13", + "@radix-ui/react-select": "^2.2.5", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", "@radix-ui/react-tooltip": "^1.2.7", @@ -3313,6 +3314,12 @@ "node": ">= 8" } }, + "node_modules/@radix-ui/number": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.1.1.tgz", + "integrity": "sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==", + "license": "MIT" + }, "node_modules/@radix-ui/primitive": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.2.tgz", @@ -3534,6 +3541,46 @@ } } }, + "node_modules/@radix-ui/react-menu": { + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.15.tgz", + "integrity": "sha512-tVlmA3Vb9n8SZSd+YSbuFR66l87Wiy4du+YE+0hzKQEANA+7cWKH1WgqcEX4pXqxUFQKrWQGHdvEfw00TjFiew==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-focus-guards": "1.1.2", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-presence": "1.1.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "aria-hidden": "^1.2.4", + "react-remove-scroll": "^2.6.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-navigation-menu": { "version": "1.2.13", "resolved": "https://registry.npmjs.org/@radix-ui/react-navigation-menu/-/react-navigation-menu-1.2.13.tgz", @@ -3673,6 +3720,80 @@ } } }, + "node_modules/@radix-ui/react-roving-focus": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.10.tgz", + "integrity": "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-controllable-state": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-select": { + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.5.tgz", + "integrity": "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA==", + "license": "MIT", + "dependencies": { + "@radix-ui/number": "1.1.1", + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-focus-guards": "1.1.2", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-previous": "1.1.1", + "@radix-ui/react-visually-hidden": "1.2.3", + "aria-hidden": "^1.2.4", + "react-remove-scroll": "^2.6.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-separator": { "version": "1.1.7", "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz", diff --git a/dti_reviewer/my-app/package.json b/dti_reviewer/my-app/package.json index 79f94bb..ffe9df1 100644 --- a/dti_reviewer/my-app/package.json +++ b/dti_reviewer/my-app/package.json @@ -13,6 +13,7 @@ "dependencies": { "@radix-ui/react-dialog": "^1.1.14", "@radix-ui/react-navigation-menu": "^1.2.13", + "@radix-ui/react-select": "^2.2.5", "@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-slot": "^1.2.3", "@radix-ui/react-tooltip": "^1.2.7", diff --git a/dti_reviewer/my-app/src/components/pages/FormPage.tsx b/dti_reviewer/my-app/src/components/pages/FormPage.tsx index d38d8ab..e3d368c 100644 --- a/dti_reviewer/my-app/src/components/pages/FormPage.tsx +++ b/dti_reviewer/my-app/src/components/pages/FormPage.tsx @@ -1,11 +1,27 @@ -import { useState, useRef, type ChangeEvent } from "react" +import { useState, useRef, type ChangeEvent, useEffect } from "react" +import { Skeleton } from "@/components/ui/skeleton" import { Textarea } from "@/components/ui/textarea" import { Button } from "@/components/ui/button" import { ResultTable } from "../ResultTable" import logo from "../../assets/logo.png" import Search from "../../assets/search.svg" import NoResults from "../../assets/noResults.png" - +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "@/components/ui/select" + + +interface Engine { + engine_id: string + name: string; + description: string; +} const FormPage = () => { const [query, setQuery] = useState("") @@ -14,17 +30,40 @@ const FormPage = () => { const [hasSearched, setHasSearched] = useState(false) const [percent, setPercent] = useState(null) const taskIdRef = useRef(null) + const [allEngines, setAllEngines] = useState([]) + const [selectedEngineId, setSelectedEngineId] = useState("") + const [enginesLoading, setEnginesLoading] = useState(true); const handleQuery = (e: ChangeEvent): void => { setQuery(e.target.value) } - const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || "" + // const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || "" // uncomment this in development - // const API_BASE_URL = "http://localhost:5000" - + const API_BASE_URL = "http://localhost:5000" + + useEffect(() => { + const fetchAvailableEngines = async () => { + setEnginesLoading(true); + try { + const resp = await fetch(`${API_BASE_URL}/available_engines`) + if (!resp.ok) throw new Error(`Failed to fetch engines: HTTP ${resp.status}`) + const availableEngines = await resp.json() + const engines = availableEngines?.engines; + if (engines && Array.isArray(engines) && engines.length > 0) { + setAllEngines(engines) + setSelectedEngineId(engines[0].engine_id) + } + } catch (error) { + console.error("Error fetching available engines:", error) + } finally { + setEnginesLoading(false); + } + } + fetchAvailableEngines() + }, []) const handleSubmit = async (e: React.FormEvent) => { e.preventDefault() @@ -40,7 +79,7 @@ const FormPage = () => { const resp = await fetch(`${API_BASE_URL}/search`, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ query }), + body: JSON.stringify({ query, engine_id: selectedEngineId }) }) if (!resp.ok) throw new Error(`Enqueue failed: HTTP ${resp.status}`) const { task_id } = await resp.json() @@ -105,13 +144,42 @@ const FormPage = () => { className="mb-2 w-full min-h-[50px] md:min-h-[60px] resize-y max-h-[300px]" required /> - +
+ {enginesLoading ? ( + + ) : allEngines.length > 0 && ( + <> + + + + )} +
diff --git a/dti_reviewer/my-app/src/components/ui/select.tsx b/dti_reviewer/my-app/src/components/ui/select.tsx new file mode 100644 index 0000000..51f466e --- /dev/null +++ b/dti_reviewer/my-app/src/components/ui/select.tsx @@ -0,0 +1,183 @@ +import * as React from "react" +import * as SelectPrimitive from "@radix-ui/react-select" +import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from "lucide-react" + +import { cn } from "@/lib/utils" + +function Select({ + ...props +}: React.ComponentProps) { + return +} + +function SelectGroup({ + ...props +}: React.ComponentProps) { + return +} + +function SelectValue({ + ...props +}: React.ComponentProps) { + return +} + +function SelectTrigger({ + className, + size = "default", + children, + ...props +}: React.ComponentProps & { + size?: "sm" | "default" +}) { + return ( + + {children} + + + + + ) +} + +function SelectContent({ + className, + children, + position = "popper", + ...props +}: React.ComponentProps) { + return ( + + + + + {children} + + + + + ) +} + +function SelectLabel({ + className, + ...props +}: React.ComponentProps) { + return ( + + ) +} + +function SelectItem({ + className, + children, + ...props +}: React.ComponentProps) { + return ( + + + + + + + {children} + + ) +} + +function SelectSeparator({ + className, + ...props +}: React.ComponentProps) { + return ( + + ) +} + +function SelectScrollUpButton({ + className, + ...props +}: React.ComponentProps) { + return ( + + + + ) +} + +function SelectScrollDownButton({ + className, + ...props +}: React.ComponentProps) { + return ( + + + + ) +} + +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectScrollDownButton, + SelectScrollUpButton, + SelectSeparator, + SelectTrigger, + SelectValue, +} diff --git a/setup.cfg b/setup.cfg index 985a26f..cf21c60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ testpaths = "dti_reviewer" "docs" astropy_header = true doctest_plus = enabled text_file_format = rst -addopts = --doctest-rst +# addopts = --doctest-rst [coverage:run] omit =