Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ jobs:

- name: benchmark-qed - Check
run: |
uv run poe check
uv run poe check

- name: benchmark-qed - Test
run: |
uv run poe test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# Ruff stuff:
.ruff_cache/
Expand Down
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20251217232258224287.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Support parquet inputs"
}
103 changes: 89 additions & 14 deletions benchmark_qed/autod/io/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load input files into Document objects."""

import datetime
Expand Down Expand Up @@ -106,17 +107,16 @@ def load_text_dir(
return documents


def load_csv_doc(
file_path: str,
encoding: str = defs.FILE_ENCODING,
def _load_docs_from_dataframe(
data_df: pd.DataFrame,
input_type: InputDataType,
title: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a CSV file and return a Document object."""
data_df = pd.read_csv(file_path, encoding=encoding)

documents: list[Document] = []

for index, row in enumerate(data_df.itertuples()):
text = getattr(row, text_tag, "")
if max_text_length is not None:
Expand All @@ -127,6 +127,7 @@ def load_csv_doc(
for tag in metadata_tags:
if tag in data_df.columns:
metadata[tag] = getattr(row, tag)

if "date_created" not in metadata:
metadata["date_created"] = datetime.datetime.now(
tz=datetime.UTC
Expand All @@ -136,15 +137,33 @@ def load_csv_doc(
Document(
id=str(uuid4()),
short_id=str(index),
title=str(file_path.replace(".csv", "")),
type="csv",
title=title,
type=str(input_type),
text=text,
attributes=metadata,
)
)
return documents


def load_csv_doc(
file_path: str,
encoding: str = defs.FILE_ENCODING,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a CSV file and return a Document object."""
return _load_docs_from_dataframe(
data_df=pd.read_csv(file_path, encoding=encoding),
input_type=InputDataType.CSV,
title=str(file_path.replace(".csv", "")),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)


def load_csv_dir(
dir_path: str,
encoding: str = defs.FILE_ENCODING,
Expand All @@ -171,6 +190,47 @@ def load_csv_dir(
return documents


def load_parquet_doc(
file_path: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load Documents from a parquet file."""
return _load_docs_from_dataframe(
data_df=pd.read_parquet(file_path),
input_type=InputDataType.PARQUET,
title=str(file_path.replace(".parquet", "")),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)


def load_parquet_dir(
dir_path: str,
text_tag: str = defs.TEXT_COLUMN,
metadata_tags: list[str] | None = None,
max_text_length: int | None = None,
) -> list[Document]:
"""Load a directory of parquet files and return a list of Document objects."""
documents: list[Document] = []
for file_path in Path(dir_path).rglob("*.parquet"):
documents.extend(
load_parquet_doc(
file_path=str(file_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
)

for index, document in enumerate(documents):
document.short_id = str(index)

return documents


def create_documents(
input_path: str,
input_type: InputDataType | str = InputDataType.JSON,
Expand Down Expand Up @@ -205,6 +265,13 @@ def create_documents(
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case InputDataType.PARQUET:
documents = load_parquet_dir(
dir_path=str(input_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case _:
msg = f"Unsupported input type: {input_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -236,6 +303,13 @@ def create_documents(
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case InputDataType.PARQUET:
documents = load_parquet_doc(
file_path=str(input_path),
text_tag=text_tag,
metadata_tags=metadata_tags,
max_text_length=max_text_length,
)
case _:
msg = f"Unsupported input type: {input_type}"
raise ValueError(msg)
Expand All @@ -254,18 +328,19 @@ def load_documents(
"""Read documents from a dataframe using pre-converted records."""
records = df.to_dict("records")

def _get_attributes(row: dict) -> dict[str, Any]:
attributes = row.get("attributes", {})
selected_attributes = attributes_cols or []
return {attr: attributes.get(attr, None) for attr in selected_attributes}

return [
Document(
id=row.get(id_col, str(uuid4())),
short_id=row.get(short_id_col, str(index)),
title=row.get(title_col, ""),
type=row.get(type_col, ""),
text=row.get(text_col, ""),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols
else {}
),
attributes=_get_attributes(row),
)
for index, row in enumerate(records)
]
Expand Down
1 change: 1 addition & 0 deletions benchmark_qed/autod/io/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class InputDataType(StrEnum):
JSON = "json"
CSV = "csv"
TEXT = "text"
PARQUET = "parquet"
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ test = "pytest tests"
serve_docs = "mkdocs serve"
build_docs = "mkdocs build"

_test_with_coverage = 'coverage run --source=benchmark_qed -m pytest tests/unit'
_test_with_coverage = 'coverage run --source=benchmark_qed -m pytest tests'
_coverage_report = 'coverage report --fail-under=100 --show-missing --omit="benchmark_qed/doc_gen/__main__.py"'
_generate_coverage_xml = 'coverage xml --omit="benchmark_qed/doc_gen/__main__.py"'
_generate_coverage_html = 'coverage html --omit="benchmark_qed/doc_gen/__main__.py"'
Expand Down Expand Up @@ -120,3 +120,6 @@ sequence = [
[tool.pyright]
include = ["benchmark_qed", "tests"]
exclude = ["**/__pycache__"]

[tool.pytest.ini_options]
tmp_path_retention_policy = "failed"
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ builtins-ignorelist = ["input", "id", "bytes"]

[lint.pydocstyle]
convention = "numpy"

[lint.flake8-copyright]
notice-rgx = "(?i)Copyright \\(C\\) (\\d{4} )?Microsoft Corporation"
2 changes: 2 additions & 0 deletions tests/autod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
2 changes: 2 additions & 0 deletions tests/autod/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
Loading