Skip to content

Commit

Permalink
refactor: make group_by_start_end more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 14, 2024
1 parent da866c5 commit f03f6c2
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 35 deletions.
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@ repos:
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
entry: bash -c 'cd backend && poetry run pytest'
language: system
pass_filenames: false
always_run: true
types: [python]

- id: coverage-check
name: coverage-check
entry: bash -c 'echo "Running coverage" && cd backend && poetry run coverage run -m pytest && poetry run coverage report --fail-under=80'
entry: bash -c 'cd backend && poetry run coverage erase && poetry run coverage run -m pytest && poetry run coverage report -m && poetry run coverage xml'
language: system
pass_filenames: false
always_run: false
files: ^(backend/|tests/)
always_run: true
types: [python]
16 changes: 16 additions & 0 deletions backend/.coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[run]
source = app
omit =
*/tests/*
*/__init__.py

[report]
exclude_lines =
pragma: no cover
def __repr__
if self.debug:
if __name__ == .__main__.:
raise NotImplementedError
pass
except ImportError:
def __str__
56 changes: 25 additions & 31 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from collections import defaultdict
import traceback
from typing import Optional

from app.config import settings
from app.database import get_db
from app.logger import Logger
from app.models.asset_content import AssetProcessingStatus
from app.repositories import project_repository, user_repository
from app.repositories import conversation_repository
from app.repositories import (
conversation_repository,
project_repository,
user_repository,
)
from app.requests import chat_query
from app.vectorstore.chroma import ChromaDB
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.config import settings


chat_router = APIRouter()

Expand All @@ -26,35 +28,21 @@ class ChatRequest(BaseModel):


def group_by_start_end(references):
grouped_references = defaultdict(
lambda: {"start": None, "end": None, "references": []}
)

grouped_references = {}
for ref in references:
key = (ref["start"], ref["end"])

# Initialize start and end if not already set
if grouped_references[key]["start"] is None:
grouped_references[key]["start"] = ref["start"]
grouped_references[key]["end"] = ref["end"]

# Check if a reference with the same asset_id already exists
existing_ref = None
for existing in grouped_references[key]["references"]:
grouped_ref = grouped_references.setdefault(
key, {"start": ref["start"], "end": ref["end"], "references": []}
)
for existing_ref in grouped_ref["references"]:
if (
existing["asset_id"] == ref["asset_id"]
and existing["page_number"] == ref["page_number"]
existing_ref["asset_id"] == ref["asset_id"]
and existing_ref["page_number"] == ref["page_number"]
):
existing_ref = existing
existing_ref["source"].extend(ref["source"])
break

if existing_ref:
# Append the source if asset_id already exists
existing_ref["source"].extend(ref["source"])
else:
# Otherwise, add the new reference
grouped_references[key]["references"].append(ref)

grouped_ref["references"].append(ref)
return list(grouped_references.values())


Expand Down Expand Up @@ -93,7 +81,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d

user = user_repository.get_users(db)[
0
] # Always pick the first user as of now
] # Always pick the first user as of now
conversation = conversation_repository.create_new_conversation(
db,
project_id=project_id,
Expand Down Expand Up @@ -159,7 +147,10 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d

except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail="Unable to process the chat query. Please try again.")
raise HTTPException(
status_code=400,
detail="Unable to process the chat query. Please try again.",
)


@chat_router.get("/project/{project_id}/status", status_code=200)
Expand Down Expand Up @@ -192,4 +183,7 @@ def chat_status(project_id: int, db: Session = Depends(get_db)):

except Exception:
logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail="Unable to process the chat query. Please try again.")
raise HTTPException(
status_code=400,
detail="Unable to process the chat query. Please try again.",
)
Loading

0 comments on commit f03f6c2

Please sign in to comment.