Skip to content

Commit

Permalink
Add multi-turn self-refine for entity relationship extractor (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
NumberChiffre authored Oct 10, 2024
1 parent cddcda7 commit c061781
Show file tree
Hide file tree
Showing 7 changed files with 570 additions and 58 deletions.
15 changes: 12 additions & 3 deletions docs/benchmark-dspy-entity-extraction.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Main Takeaways
# Chain Of Thought Prompting with DSPy-AI (v2.4.16)
## Main Takeaways
- Time difference: 156.99 seconds
- Execution time with DSPy-AI: 304.38 seconds
- Execution time without DSPy-AI: 147.39 seconds
- Entities extracted: 22 (without DSPy-AI) vs 37 (with DSPy-AI)
- Relationships extracted: 21 (without DSPy-AI) vs 36 (with DSPy-AI)


# Results
## Results
```markdown
> python examples/benchmarks/dspy_entity.py

Expand Down Expand Up @@ -264,4 +265,12 @@ Relationships:
"朱元璋早年为刘德放牛,这段经历对他的成长有重要影响。"
- "朱元璋" -> "吴老太":
"朱元璋曾希望托吴老太找一个媳妇,显示了他对家庭的渴望。"
```
```

# Self-Refine with DSPy-AI (v2.5.6)
## Main Takeaways
- Time difference: 66.24 seconds
- Execution time with DSPy-AI: 211.04 seconds
- Execution time without DSPy-AI: 144.80 seconds
- Entities extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
- Relationships extracted: 38 (without DSPy-AI) vs 16 (with DSPy-AI)
16 changes: 7 additions & 9 deletions examples/benchmarks/dspy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import time
import shutil
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
from nano_graphrag._storage import NetworkXStorage, BaseKVStorage
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._storage import NetworkXStorage
from nano_graphrag._utils import compute_mdhash_id, compute_args_hash
from nano_graphrag._op import extract_entities

Expand Down Expand Up @@ -106,14 +107,12 @@ def print_extraction_results(graph_storage: NetworkXStorage):
async def run_benchmark(text: str):
print("\nRunning benchmark with DSPy-AI:")
system_prompt = """
You are a world-class AI system, capable of complex rationale and reflection.
Reason through the query, and then provide your final response.
If you detect that you made a mistake in your rationale at any point, correct yourself.
Think carefully.
You are an expert system specialized in entity and relationship extraction from complex texts.
Your task is to thoroughly analyze the given text and extract all relevant entities and their relationships with utmost precision and completeness.
"""
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
lm = dspy.OpenAI(
model="deepseek-chat",
lm = dspy.LM(
model="deepseek/deepseek-chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
Expand All @@ -127,7 +126,6 @@ async def run_benchmark(text: str):
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
print_extraction_results(graph_storage_with_dspy)

import pdb; pdb.set_trace()
print("Running benchmark without DSPy-AI:")
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
Expand All @@ -148,7 +146,7 @@ async def run_benchmark(text: str):


if __name__ == "__main__":
with open("./examples/data/test.txt", encoding="utf-8-sig") as f:
with open("./tests/zhuyuanzhang.txt", encoding="utf-8-sig") as f:
text = f.read()

asyncio.run(run_benchmark(text=text))
11 changes: 2 additions & 9 deletions examples/using_dspy_entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,12 @@ def query():


if __name__ == "__main__":
system_prompt = """
You are a world-class AI system, capable of complex rationale and reflection.
Reason through the query, and then provide your final response.
If you detect that you made a mistake in your rationale at any point, correct yourself.
Think carefully.
"""
lm = dspy.OpenAI(
model="deepseek-chat",
lm = dspy.LM(
model="deepseek/deepseek-chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=1.0,
max_tokens=8192
)
Expand Down
4 changes: 2 additions & 2 deletions nano_graphrag/entity_extraction/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def generate_dataset(
save_dataset: bool = True,
global_config: dict = {},
) -> list[dspy.Example]:
entity_extractor = TypedEntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
Expand Down Expand Up @@ -84,7 +84,7 @@ async def extract_entities_dspy(
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
entity_extractor = TypedEntityRelationshipExtractor()
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)

if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
Expand Down
173 changes: 142 additions & 31 deletions nano_graphrag/entity_extraction/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union
import dspy
from pydantic import BaseModel, Field
from nano_graphrag._utils import clean_str
from nano_graphrag._utils import logger


"""
Expand Down Expand Up @@ -75,6 +75,14 @@ class Entity(BaseModel):
description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.",
)

def to_dict(self):
return {
"entity_name": clean_str(self.entity_name.upper()),
"entity_type": clean_str(self.entity_type.upper()),
"description": clean_str(self.description),
"importance_score": float(self.importance_score),
}


class Relationship(BaseModel):
src_id: str = Field(..., description="The name of the source entity.")
Expand All @@ -96,6 +104,15 @@ class Relationship(BaseModel):
description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.",
)

def to_dict(self):
return {
"src_id": clean_str(self.src_id.upper()),
"tgt_id": clean_str(self.tgt_id.upper()),
"description": clean_str(self.description),
"weight": float(self.weight),
"order": int(self.order),
}


class CombinedExtraction(dspy.Signature):
"""
Expand Down Expand Up @@ -134,8 +151,85 @@ class CombinedExtraction(dspy.Signature):
entity_types: list[str] = dspy.InputField(
desc="List of entity types used for extraction."
)
entities_relationships: list[Union[Entity, Relationship]] = dspy.OutputField(
desc="List of entities and relationships extracted from the text."
entities: list[Entity] = dspy.OutputField(
desc="List of entities extracted from the text and the entity types."
)
relationships: list[Relationship] = dspy.OutputField(
desc="List of relationships extracted from the text and the entity types."
)


class CritiqueCombinedExtraction(dspy.Signature):
"""
Critique the current extraction of entities and relationships from a given text.
Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
Critique Guidelines:
1. Evaluate if all relevant entities from the input text are captured and correctly typed.
2. Check if entity descriptions are comprehensive and follow the provided guidelines.
3. Assess the completeness of relationship extractions, including higher-order relationships.
4. Verify that relationship descriptions are detailed and follow the provided guidelines.
5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
6. Suggest specific improvements or additions to enhance the quality of the extraction.
"""

input_text: str = dspy.InputField(
desc="The original text from which entities and relationships were extracted."
)
entity_types: list[str] = dspy.InputField(
desc="List of valid entity types for this extraction task."
)
current_entities: list[Entity] = dspy.InputField(
desc="List of currently extracted entities to be critiqued."
)
current_relationships: list[Relationship] = dspy.InputField(
desc="List of currently extracted relationships to be critiqued."
)
entity_critique: str = dspy.OutputField(
desc="Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
)
relationship_critique: str = dspy.OutputField(
desc="Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
)


class RefineCombinedExtraction(dspy.Signature):
"""
Refine the current extraction of entities and relationships based on the provided critique.
Improve completeness, accuracy, and adherence to the extraction guidelines.
Refinement Guidelines:
1. Address all points raised in the entity and relationship critiques.
2. Add missing entities and relationships identified in the critique.
3. Improve entity and relationship descriptions as suggested.
4. Ensure all refinements still adhere to the original extraction guidelines.
5. Maintain consistency between entities and relationships during refinement.
6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
"""

input_text: str = dspy.InputField(
desc="The original text from which entities and relationships were extracted."
)
entity_types: list[str] = dspy.InputField(
desc="List of valid entity types for this extraction task."
)
current_entities: list[Entity] = dspy.InputField(
desc="List of currently extracted entities to be refined."
)
current_relationships: list[Relationship] = dspy.InputField(
desc="List of currently extracted relationships to be refined."
)
entity_critique: str = dspy.InputField(
desc="Detailed critique of the current entities to guide refinement."
)
relationship_critique: str = dspy.InputField(
desc="Detailed critique of the current relationships to guide refinement."
)
refined_entities: list[Entity] = dspy.OutputField(
desc="List of refined entities, addressing the entity critique and improving upon the current entities."
)
refined_relationships: list[Relationship] = dspy.OutputField(
desc="List of refined relationships, addressing the relationship critique and improving upon the current relationships."
)


Expand All @@ -159,7 +253,7 @@ def forward(self, **kwargs):

except Exception as e:
if isinstance(e, self.exception_types):
return dspy.Prediction(entities_relationships=[])
return dspy.Prediction(entities=[], relationships=[])

raise e

Expand All @@ -168,46 +262,63 @@ class TypedEntityRelationshipExtractor(dspy.Module):
def __init__(
self,
lm: dspy.LM = None,
reasoning: dspy.OutputField = None,
max_retries: int = 3,
entity_types: list[str] = ENTITY_TYPES,
self_refine: bool = False,
num_refine_turns: int = 1
):
super().__init__()
self.lm = lm
self.entity_types = ENTITY_TYPES
self.extractor = dspy.TypedChainOfThought(
signature=CombinedExtraction, reasoning=reasoning, max_retries=max_retries
)
self.entity_types = entity_types
self.self_refine = self_refine
self.num_refine_turns = num_refine_turns

self.extractor = dspy.TypedChainOfThought(signature=CombinedExtraction, max_retries=max_retries)
self.extractor = TypedEntityRelationshipExtractorException(
self.extractor, exception_types=(ValueError,)
)

if self.self_refine:
self.critique = dspy.TypedChainOfThought(
signature=CritiqueCombinedExtraction,
max_retries=max_retries
)
self.refine = dspy.TypedChainOfThought(
signature=RefineCombinedExtraction,
max_retries=max_retries
)

def forward(self, input_text: str) -> dspy.Prediction:
with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm):
extraction_result = self.extractor(
input_text=input_text, entity_types=self.entity_types
)

current_entities: list[Entity] = extraction_result.entities
current_relationships: list[Relationship] = extraction_result.relationships

if self.self_refine:
for _ in range(self.num_refine_turns):
critique_result = self.critique(
input_text=input_text,
entity_types=self.entity_types,
current_entities=current_entities,
current_relationships=current_relationships
)
refined_result = self.refine(
input_text=input_text,
entity_types=self.entity_types,
current_entities=current_entities,
current_relationships=current_relationships,
entity_critique=critique_result.entity_critique,
relationship_critique=critique_result.relationship_critique
)
logger.debug(f"entities: {len(current_entities)} | refined_entities: {len(refined_result.refined_entities)}")
logger.debug(f"relationships: {len(current_relationships)} | refined_relationships: {len(refined_result.refined_relationships)}")
current_entities = refined_result.refined_entities
current_relationships = refined_result.refined_relationships

entities = [
dict(
entity_name=clean_str(entity.entity_name.upper()),
entity_type=clean_str(entity.entity_type.upper()),
description=clean_str(entity.description),
importance_score=float(entity.importance_score),
)
for entity in extraction_result.entities_relationships
if isinstance(entity, Entity)
]

relationships = [
dict(
src_id=clean_str(relationship.src_id.upper()),
tgt_id=clean_str(relationship.tgt_id.upper()),
description=clean_str(relationship.description),
weight=float(relationship.weight),
order=int(relationship.order),
)
for relationship in extraction_result.entities_relationships
if isinstance(relationship, Relationship)
]
entities = [entity.to_dict() for entity in current_entities]
relationships = [relationship.to_dict() for relationship in current_relationships]

return dspy.Prediction(entities=entities, relationships=relationships)
Loading

0 comments on commit c061781

Please sign in to comment.