Skip to content

Commit

Permalink
update function
Browse files Browse the repository at this point in the history
  • Loading branch information
rbs333 committed Jul 25, 2024
1 parent cafde63 commit 434b181
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 53 deletions.
13 changes: 0 additions & 13 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.11"]
# python-version: ["3.9", "3.10", "3.11"]
redis-stack-version: ['latest']
# python-version: [3.9, 3.10, 3.11] # idk if we need all of this for a demo repo
# connection: ['hiredis', 'plain']
# redis-stack-version: ['6.2.6-v9', 'latest', 'edge']

services:
redis:
Expand All @@ -46,15 +42,6 @@ jobs:
run: |
poetry install --all-extras
# - name: Install hiredis if needed
# if: matrix.connection == 'hiredis'
# run: |
# poetry add hiredis

- name: Set Redis version
run: |
echo "REDIS_VERSION=${{ matrix.redis-stack-version }}" >> $GITHUB_ENV
- name: Run tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"python.testing.pytestArgs": [],
"python.testing.pytestArgs": [
"backend"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.testing.cwd": "${workspaceFolder}/backend/",
Expand Down
18 changes: 6 additions & 12 deletions backend/arxivsearch/api/routes/papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def get_papers(
filter papers. Defaults to "".
Returns:
dict: Dictionary containing total count and list of papers.
SearchResponse: Pydantic model containing papers and total count.
"""

# Build queries
Expand All @@ -67,7 +67,6 @@ async def get_papers(
total_count, result_papers = await asyncio.gather(
index.query(count_query), index.query(filter_query)
)
# await index.client.aclose()
return SearchResponse(total=total_count, papers=result_papers)


Expand All @@ -81,12 +80,12 @@ async def find_papers_by_paper(
similarity.
Args:
similarity_request (SimilarityRequest): Similarity request object
PaperSimilarityRequest: Similarity request object
containing paper_id, provider, number_of_results, years, and
categories for filtering.
Returns:
dict: Dictionary containing total count and list of similar papers.
VectorSearchResponse: Pydantic model with paper content.
"""

# Fetch paper vector from the HASH, cast to numpy array
Expand All @@ -111,28 +110,25 @@ async def find_papers_by_paper(
total_count, result_papers = await asyncio.gather(
index.query(count_query), index.query(paper_similarity_query)
)
# Get Paper records of those results
# await index.client.aclose()
return VectorSearchResponse(total=total_count, papers=result_papers)


@router.post("/vector_search/by_text", response_model=VectorSearchResponse)
async def find_papers_by_text(
similarity_request: UserTextSimilarityRequest,
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
# embeddings: Embeddings = Depends(get_embeddings),
):
"""
Find and return papers similar to user-provided text based on
vector similarity.
Args:
similarity_request (UserTextSimilarityRequest): Similarity request
UserTextSimilarityRequest: Similarity request
object containing user_text, provider, number_of_results, years,
and categories for filtering.
Returns:
dict: Dictionary containing total count and list of similar papers.
VectorSearchResponse: Pydantic model with paper content.
"""

# Build filter expression
Expand All @@ -155,7 +151,5 @@ async def find_papers_by_text(
# Execute searches
total_count, result_papers = await asyncio.gather(
index.query(count_query), index.query(paper_similarity_query)
) # Get Paper records of those results

# await index.client.aclose()
)
return VectorSearchResponse(total=total_count, papers=result_papers)
26 changes: 16 additions & 10 deletions backend/arxivsearch/db/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
logger = logging.getLogger(__name__)


def read_from_s3(path):
res = requests.get(config.S3_DATA_URL)
data = res.json()

if os.path.isdir(config.DATA_LOCATION):
logger.info(f"Writing s3 file to {path}")
with open(path, "w") as f:
json.dump(data, f)
else:
logger.warning(
f"Data directory {config.DATA_LOCATION} not found. Skipping write of S3 data"
)
return data


def read_paper_json() -> List[Dict[str, Any]]:
"""
Load JSON array of arXiv papers and embeddings.
Expand All @@ -26,17 +41,8 @@ def read_paper_json() -> List[Dict[str, Any]]:
data = json.load(f)
except:
logger.info(f"Failed to read {path} => getting from s3")
res = requests.get(config.S3_DATA_URL)
data = res.json()
data = read_from_s3(path)

if os.path.isdir(config.DATA_LOCATION):
logger.info(f"Writing s3 file to {path}")
with open(path, "w") as f:
json.dump(data, f)
else:
logger.warning(
f"Data directory {config.DATA_LOCATION} not found. Skipping write of S3 data"
)
return data


Expand Down
23 changes: 6 additions & 17 deletions backend/arxivsearch/tests/db/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,20 @@ def test_read_paper_json_local(mock_json_load, mock_file_open, mock_path_join):


# Test when the file needs to be fetched from S3
@patch("arxivsearch.db.load.os.path.isdir")
@patch("arxivsearch.db.load.os.path.join")
@patch("arxivsearch.db.load.requests.get")
@patch("arxivsearch.db.load.open", new_callable=mock_open)
@patch("arxivsearch.db.load.json.dump")
@patch("arxivsearch.db.load.read_from_s3")
@patch("arxivsearch.db.load.json.load", side_effect=Exception("File not found"))
def test_read_paper_json_s3(
mock_json_load,
mock_json_dump,
mock_file_open,
mock_requests_get,
mock_read_from_s3,
mock_path_join,
mock_isdir,
):
mock_isdir.return_value = True
mock_path_join.return_value = "dummy_path"
mock_requests_get.return_value.json.return_value = [
{"id": "5678", "title": "Test Paper from S3"}
]
mock_read_from_s3.return_value = [{"id": "5678", "title": "Test Paper from S3"}]

result = read_paper_json()

mock_requests_get.assert_called_once()
mock_file_open.assert_called_with("dummy_path", "w")
mock_json_dump.assert_called_once_with(
[{"id": "5678", "title": "Test Paper from S3"}], mock_file_open()
)
mock_read_from_s3.assert_called_once()
mock_read_from_s3.assert_called_with("dummy_path")

assert result == [{"id": "5678", "title": "Test Paper from S3"}]

0 comments on commit 434b181

Please sign in to comment.