Skip to content

Commit

Permalink
lint, mypy, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
rbs333 committed Jul 19, 2024
1 parent 19ad80d commit 7a0898d
Show file tree
Hide file tree
Showing 21 changed files with 670 additions and 109 deletions.
71 changes: 71 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: Test Suite

on:
pull_request:
branches:
- main

push:
branches:
- main

jobs:
test:
name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis-stack ${{matrix.redis-stack-version}}]
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version: [3.9, 3.10, 3.11]
redis-stack-version: ['latest']
# python-version: [3.9, 3.10, 3.11] # idk if we need all of this for a demo repo
# connection: ['hiredis', 'plain']
# redis-stack-version: ['6.2.6-v9', 'latest', 'edge']

services:
redis:
image: redis/redis-stack-server:${{matrix.redis-stack-version}}
ports:
- 6379:6379

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install Poetry
uses: snok/install-poetry@v1

- name: change directory
run: cd backend/

- name: Install dependencies
run: |
poetry install --all-extras
- name: Install hiredis if needed
if: matrix.connection == 'hiredis'
run: |
poetry add hiredis
- name: Set Redis version
run: |
echo "REDIS_VERSION=${{ matrix.redis-stack-version }}" >> $GITHUB_ENV
# - name: Authenticate to Google Cloud
# uses: google-github-actions/auth@v1
# with:
# credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }}

- name: Run tests
if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest'
env:
OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}}
OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
run: |
poetry run test
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ venv
__pycache__
new_backend/arxivsearch/templates/
*/.nvm
.coverage*
coverage.*
htmlcov/
Empty file removed backend/__init__.py
Empty file.
11 changes: 5 additions & 6 deletions backend/arxivsearch/api/routes/papers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import asyncio
import numpy as np
import logging

from fastapi import APIRouter, Query, Depends

import numpy as np
from fastapi import APIRouter, Depends, Query
from redisvl.index import AsyncSearchIndex
from redisvl.query import VectorQuery, FilterQuery, CountQuery
from redisvl.query import CountQuery, FilterQuery, VectorQuery

from arxivsearch import config
from arxivsearch.db import redis_helpers
from arxivsearch.utils.embeddings import embeddings
from arxivsearch.schema.similarity import (
PaperSimilarityRequest,
UserTextSimilarityRequest,
SearchResponse,
UserTextSimilarityRequest,
VectorSearchResponse,
)
from arxivsearch.utils.embeddings import embeddings

logger = logging.getLogger(__name__)

Expand Down
12 changes: 5 additions & 7 deletions backend/arxivsearch/db/load.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import asyncio
import numpy as np
import json
import os
import logging

import os
from typing import Any, Dict, List

from arxivsearch import config
from arxivsearch.schema.provider import Provider
from arxivsearch.db import redis_helpers

import numpy as np
from redisvl.index import AsyncSearchIndex

from arxivsearch import config
from arxivsearch.db import redis_helpers
from arxivsearch.schema.provider import Provider

logger = logging.getLogger(__name__)

Expand Down
12 changes: 7 additions & 5 deletions backend/arxivsearch/db/redis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import logging
import os
from typing import List

from redis.asyncio import Redis
from arxivsearch import config
from redisvl.schema import IndexSchema
from redisvl.index import AsyncSearchIndex, SearchIndex
from redisvl.query.filter import Tag, FilterExpression
from redisvl.query.filter import FilterExpression, Tag
from redisvl.schema import IndexSchema

from arxivsearch import config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -34,7 +36,7 @@ async def get_async_index():


def build_filter_expression(
years: List[int], categories: List[str]
years: List[str], categories: List[str]
) -> FilterExpression:
"""
Construct a filter expression based on the provided years and categories.
Expand Down
5 changes: 1 addition & 4 deletions backend/arxivsearch/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import uvicorn
import logging

from pathlib import Path

import uvicorn
from fastapi import FastAPI

from fastapi.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware

from arxivsearch import config
from arxivsearch.api.main import api_router
from arxivsearch.spa import SinglePageApplication


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
Expand Down
4 changes: 3 additions & 1 deletion backend/arxivsearch/schema/provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from enum import Enum


class Provider(str, Enum):
"""Embedding model provider"""

huggingface = "huggingface"
openai = "openai"
cohere = "cohere"
cohere = "cohere"
8 changes: 3 additions & 5 deletions backend/arxivsearch/spa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@ def __init__(self, directory: os.PathLike, index="index.html") -> None:

# set html=True to resolve the index even when no
# the base path is passed in
super().__init__(
directory=directory, packages=None, html=True, check_dir=True
)
super().__init__(directory=directory, packages=None, html=True, check_dir=True)

async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
if response.status_code == 404:
response = await super().get_response('.', scope)
response = await super().get_response(".", scope)
return response

def lookup_path(self, path: str) -> Tuple[str, os.stat_result]:
def lookup_path(self, path: str) -> Tuple[str, os.stat_result | None]:
results = super().lookup_path(path)
full_path, stat_result = results

Expand Down
28 changes: 14 additions & 14 deletions backend/arxivsearch/tests/api/routes/test_papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from arxivsearch.main import app
from arxivsearch.schema.similarity import (
UserTextSimilarityRequest,
PaperSimilarityRequest,
UserTextSimilarityRequest,
)


Expand Down Expand Up @@ -100,24 +100,24 @@ async def test_vector_by_text_bad_input(async_client: AsyncClient, bad_req_json:
assert response.status_code == 422


# @pytest.mark.asyncio(scope="session")
# async def test_vector_by_paper(
# async_client: AsyncClient,
# paper_req: PaperSimilarityRequest,
# ):
# response = await async_client.post(
# f"papers/vector_search/by_paper", json=paper_req.model_dump()
# )
@pytest.mark.asyncio(scope="session")
async def test_vector_by_paper(
async_client: AsyncClient,
paper_req: PaperSimilarityRequest,
):
response = await async_client.post(
f"papers/vector_search/by_paper", json=paper_req.model_dump()
)

# assert response.status_code == 200
# content = response.json()
assert response.status_code == 200
content = response.json()

# assert content["total"] == 2
# assert len(content["papers"]) == 2
assert content["total"] == 2
assert len(content["papers"]) == 2


@pytest.mark.asyncio(scope="session")
async def test_vector_by_text_bad_input(async_client: AsyncClient, bad_req_json: dict):
async def test_vector_by_paper_bad_input(async_client: AsyncClient, bad_req_json: dict):

response = await async_client.post(
f"papers/vector_search/by_paper", json=bad_req_json
Expand Down
13 changes: 7 additions & 6 deletions backend/arxivsearch/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from asyncio import get_event_loop
from typing import Generator
import pytest_asyncio

import pytest
from redis.asyncio import Redis as AsyncRedis
import pytest_asyncio
from httpx import AsyncClient
from asyncio import get_event_loop
from redis.asyncio import Redis
from arxivsearch.db import redis_helpers
from arxivsearch.tests.utils.seed import seed_test_db
from arxivsearch import config
from redis.asyncio import Redis as AsyncRedis
from redisvl.index import AsyncSearchIndex

from arxivsearch import config
from arxivsearch.db import redis_helpers
from arxivsearch.main import app
from arxivsearch.tests.utils.seed import seed_test_db


@pytest.fixture(scope="module")
Expand Down
2 changes: 2 additions & 0 deletions backend/arxivsearch/tests/utils/seed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os

import numpy as np

from arxivsearch import config
from arxivsearch.db import redis_helpers

Expand Down
2 changes: 1 addition & 1 deletion backend/arxivsearch/utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from redisvl.utils.vectorize import (
CohereTextVectorizer,
OpenAITextVectorizer,
HFTextVectorizer,
OpenAITextVectorizer,
)

from arxivsearch import config
Expand Down
Loading

0 comments on commit 7a0898d

Please sign in to comment.