Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/dcv_benchmark/data_factory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

class DatasetBuilder:
"""
Orchestrates the creation of a RAG Security Dataset.
Orchestrates the creation of a RAG Security Dataset based on the
SQUAD dataset.

Workflow:
1. Load raw samples (Query + Gold Chunk) from a corpus.
Expand Down Expand Up @@ -71,20 +72,33 @@ def build(self) -> Dataset:

for raw in raw_samples:
# A. Retrieve Contexts (Distractors + Candidates)
# We fetch k chunks. Note: Gold chunk might be in here if retrieval is good.
retrieved_texts = self.retriever.query(
query_text=raw.query, k=self.config.retrieval_k
# We fetch k+1 chunks to handle the "Oracle" property robustly.
# This allows us to verify if the golden sample would naturally appear in
# the top K or if we need to force-inject it, without potentially losing
# the K-th best distractor when replacing.
retrieved_candidates = self.retriever.query(
query_text=raw.query, k=self.config.retrieval_k + 1
)

# B. Enforce Gold Chunk Presence (The "Oracle" Property)
# In this benchmark, we want to test Integrity, not Recall.
# So we MUST ensure the correct answer context is present.
final_context_texts = retrieved_texts
if raw.source_document not in final_context_texts:
# Replace the last retrieved chunk with the Gold Chunk
if final_context_texts:
final_context_texts.pop()
final_context_texts.insert(0, raw.source_document)
# Logic:
# 1. Start with the Gold Chunk (source_document).
# 2. Add as many retrieved candidates as possible, skipping any that are
# identical to Gold.
# 3. Stop when we reach K total items.
final_context_texts = [raw.source_document]

for candidate in retrieved_candidates:
if len(final_context_texts) >= self.config.retrieval_k:
break

# Exact string match check for deduplication
if candidate != raw.source_document:
final_context_texts.append(candidate)

# Note: This logic always puts Gold at index 0 initially.
# If random shuffle is needed, it should happen later or be explicitly
# handled.

# C. Determine Sample Type (Attack vs Benign)
is_attack = rng.random() < self.config.attack_rate
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_config_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def mock_dataset_sample():
sample_type="benign",
query="test query",
labels=[],
context=[{"content": "test context", "source": "test", "id": "1"}], # added id
context=[{"content": "test context", "source": "test", "id": "1"}],
)


Expand Down
131 changes: 131 additions & 0 deletions tests/unit/data_factory/test_builder_golden.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from unittest.mock import MagicMock, patch

import pytest

from dcv_benchmark.data_factory.builder import DatasetBuilder
from dcv_benchmark.models.data_factory import DataFactoryConfig, RawSample


@pytest.fixture
def mock_config():
return DataFactoryConfig(
dataset_name="repro_ds",
description="repro",
source_file="dummy.json",
attack_strategy="naive",
attack_payload="dummy",
attack_rate=0.0, # Benign to focus on retrieval
retrieval_k=3,
truncate_overflow=False,
)


@pytest.fixture
def mock_loader():
loader = MagicMock()
loader.load.return_value = [
RawSample(
id="1", query="Q1", source_document="GOLD_CONTENT", reference_answer="Ans1"
),
]
return loader


@pytest.fixture
def mock_injector():
return MagicMock()


@pytest.fixture
def mock_retriever_class():
with patch("dcv_benchmark.data_factory.builder.EphemeralRetriever") as MockClass:
yield MockClass


def test_gold_in_retrieved_k_plus_one(
mock_config, mock_loader, mock_injector, mock_retriever_class
):
"""
Test scenario where we retrieve k+1 (4) items.
The Gold sample is the 4th item (index 3).
We expect the builder to keep the Gold sample and the top 2 others.
Total context size should be 3.
"""
mock_instance = mock_retriever_class.return_value
# Return 4 items: Top1, Top2, Top3, GOLD_CONTENT
mock_instance.query.return_value = ["D1", "D2", "D3", "GOLD_CONTENT"]

# Update config to ensure we are testing k=3 behavior
mock_config.retrieval_k = 3

builder = DatasetBuilder(mock_loader, mock_injector, mock_config)
dataset = builder.build()

sample = dataset.samples[0]
contents = [c.content for c in sample.context]

mock_instance.query.assert_called_with(query_text="Q1", k=4)

# 2. Verify Gold is present
assert "GOLD_CONTENT" in contents

# 3. Verify size is 3
assert len(contents) == 3

# 4. Verify we kept D1, D2 and GOLD (discarded D3 effectively, or whatever
# priority logic). If we strictly take top k-1 + gold, it should be D1, D2, GOLD.
# Note: The order in 'contents' might vary depending on implementation
# (insert at 0 vs preserve order).
assert set(contents) == {"D1", "D2", "GOLD_CONTENT"}


def test_gold_missing_from_retrieved(
mock_config, mock_loader, mock_injector, mock_retriever_class
):
"""
Test scenario where Gold is NOT in the retrieved 4 items.
We expect replacement of the last item with Gold.
"""
mock_instance = mock_retriever_class.return_value
# Return 4 distractors
mock_instance.query.return_value = ["D1", "D2", "D3", "D4"]
mock_config.retrieval_k = 3

builder = DatasetBuilder(mock_loader, mock_injector, mock_config)
dataset = builder.build()

sample = dataset.samples[0]
contents = [c.content for c in sample.context]

# Verify k=4 call
mock_instance.query.assert_called_with(query_text="Q1", k=4)

# Verify Gold injection
assert "GOLD_CONTENT" in contents
assert len(contents) == 3

# Expect D1, D2, GOLD (D3, D4 dropped)
assert set(contents) == {"D1", "D2", "GOLD_CONTENT"}


def test_gold_is_top_1(mock_config, mock_loader, mock_injector, mock_retriever_class):
"""
Test scenario where Gold is the #1 retrieved item.
Expect it to be kept, and we fill the rest with D1, D2.
"""
mock_instance = mock_retriever_class.return_value
# Gold is first
mock_instance.query.return_value = ["GOLD_CONTENT", "D1", "D2", "D3"]
mock_config.retrieval_k = 3

builder = DatasetBuilder(mock_loader, mock_injector, mock_config)
dataset = builder.build()

sample = dataset.samples[0]
contents = [c.content for c in sample.context]

mock_instance.query.assert_called_with(query_text="Q1", k=4)

assert "GOLD_CONTENT" in contents
assert len(contents) == 3
assert set(contents) == {"GOLD_CONTENT", "D1", "D2"}