-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_types.py
65 lines (47 loc) · 1.49 KB
/
data_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from abc import abstractmethod
from typing import Callable, TypeAlias
from pydantic import BaseModel
from pymilvus import MilvusClient
from chromadb.api.client import Client as ChromaClient
VectorDbClient: TypeAlias = ChromaClient | MilvusClient
class AnnotatedDoc(BaseModel):
id: str
score: int | float
class QueryDatapoint(BaseModel):
text: str
id: str | None = None
annotated_docs: list[AnnotatedDoc] | None = None
def get_relevant_documents(
self, relevance_func: Callable[[float], bool]
) -> list[AnnotatedDoc]:
return [
doc for doc in self.annotated_docs
if relevance_func(doc.score)
]
class SemanticSearchResult(BaseModel):
query_id: str
doc_ids: list[str]
distances: list[float] | None = None
class RetrievedDocuments(BaseModel):
query_id: str
document_objects: list[dict]
class MetricsClassAtK(BaseModel):
@abstractmethod
def compute_metrics_for_datapoint(
self, query_dp: QueryDatapoint, query_results: SemanticSearchResult,
k: int, **kwargs
) -> None:
pass
@abstractmethod
def average_results(self) -> None:
pass
class MetricsClass(BaseModel):
results_in_top: dict[str, MetricsClassAtK] | None = None
@abstractmethod
def compute_metrics_for_datapoint(
self, query_dp: QueryDatapoint, query_results: SemanticSearchResult, **kwargs
) -> None:
pass
@abstractmethod
def average_results(self) -> None:
pass