diff --git a/src/dcv_benchmark/data_factory/builder.py b/src/dcv_benchmark/data_factory/builder.py index fadf5a9..13d2ffa 100644 --- a/src/dcv_benchmark/data_factory/builder.py +++ b/src/dcv_benchmark/data_factory/builder.py @@ -20,7 +20,8 @@ class DatasetBuilder: """ - Orchestrates the creation of a RAG Security Dataset. + Orchestrates the creation of a RAG Security Dataset based on the + SQUAD dataset. Workflow: 1. Load raw samples (Query + Gold Chunk) from a corpus. @@ -71,20 +72,33 @@ def build(self) -> Dataset: for raw in raw_samples: # A. Retrieve Contexts (Distractors + Candidates) - # We fetch k chunks. Note: Gold chunk might be in here if retrieval is good. - retrieved_texts = self.retriever.query( - query_text=raw.query, k=self.config.retrieval_k + # We fetch k+1 chunks to handle the "Oracle" property robustly. + # This allows us to verify if the golden sample would naturally appear in + # the top K or if we need to force-inject it, without potentially losing + # the K-th best distractor when replacing. + retrieved_candidates = self.retriever.query( + query_text=raw.query, k=self.config.retrieval_k + 1 ) # B. Enforce Gold Chunk Presence (The "Oracle" Property) - # In this benchmark, we want to test Integrity, not Recall. - # So we MUST ensure the correct answer context is present. - final_context_texts = retrieved_texts - if raw.source_document not in final_context_texts: - # Replace the last retrieved chunk with the Gold Chunk - if final_context_texts: - final_context_texts.pop() - final_context_texts.insert(0, raw.source_document) + # Logic: + # 1. Start with the Gold Chunk (source_document). + # 2. Add as many retrieved candidates as possible, skipping any that are + # identical to Gold. + # 3. Stop when we reach K total items. + final_context_texts = [raw.source_document] + + for candidate in retrieved_candidates: + if len(final_context_texts) >= self.config.retrieval_k: + break + + # Exact string match check for deduplication + if candidate != raw.source_document: + final_context_texts.append(candidate) + + # Note: This logic always puts Gold at index 0 initially. + # If random shuffle is needed, it should happen later or be explicitly + # handled. # C. Determine Sample Type (Attack vs Benign) is_attack = rng.random() < self.config.attack_rate diff --git a/tests/integration/test_config_options.py b/tests/integration/test_config_options.py index 7ee2a7d..84b1c8a 100644 --- a/tests/integration/test_config_options.py +++ b/tests/integration/test_config_options.py @@ -28,7 +28,7 @@ def mock_dataset_sample(): sample_type="benign", query="test query", labels=[], - context=[{"content": "test context", "source": "test", "id": "1"}], # added id + context=[{"content": "test context", "source": "test", "id": "1"}], ) diff --git a/tests/unit/data_factory/test_builder_golden.py b/tests/unit/data_factory/test_builder_golden.py new file mode 100644 index 0000000..89eb00a --- /dev/null +++ b/tests/unit/data_factory/test_builder_golden.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dcv_benchmark.data_factory.builder import DatasetBuilder +from dcv_benchmark.models.data_factory import DataFactoryConfig, RawSample + + +@pytest.fixture +def mock_config(): + return DataFactoryConfig( + dataset_name="repro_ds", + description="repro", + source_file="dummy.json", + attack_strategy="naive", + attack_payload="dummy", + attack_rate=0.0, # Benign to focus on retrieval + retrieval_k=3, + truncate_overflow=False, + ) + + +@pytest.fixture +def mock_loader(): + loader = MagicMock() + loader.load.return_value = [ + RawSample( + id="1", query="Q1", source_document="GOLD_CONTENT", reference_answer="Ans1" + ), + ] + return loader + + +@pytest.fixture +def mock_injector(): + return MagicMock() + + +@pytest.fixture +def mock_retriever_class(): + with patch("dcv_benchmark.data_factory.builder.EphemeralRetriever") as MockClass: + yield MockClass + + +def test_gold_in_retrieved_k_plus_one( + mock_config, mock_loader, mock_injector, mock_retriever_class +): + """ + Test scenario where we retrieve k+1 (4) items. + The Gold sample is the 4th item (index 3). + We expect the builder to keep the Gold sample and the top 2 others. + Total context size should be 3. + """ + mock_instance = mock_retriever_class.return_value + # Return 4 items: Top1, Top2, Top3, GOLD_CONTENT + mock_instance.query.return_value = ["D1", "D2", "D3", "GOLD_CONTENT"] + + # Update config to ensure we are testing k=3 behavior + mock_config.retrieval_k = 3 + + builder = DatasetBuilder(mock_loader, mock_injector, mock_config) + dataset = builder.build() + + sample = dataset.samples[0] + contents = [c.content for c in sample.context] + + mock_instance.query.assert_called_with(query_text="Q1", k=4) + + # 2. Verify Gold is present + assert "GOLD_CONTENT" in contents + + # 3. Verify size is 3 + assert len(contents) == 3 + + # 4. Verify we kept D1, D2 and GOLD (discarded D3 effectively, or whatever + # priority logic). If we strictly take top k-1 + gold, it should be D1, D2, GOLD. + # Note: The order in 'contents' might vary depending on implementation + # (insert at 0 vs preserve order). + assert set(contents) == {"D1", "D2", "GOLD_CONTENT"} + + +def test_gold_missing_from_retrieved( + mock_config, mock_loader, mock_injector, mock_retriever_class +): + """ + Test scenario where Gold is NOT in the retrieved 4 items. + We expect replacement of the last item with Gold. + """ + mock_instance = mock_retriever_class.return_value + # Return 4 distractors + mock_instance.query.return_value = ["D1", "D2", "D3", "D4"] + mock_config.retrieval_k = 3 + + builder = DatasetBuilder(mock_loader, mock_injector, mock_config) + dataset = builder.build() + + sample = dataset.samples[0] + contents = [c.content for c in sample.context] + + # Verify k=4 call + mock_instance.query.assert_called_with(query_text="Q1", k=4) + + # Verify Gold injection + assert "GOLD_CONTENT" in contents + assert len(contents) == 3 + + # Expect D1, D2, GOLD (D3, D4 dropped) + assert set(contents) == {"D1", "D2", "GOLD_CONTENT"} + + +def test_gold_is_top_1(mock_config, mock_loader, mock_injector, mock_retriever_class): + """ + Test scenario where Gold is the #1 retrieved item. + Expect it to be kept, and we fill the rest with D1, D2. + """ + mock_instance = mock_retriever_class.return_value + # Gold is first + mock_instance.query.return_value = ["GOLD_CONTENT", "D1", "D2", "D3"] + mock_config.retrieval_k = 3 + + builder = DatasetBuilder(mock_loader, mock_injector, mock_config) + dataset = builder.build() + + sample = dataset.samples[0] + contents = [c.content for c in sample.context] + + mock_instance.query.assert_called_with(query_text="Q1", k=4) + + assert "GOLD_CONTENT" in contents + assert len(contents) == 3 + assert set(contents) == {"GOLD_CONTENT", "D1", "D2"}