Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ jobs:

- name: benchmark-qed - Check
run: |
uv run poe check
uv run poe check

- name: benchmark-qed - Test
run: |
uv run poe test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# Ruff stuff:
.ruff_cache/
Expand Down
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20251217232258224287.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Support parquet inputs"
}
103 changes: 89 additions & 14 deletions benchmark_qed/autod/io/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load input files into Document objects."""

import datetime
Expand Down Expand Up @@ -106,17 +107,16 @@ def load_text_dir(
return documents


def load_csv_doc(
file_path: str,
encoding: str = defs.FILE_ENCODING,
def _load_docs_from_dataframe(
data_df: pd.DataFrame,
input_type: InputDataType,
title: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a CSV file and return a Document object."""
data_df = pd.read_csv(file_path, encoding=encoding)

documents: list[Document] = []

for index, row in enumerate(data_df.itertuples()):
text = getattr(row, text_tag, "")
if max_text_length is not None:
Expand All @@ -127,6 +127,7 @@ def load_csv_doc(
for tag in metadata_tags:
if tag in data_df.columns:
metadata[tag] = getattr(row, tag)

if "date_created" not in metadata:
metadata["date_created"] = datetime.datetime.now(
tz=datetime.UTC
Expand All @@ -136,15 +137,33 @@ def load_csv_doc(
Document(
id=str(uuid4()),
short_id=str(index),
title=str(file_path.replace(".csv", "")),
type="csv",
title=title,
type=str(input_type),
text=text,
attributes=metadata,
)
)
return documents


def load_csv_doc(
file_path: str,
encoding: str = defs.FILE_ENCODING,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a CSV file and return a Document object."""
return _load_docs_from_dataframe(
data_df=pd.read_csv(file_path, encoding=encoding),
input_type=InputDataType.CSV,
title=str(file_path.replace(".csv", "")),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)


def load_csv_dir(
dir_path: str,
encoding: str = defs.FILE_ENCODING,
Expand All @@ -171,6 +190,47 @@ def load_csv_dir(
return documents


def load_parquet_doc(
file_path: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load Documents from a parquet file."""
return _load_docs_from_dataframe(
data_df=pd.read_parquet(file_path),
input_type=InputDataType.PARQUET,
title=str(file_path.replace(".parquet", "")),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)


def load_parquet_dir(
dir_path: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a directory of parquet files and return a list of Document objects."""
documents: list[Document] = []
for file_path in Path(dir_path).rglob("*.parquet"):
documents.extend(
load_parquet_doc(
file_path=str(file_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
)

for index, document in enumerate(documents):
document.short_id = str(index)

return documents


def create_documents(
input_path: str,
input_type: InputDataType | str = InputDataType.JSON,
Expand Down Expand Up @@ -205,6 +265,13 @@ def create_documents(
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case InputDataType.PARQUET:
documents = load_parquet_dir(
dir_path=str(input_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case _:
msg = f"Unsupported input type: {input_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -236,6 +303,13 @@ def create_documents(
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case InputDataType.PARQUET:
documents = load_parquet_doc(
file_path=str(input_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case _:
msg = f"Unsupported input type: {input_type}"
raise ValueError(msg)
Expand All @@ -254,18 +328,19 @@ def load_documents(
"""Read documents from a dataframe using pre-converted records."""
records = df.to_dict("records")

def _get_attributes(row: dict) -> dict[str, Any]:
attributes = row.get("attributes", {})
selected_attributes = attributes_cols or []
return {attr: attributes.get(attr, None) for attr in selected_attributes}

return [
Document(
id=row.get(id_col, str(uuid4())),
short_id=row.get(short_id_col, str(index)),
title=row.get(title_col, ""),
type=row.get(type_col, ""),
text=row.get(text_col, ""),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols
else {}
),
attributes=_get_attributes(row),
)
for index, row in enumerate(records)
]
Expand Down
1 change: 1 addition & 0 deletions benchmark_qed/autod/io/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class InputDataType(StrEnum):
JSON = "json"
CSV = "csv"
TEXT = "text"
PARQUET = "parquet"
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ test = "pytest tests"
serve_docs = "mkdocs serve"
build_docs = "mkdocs build"

_test_with_coverage = 'coverage run --source=benchmark_qed -m pytest tests/unit'
_test_with_coverage = 'coverage run --source=benchmark_qed -m pytest tests'
_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="benchmark_qed/doc_gen/__main__.py"'
_generate_coverage_xml = 'coverage xml --omit="benchmark_qed/doc_gen/__main__.py"'
_generate_coverage_html = 'coverage html --omit="benchmark_qed/doc_gen/__main__.py"'
Expand Down Expand Up @@ -120,3 +120,6 @@ sequence = [
[tool.pyright]
include = ["benchmark_qed", "tests"]
exclude = ["**/__pycache__"]

[pytest]
tmp_path_retention_policy = "failed"
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ builtins-ignorelist = ["input", "id", "bytes"]

[lint.pydocstyle]
convention = "numpy"

[lint.flake8-copyright]
notice-rgx = "(?i)Copyright \\(C\\) (\\d{4} )?Microsoft Corporation"
2 changes: 2 additions & 0 deletions tests/autod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
2 changes: 2 additions & 0 deletions tests/autod/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
Loading