diff --git a/backend/arxivsearch/api/routes/papers.py b/backend/arxivsearch/api/routes/papers.py index 3c171a2..697f751 100644 --- a/backend/arxivsearch/api/routes/papers.py +++ b/backend/arxivsearch/api/routes/papers.py @@ -1,18 +1,15 @@ import asyncio import numpy as np import logging -from fastapi import FastAPI -from contextlib import asynccontextmanager from fastapi import APIRouter, Query, Depends -from redis.asyncio import Redis from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery, FilterQuery, CountQuery from arxivsearch import config from arxivsearch.db import redis_helpers -from arxivsearch.utils.embeddings import Embeddings +from arxivsearch.utils.embeddings import embeddings from arxivsearch.schema.similarity import ( PaperSimilarityRequest, UserTextSimilarityRequest, @@ -20,35 +17,9 @@ VectorSearchResponse, ) - logger = logging.getLogger(__name__) -# class DB: -# client = None -# schema = None -# index = None - - -# db = DB() - -# Initialize embeddings once -embeddings = Embeddings() - -# client = Redis.from_url(config.REDIS_URL) -# schema = redis_helpers.get_schema() -# index = AsyncSearchIndex(schema, client) - - -# @asynccontextmanager -# async def lifespan(app: FastAPI): -# db.client = Redis.from_url(config.REDIS_URL) -# db.schema = redis_helpers.get_schema() -# db.index = AsyncSearchIndex(db.schema, db.client) -# yield -# db.client.aclose() - - # Initialize the API router router = APIRouter() diff --git a/backend/arxivsearch/db/redis_helpers.py b/backend/arxivsearch/db/redis_helpers.py index e1c0259..fdff8f3 100644 --- a/backend/arxivsearch/db/redis_helpers.py +++ b/backend/arxivsearch/db/redis_helpers.py @@ -1,28 +1,19 @@ import os import logging from typing import List -from redis.asyncio import Redis, ConnectionPool +from redis.asyncio import Redis from arxivsearch import config from redisvl.schema import IndexSchema from redisvl.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Tag, FilterExpression -from contextlib import asynccontextmanager logger = logging.getLogger(__name__) -async def get_async_client(): - async with Redis.from_url(config.REDIS_URL) as session: - yield session - await session.aclose() - - -print("\n getting in pool \n") dir_path = os.path.dirname(os.path.realpath(__file__)) schema = IndexSchema.from_yaml(os.path.join(dir_path, "index.yaml")) client = Redis.from_url(config.REDIS_URL) global_index = None -# client = get_async_client( def get_schema(): @@ -35,33 +26,11 @@ def get_index(): return SearchIndex.from_yaml(os.path.join(dir_path, "index.yaml")) -# async def get_async_client(): -# return Redis.from_url(config.REDIS_URL) - - async def get_async_index(): - try: - # schema = IndexSchema.from_yaml(os.path.join(dir_path, "index.yaml")) - # client = Redis.from_url(config.REDIS_URL) - global global_index - if not global_index: - global_index = AsyncSearchIndex(schema, client) - yield global_index - # yield AsyncSearchIndex(schema, client) - - finally: - # await global_index.client.aclose() - pass - # yield AsyncSearchIndex(schema, client) - # await client.aclose() - # async with Redis.from_pool(pool) as session: - # print("using session") - # index = AsyncSearchIndex(schema, session) - # yield index - # await index.client.aclose() - - # yield index - # await index.client.aclose() + global global_index + if not global_index: + global_index = AsyncSearchIndex(schema, client) + yield global_index def build_filter_expression( diff --git a/backend/arxivsearch/schema/similarity.py b/backend/arxivsearch/schema/similarity.py index 5cdacea..8f9414c 100644 --- a/backend/arxivsearch/schema/similarity.py +++ b/backend/arxivsearch/schema/similarity.py @@ -20,7 +20,7 @@ class UserTextSimilarityRequest(BaseRequest): class Paper(BaseModel): - paper_id: str = Field(alias="id") + paper_id: str # = Field(alias="id") authors: str categories: str year: str diff --git a/backend/arxivsearch/tests/utils/seed.py b/backend/arxivsearch/tests/utils/seed.py index 869704c..d678b76 100644 --- a/backend/arxivsearch/tests/utils/seed.py +++ b/backend/arxivsearch/tests/utils/seed.py @@ -20,5 +20,5 @@ def seed_test_db(): index = redis_helpers.get_index() index.connect(redis_url=config.REDIS_URL) - index.load(data=papers, id_field="id") + index.load(data=papers, id_field="paper_id") return papers diff --git a/backend/arxivsearch/tests/utils/test_vectors.json b/backend/arxivsearch/tests/utils/test_vectors.json index 01e2119..8b1e626 100644 --- a/backend/arxivsearch/tests/utils/test_vectors.json +++ b/backend/arxivsearch/tests/utils/test_vectors.json @@ -1,6 +1,6 @@ [ { - "id": "1234.5678", + "paper_id": "1234.5678", "title": "Exploring the Universe of Deep Learning", "authors": "Jane Doe", "year": "3000", @@ -3342,7 +3342,7 @@ ] }, { - "id": "8765.4321", + "paper_id": "8765.4321", "title": "Exploring the Galaxy of Deep Learning", "authors": "John Doe", "year": "2021", diff --git a/backend/arxivsearch/utils/embeddings.py b/backend/arxivsearch/utils/embeddings.py index e60d752..475fa23 100644 --- a/backend/arxivsearch/utils/embeddings.py +++ b/backend/arxivsearch/utils/embeddings.py @@ -63,3 +63,6 @@ async def get(self, provider: str, text: str): return self.co_vectorizer.embed( text, input_type="search_query", preprocess=preprocess_text ) + + +embeddings = Embeddings() diff --git a/frontend/src/api.ts b/frontend/src/api.ts index d89965e..19cdb5d 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -49,12 +49,14 @@ export const getPapers = async (limit = 15, skip = 0, years: string[] = [], cate // get papers from Redis through the FastAPI backend -export const getSemanticallySimilarPapers = async (paper_id: string, +export const getSemanticallySimilarPapers = async ( + paper_id: string, years: string[], categories: string[], provider: string, search = 'KNN', - limit = 15) => { + limit = 15 +) => { console.log(paper_id); let body = { diff --git a/frontend/src/config/index.tsx b/frontend/src/config/index.tsx index d6b0033..033684a 100644 --- a/frontend/src/config/index.tsx +++ b/frontend/src/config/index.tsx @@ -1,2 +1,3 @@ export const BASE_URL: string = ''; export const MASTER_URL: string = '/api/v1/papers/'; +export const EMAIL = "applied.ai@redis.com" diff --git a/frontend/src/views/Footer.tsx b/frontend/src/views/Footer.tsx index d81be68..c76d683 100644 --- a/frontend/src/views/Footer.tsx +++ b/frontend/src/views/Footer.tsx @@ -1,4 +1,5 @@ /* eslint-disable jsx-a11y/anchor-is-valid */ +import { EMAIL } from '../config' import '../styles/Footer.css'; export const Footer = () => { @@ -8,7 +9,6 @@ export const Footer = () => {
All Redis software used in this demo is licensed according to the Redis Stack License.
-
-
Redis AI Resources @@ -26,6 +26,7 @@ export const Footer = () => { Vector Search Docs
+
contact: {EMAIL}
); diff --git a/frontend/src/views/Header.tsx b/frontend/src/views/Header.tsx index 46b5514..2de3e0c 100644 --- a/frontend/src/views/Header.tsx +++ b/frontend/src/views/Header.tsx @@ -1,7 +1,7 @@ -import { BASE_URL } from "../config"; +import { BASE_URL, EMAIL } from "../config"; +import Tooltip from '@mui/material/Tooltip'; import '../styles/Header.css'; - /* eslint-disable jsx-a11y/anchor-is-valid */ export const Header = () => { return ( @@ -25,7 +25,9 @@ export const Header = () => { className="header-icon-link" > - Talk with us! + + Talk with us! + diff --git a/frontend/src/views/Home.tsx b/frontend/src/views/Home.tsx index 10bd321..fc98518 100644 --- a/frontend/src/views/Home.tsx +++ b/frontend/src/views/Home.tsx @@ -107,12 +107,11 @@ export const Home = (props: Props) => { const { target: { value }, } = event; + setSkip(0); setYears( // On autofill we get a stringified value. typeof value === 'string' ? value.split(',') : value, - ); - setSkip(0); - console.log(years) + ) }; return ( @@ -139,7 +138,7 @@ export const Home = (props: Props) => { } function CategoryOptions() { - const handleChange = (event: SelectChangeEvent) => { + const handleChange = (event: SelectChangeEvent) => { const { target: { value }, } = event; @@ -148,7 +147,6 @@ export const Home = (props: Props) => { typeof value === 'string' ? value.split(',') : value, ); setSkip(0); - console.log(years) }; return ( @@ -198,13 +196,13 @@ export const Home = (props: Props) => { } }; - // Execute this one when the component loads up useEffect(() => { - setPapers([]); - setCategories([]); - setYears([]); queryPapers(); - }, []); + }, [categories]) + + useEffect(() => { + queryPapers(); + }, [years]) return ( <> @@ -257,6 +255,7 @@ export const Home = (props: Props) => {
{papers.map((paper) => (