Skip to content

Commit

Permalink
Resolve some linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrejridzik committed Mar 7, 2025
1 parent 30bd13b commit 49d0277
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 41 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ repos:
- id: ruff
args: [ --fix ]
- id: ruff-format
# TODO: Add mypy once all related issues are resolved
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v0.991
# hooks:
Expand Down
2 changes: 1 addition & 1 deletion api/.env.app.sample
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ OLLAMA__NUM_CTX=4096

AIOD__URL=""
AIOD__COMMA_SEPARATED_ASSET_TYPES="ASSET1,ASSET2,ASSET3"
AIOD__COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTON="ASSET1,ASSET2" # for now we only support "datasets" asset type
AIOD__COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTION="ASSET1,ASSET2" # for now we only support "datasets" asset type
AIOD__WINDOW_SIZE=1000
AIOD__WINDOW_OVERLAP=0.1
AIOD__JOB_WAIT_INBETWEEN_REQUESTS_SEC=1
Expand Down
2 changes: 1 addition & 1 deletion api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ In this file you find the following ENV variables:
- `OLLAMA__NUM_CTX`: The maximum number of tokens that are considered to be within model context when an LLM generates an output for metadata filtering purposes.
- `AIOD__URL`: URL of the AIoD API we use to retrieve information about the assets and assets themselves.
- `AIOD__COMMA_SEPARATED_ASSET_TYPES`: Comma-separated list of values representing all the asset types we wish to process
- `AIOD__COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTON`: Comma-separated list of values representing all the asset types we wish to apply metadata filtering on. Only include an asset type into this list if all the setup regarding metadata filtering (manual/automatic extraction of metadata from assets, automatic extraction of filter in user queries)
- `AIOD__COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTION`: Comma-separated list of values representing all the asset types we wish to apply metadata filtering on. Only include an asset type into this list if all the setup regarding metadata filtering (manual/automatic extraction of metadata from assets, automatic extraction of filter in user queries)
- `AIOD__WINDOW_SIZE`: Asset window size (limit of pagination) we use for retrieving assets from AIoD API during the initial setup, by iterating over all the AIoD assets.
- `AIOD__WINDOW_OVERLAP`: Asset window overlap representing relative size of an overlap we maintain between the pages in pagination. The overlap is necessary so that we wouldn't potentionally skip on some new assets to process if any particular assets were to be deleted in parallel with our update logic, making the whole data returned by AIoD platform slightly shifted.
- `AIOD__JOB_WAIT_INBETWEEN_REQUESTS_SEC`: Number of seconds we wait when performing JOBs (for updating/deleting assets) in between AIoD requests in order not to overload their API.
Expand Down
10 changes: 5 additions & 5 deletions api/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class OllamaConfig(BaseModel):
class AIoDConfig(BaseModel):
URL: AnyUrl = Field(...)
COMMA_SEPARATED_ASSET_TYPES: str = Field(...)
COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTON: str = Field(...)
COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTION: str = Field(...)
WINDOW_SIZE: int = Field(1000, le=1000, gt=1)
WINDOW_OVERLAP: float = Field(0.1, lt=1, ge=0)
JOB_WAIT_INBETWEEN_REQUESTS_SEC: float = Field(1, ge=0)
Expand All @@ -71,7 +71,7 @@ def convert_csv_to_asset_types(cls, value: str) -> list[AssetType]:

@field_validator(
"COMMA_SEPARATED_ASSET_TYPES",
"COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTON",
"COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTION",
mode="before",
)
@classmethod
Expand All @@ -92,13 +92,13 @@ def OFFSET_INCREMENT(self) -> int:
return int(settings.AIOD.WINDOW_SIZE * (1 - settings.AIOD.WINDOW_OVERLAP))

@property
def ASSET_TYPES(self) -> list[str]:
def ASSET_TYPES(self) -> list[AssetType]:
return self.convert_csv_to_asset_types(self.COMMA_SEPARATED_ASSET_TYPES)

@property
def ASSET_TYPES_FOR_METADATA_EXTRACTION(self) -> list[str]:
def ASSET_TYPES_FOR_METADATA_EXTRACTION(self) -> list[AssetType]:
types = self.convert_csv_to_asset_types(
self.COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTON
self.COMMA_SEPARATED_ASSET_TYPES_FOR_METADATA_EXTRACTION
)

if not set(types).issubset(set(self.ASSET_TYPES)):
Expand Down
14 changes: 7 additions & 7 deletions api/app/schemas/asset_metadata/dataset_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel, Field, field_validator

# Rules how to setup these schemas representing asset metadata
# Rules how to set up these schemas representing asset metadata
# 1. Each pydantic.Field only contains a default value and a description.
# Other arguments are not copied over when creating dynamic schemas.
# 2. If you wish to apply some additional value constraints, feel free to do so, but
Expand All @@ -17,7 +17,7 @@ class HuggingFaceDatasetMetadataTemplate(BaseModel):
Extraction of relevant metadata we wish to retrieve from ML assets
"""

_ALL_VALID_VALUES: ClassVar[list[list[str]] | None] = None
_ALL_VALID_VALUES: ClassVar[dict[str, list[str]] | None] = None
_PATH_TO_VALID_VALUES: ClassVar[Path] = Path("data/valid_metadata_values.json")

date_published: Optional[str] = Field(
Expand Down Expand Up @@ -58,29 +58,29 @@ def _load_all_valid_values(cls) -> None:
def get_field_valid_values(cls, field: str) -> list[str]:
if cls._ALL_VALID_VALUES is None:
cls._load_all_valid_values()
return cls._ALL_VALID_VALUES.get(field, None)
return cls._ALL_VALID_VALUES.get(field, [])

@classmethod
def exists_field_valid_values(cls, field: str) -> bool:
if cls._ALL_VALID_VALUES is None:
cls._load_all_valid_values()
return field in cls._ALL_VALID_VALUES.keys()

@field_validator("date_published", mode="before")
@classmethod
@field_validator("date_published", mode="before")
def check_date_published(cls, value: str) -> str | None:
pattern = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$"
return value if bool(re.match(pattern, value)) else None

@field_validator("license", mode="before")
@classmethod
@field_validator("license", mode="before")
def check_license(cls, value: str) -> str | None:
if cls._ALL_VALID_VALUES is None:
cls._load_all_valid_values()
return value if value in cls.get_field_valid_values("license") else None

@field_validator("task_types", mode="before")
@classmethod
@field_validator("task_types", mode="before")
def check_task_types(cls, values: list[str]) -> list[str] | None:
if cls._ALL_VALID_VALUES is None:
cls._load_all_valid_values()
Expand All @@ -89,7 +89,7 @@ def check_task_types(cls, values: list[str]) -> list[str] | None:
]
return valid_values

@field_validator("languages", mode="before")
@classmethod
@field_validator("languages", mode="before")
def check_languages(cls, values: list[str]) -> list[str] | None:
return [val.lower() for val in values if len(val) == 2]
2 changes: 1 addition & 1 deletion api/app/schemas/request_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RequestParams(BaseModel):
from_time: datetime | None = None
to_time: datetime | None = None

def new_page(self, offset: int = None, limit: int = None) -> RequestParams:
def new_page(self, offset: int | None = None, limit: int | None = None) -> RequestParams:
new_obj = RequestParams(**self.model_dump())

if offset is not None:
Expand Down
2 changes: 1 addition & 1 deletion api/app/services/aiod.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def check_aiod_document(doc_id: str, asset_type: AssetType, sleep_time: float =


def _build_aiod_url_queries(url_params: RequestParams) -> dict:
def translate_datetime_to_aiod_params(date: datetime) -> str | None:
def translate_datetime_to_aiod_params(date: datetime | None = None) -> str | None:
if date is None:
return date
return f"{date.year}-{date.month}-{date.day}"
Expand Down
8 changes: 5 additions & 3 deletions api/app/services/embedding_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def retrieve_topk_document_ids(
query_text: str | None = None,
topk: int = 10,
filter: str = "",
query_embeddings: list[list[float]] = None,
query_embeddings: list[float] | None = None,
) -> SearchResults:
pass

Expand All @@ -72,14 +72,15 @@ def __init__(
self.chunk_embedding_store = settings.MILVUS.STORE_CHUNKS
self.verbose = verbose

self.client = None
self.client: MilvusClient | None = None

@staticmethod
async def init() -> MilvusEmbeddingStore:
obj = MilvusEmbeddingStore()
await obj.init_connection()
return obj

async def init_connection(self) -> None:
async def init_connection(self) -> bool:
for _ in range(5):
try:
self.client = MilvusClient(
Expand Down Expand Up @@ -174,6 +175,7 @@ def store_embeddings(
loader: DataLoader,
asset_type: AssetType,
milvus_batch_size: int = 50,
**kwargs,
) -> int:
collection_name = self.get_collection_name(asset_type)
self._create_collection(asset_type)
Expand Down
13 changes: 7 additions & 6 deletions api/app/services/inference/architecture.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABCMeta, abstractmethod
from typing import Callable

import numpy as np
Expand All @@ -7,7 +7,7 @@
from transformers import PreTrainedModel, PreTrainedTokenizer


class EmbeddingModel(ABC):
class EmbeddingModel(torch.nn.Module, metaclass=ABCMeta):
@abstractmethod
def forward(self, texts: list[str]) -> list[torch.Tensor]:
"""
Expand All @@ -17,6 +17,7 @@ def forward(self, texts: list[str]) -> list[torch.Tensor]:
Returns a list of tensors representing either entire documents or
the chunks documents consist of
"""
pass

@abstractmethod
def _forward(self, encodings: dict[str, torch.Tensor]) -> list[torch.Tensor]:
Expand Down Expand Up @@ -105,7 +106,7 @@ def forward(
return [encodings["sentence_embedding"]]


class Basic_EmbeddingModel(torch.nn.Module, EmbeddingModel):
class Basic_EmbeddingModel(EmbeddingModel):
"""
Class representing models that process the input documents in their entirety
without needing to divide them into separate chunks.
Expand Down Expand Up @@ -173,7 +174,7 @@ def preprocess_input(self, texts: list[str]) -> dict:
return encodings


class Hierarchical_EmbeddingModel(torch.nn.Module, EmbeddingModel):
class Hierarchical_EmbeddingModel(EmbeddingModel):
"""
Class representing models that process the input documents by firstly individually
processing their chunks before further accumulating the chunk information to
Expand Down Expand Up @@ -240,7 +241,7 @@ def forward(self, texts: list[str]) -> list[torch.Tensor]:

def _forward(self, encodings: dict[str, torch.Tensor]) -> list[torch.Tensor]:
chunk_embeddings = self._first_level_forward(
encodings["input_encodings"], encodings["max_num_chunks"]
encodings["input_encodings"], int(encodings["max_num_chunks"])
)
doc_embeddings = self._second_level_forward(
chunk_embeddings,
Expand Down Expand Up @@ -343,7 +344,7 @@ def preprocess_input(self, texts: list[str]) -> dict:
"max_num_chunks": max_chunks,
}

# input_encoddings
# input_encodings
transposed_texts = np.array(padded_texts).T.tolist()
encodings = [
self.tokenizer(
Expand Down
5 changes: 3 additions & 2 deletions api/app/services/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def get_device(cls) -> str:

def __init__(self, device: torch.device = "cpu") -> None:
self.use_chunking = settings.MILVUS.STORE_CHUNKS
self.model = self.load_model(self.use_chunking, device)
self.model = AiModel.load_model(self.use_chunking, device)

def load_model(self, use_chunking: bool, device: torch.device = "cpu") -> EmbeddingModel:
@staticmethod
def load_model(use_chunking: bool, device: torch.device = "cpu") -> EmbeddingModel:
transformer = SentenceTransformerToHF(settings.MODEL_LOADPATH, trust_remote_code=True)
text_splitter = TokenizerTextSplitter(
transformer.tokenizer, chunk_size=512, chunk_overlap=0.25
Expand Down
9 changes: 5 additions & 4 deletions api/app/services/inference/text_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _extract_very_basic_fields(cls, data: dict) -> dict:
"platform": data["platform"],
"name": data["name"],
}
description = cls._get_text_like_field(data)
description = cls._get_text_like_field(data, "description")
if description is not None:
new_object["description"] = description

Expand Down Expand Up @@ -349,9 +349,9 @@ def _extract_relevant_fields(cls, data: dict, asset_type: AssetType) -> dict:
new_dist = {k: dist[k] for k in dist_relevant_fields if k in dist}

if new_dist.get("content_size_kb", None) is not None:
size_kb = new_dist["content_size_kb"]
new_dist["content_size_mb"] = float(f"{(size_kb / 1024):.2f}")
new_dist["content_size_gb"] = float(f"{(size_kb / 1024**2):.2f}")
size_kb = int(new_dist["content_size_kb"])
new_dist["content_size_mb"] = round(size_kb / 1024, 2)
new_dist["content_size_gb"] = round(size_kb / 1024**2, 2)
if new_dist != {}:
new_object[field_name].append(new_dist)

Expand All @@ -370,6 +370,7 @@ def _extract_relevant_fields(cls, data: dict, asset_type: AssetType) -> dict:

@classmethod
def _get_text_like_field(cls, data: dict, field: str) -> str | None:
# TODO: Simplify this logic
description = data.get(field, None)
if description is None:
return None
Expand Down
19 changes: 11 additions & 8 deletions api/app/services/threads/embedding_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def process_aiod_assets_wrapper(
stringified_assets = [stringify_function(obj) for obj in assets_to_add]
asset_ids = [str(obj["identifier"]) for obj in assets_to_add]

metadata = [{} for _ in assets_to_add]
metadata: list[dict] = [{} for _ in assets_to_add]
if extract_metadata_function is not None:
metadata = [extract_metadata_function(obj) for obj in assets_to_add]

Expand Down Expand Up @@ -201,7 +201,7 @@ def get_assets_to_add_and_delete(
newly_added_doc_ids: list[str],
last_db_sync_datetime: datetime | None,
) -> tuple[list[dict] | None, list[str] | None]:
mark_recursions = []
mark_recursions: list[int] = []
assets = recursive_aiod_asset_fetch(asset_type, url_params, mark_recursions)

if len(assets) == 0 and len(mark_recursions) == 0:
Expand Down Expand Up @@ -247,12 +247,15 @@ def parse_aiod_asset_date(
none_value: Literal["none", "now", "zero"] = "none",
) -> datetime | None:
string_time = asset.get("aiod_entry", {}).get(field, None)
if string_time is None:

if string_time is not None:
return datetime.fromisoformat(string_time).replace(tzinfo=timezone.utc)
else:
if none_value == "none":
return None
if none_value == "now":
elif none_value == "now":
return datetime.now(tz=timezone.utc)
if none_value == "zero":
return datetime.fromtimestamp(0, tz=timezone)

return datetime.fromisoformat(string_time).replace(tzinfo=timezone.utc)
elif none_value == "zero":
return datetime.fromtimestamp(0, tz=timezone.utc)
else:
return None
4 changes: 2 additions & 2 deletions api/app/services/threads/milvus_gc_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def delete_asset_embeddings(
# iterate over entirety of AIoD database, store all the doc IDs
while True:
assets_to_add, _ = get_assets_to_add_and_delete(
url=settings.AIOD.get_assets_url(asset_type),
asset_type=asset_type,
url_params=url_params,
existing_doc_ids_from_past=[],
newly_added_doc_ids=[],
Expand All @@ -68,7 +68,7 @@ def delete_asset_embeddings(
all_aiod_doc_ids.extend([str(obj["identifier"]) for obj in assets_to_add])

# during the traversal of AIoD assets, some of them may be deleted in between
# which would make us skip some assets if we were to use tradinational
# which would make us skip some assets if we were to use traditional
# pagination without any overlap, hence the need for an overlap
url_params.offset += settings.AIOD.OFFSET_INCREMENT

Expand Down

0 comments on commit 49d0277

Please sign in to comment.