diff --git a/evaluators/langevals/langevals_langevals/grapheval.py b/evaluators/langevals/langevals_langevals/grapheval.py new file mode 100644 index 0000000..fab2b52 --- /dev/null +++ b/evaluators/langevals/langevals_langevals/grapheval.py @@ -0,0 +1,325 @@ +# DISCLAIMER: some prompts are taken from the research paper https://arxiv.org/pdf/2407.10793. +# Creation of this module was inspired by that paper, so cheers to authors! +import json +from typing import Optional +from langevals_core.base_evaluator import ( + BaseEvaluator, + EvaluatorEntry, + EvaluationResult, + SingleEvaluationResult, + LLMEvaluatorSettings, + Money, +) +from pydantic import Field +import litellm +from litellm import Choices, Message, cast +from litellm.types.utils import ModelResponse +from litellm.cost_calculator import completion_cost +from dotenv import load_dotenv +import logging + +load_dotenv() + + +class GraphEvalEntry(EvaluatorEntry): + input: Optional[str] = Field(default="") + output: str + contexts: list[str] + + +class GraphEvalSettings(LLMEvaluatorSettings): + kg_construction_prompt: str = Field( + default="""You are an expert at extracting information in structured formats to build a knowledge graph . + Step 1 - Entity detection: Identify all entities in the raw text . Make sure not to miss any out. + Entities should be basic and simple, they are akin to Wikipedia nodes . + Step 2 - Coreference resolution: Find all expressions in the text that refer to the same entity. Make sure entities are not duplicated. In particular do not include + entities that are more specific versions themselves , e.g. "a detailed view of jupiter's atmosphere " and " jupiter's atmosphere ", only include the most specific version of the entity. + Step 3 - Relation extraction: Identify semantic relationships between the entities you have identified. + Format : Return the knowledge graph as a list of triples , i.e. ["entity 1", "relation 1 - 2", "entity 2"], in Python code """ + ) + context_to_knowledge_graph_comparison_prompt: str = Field( + default="""You are an expert at evaluating knowledge graph triples for factual accuracy and faithfulness to the given context. Your task is to: + + 1. Compare each triple in the knowledge graph against the provided context(s) + 2. Check for: + - Direct contradictions with the context + - Incorrect relationships between entities + - Misrepresented or fabricated facts + + For each triple, determine if it is: + - SUPPORTED: All information is explicitly stated in or can be directly inferred from the context + - CONTRADICTED: The triple conflicts with information in the context + - UNVERIFIABLE: The triple makes claims that cannot be verified from the context + + Return false if any triple is CONTRADICTED or UNVERIFIABLE, true if all triples are SUPPORTED.""" + ) + model: str = Field( + default="claude-3-5-sonnet-20240620", + description="The model to use for evaluation", + ) + + +class GraphEvalResult(EvaluationResult): + score: float = Field(default=0.0) + passed: Optional[bool] = Field( + default=True, description="True if the response is faithful, False otherwise" + ) + + +class GraphEvalEvaluator( + BaseEvaluator[GraphEvalEntry, GraphEvalSettings, GraphEvalResult] +): + """ + Allows you to check for hallucinations by utilizing Knowledge Graphs + """ + + name = "GraphEval" + category = "custom" + default_settings = GraphEvalSettings() + is_guardrail = True + + def evaluate(self, entry: GraphEvalEntry) -> SingleEvaluationResult: + details = None + passed = None + try: + knowledge_graph_response = self._construct_knowledge_graph(entry.output) + cost = completion_cost(knowledge_graph_response) or 0.0 + knowledge_graph = self._get_arguments( + knowledge_graph_response, value="triples" + ) + except Exception as e: + logging.error("Caught an exception while creating a knowledge graph: ", e) + + try: + if isinstance(knowledge_graph, list): + passed_response = self._compare_knowledge_graph_with_contexts( + knowledge_graph=knowledge_graph, contexts=entry.contexts + ) + cost += completion_cost(passed_response) or 0.0 + passed = self._get_arguments(passed_response, value="result") + except Exception as e: + logging.error( + "Caught an exception while comparing knowledge graph with contexts: ", e + ) + + if isinstance(passed, bool): + return GraphEvalResult( + passed=passed, + details=f"The following entity_1-relationship->entity_2 triples were found in the output: {knowledge_graph}", + cost=Money(amount=cost, currency="USD") if cost else None, + ) + return GraphEvalResult( + passed=False, + details="We could not evaluate faithfulness of the output", + cost=Money(amount=cost, currency="USD") if cost else None, + ) + + def _construct_knowledge_graph(self, output: str) -> ModelResponse: + tools = [ + { + "type": "function", + "function": { + "name": "create_knowledge_graph", + "description": "Create a knowledge graph from input text", + "parameters": { + "type": "object", + "properties": { + "triples": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entity_1": { + "type": "string", + "description": "First entity in the relationship", + }, + "relationship": { + "type": "string", + "description": "Relationship between entities", + }, + "entity_2": { + "type": "string", + "description": "Second entity in the relationship", + }, + }, + "required": [ + "entity_1", + "relationship", + "entity_2", + ], + }, + "description": "List of entity-relationship triples that construct a knowledge graph", + } + }, + "required": ["triples"], + }, + }, + } + ] + response = litellm.completion( + model=self.settings.model, + messages=[ + { + "role": "system", + "content": self.settings.kg_construction_prompt, + }, + { + "role": "user", + "content": f"""Use the given format to extract +information from the following input : { output } . Skip the preamble and output the result as a list within < python> tags. +Important Tips + 1. Make sure all information is included in the knowledge graph . + 2. Each triple must only contain three strings ! None of the strings should be empty . + 3. Do not split up related information into separate triples because this could change the meaning. + 4. Make sure all brackets and quotation marks are matched . + 5. Before adding a triple to the knowledge graph, checkn the concatenated triple makes sense as a sentence. If not, discard it. + + + Here are some example input and output pairs. + + ## Example 1. + Input: + "The Walt Disney Company, + commonly known as Disney, is + an American multinational + mass media and entertainment + conglomerate that is + headquartered at the Walt + Disney Studios complex in + Burbank, California." + Output: + + [['The Walt Disney Company', ' + headquartered at', 'Walt + Disney Studios complex in + Burbank, California'], + ['The Walt Disney Company', ' + commonly known as', 'Disney' + ], + ['The Walt Disney Company', ' + instance of', 'American + multinational mass media and + entertainment conglomerate']] + + + ## Example 2. + Input: + "Amanda Jackson was born in + Springfield, Ohio, USA on + June 1, 1985. She was a + basketball player for the U.S + women's team." + Output: + + [['Amanda Jackson', 'born in', + 'Springfield, Ohio, USA'], + ['Amanda Jackson', 'born on', + 'June 1, 1985'], + ['Amanda Jackson', 'occupation', + 'basketball player'], + ['Amanda Jackson', 'played for', + 'U.S. women's basketball team + ']] + + ## Example 3. + Input: + "Music executive Darius Van Arman + was born in Pennsylvania. He + attended Gonzaga College + + High School and is a human + being." + Output: + + [['Darius Van Arman', ' + occupation', 'Music executive + '], + ['Darius Van Arman', 'born in', ' + Pennsylvania'], + ['Darius Van Arman', 'attended', + 'Gonzaga College High School + '], ['Darius Van Arman', ' + instance of', 'human being']] + + + ## Example 4. + Input: "Italy had 3.6x times more + cases of coronavirus than + China." + + [['Italy', 'had 3.6x times more + cases of coronavirus than', ' + China']] + + + """, + }, + ], + tools=tools, + tool_choice={ + "type": "function", + "function": {"name": "create_knowledge_graph"}, + }, + ) + response = cast(ModelResponse, response) + return response + + def _compare_knowledge_graph_with_contexts( + self, + knowledge_graph: list[str], + contexts: list[str], + ) -> ModelResponse: + + tools = [ + { + "type": "function", + "function": { + "name": "compare_knowledge_graph_with_contexts", + "description": "Check if the knowledge graph triples are faithful and factually accurate to the contexts", + "parameters": { + "type": "object", + "properties": { + "result": { + "type": "boolean", + "description": "True if the knowledge graph is faithful and factually accurate to the contexts, False otherwise", + } + }, + "required": ["result"], + }, + }, + } + ] + + response = litellm.completion( + model=self.settings.model, + messages=[ + { + "role": "system", + "content": self.settings.context_to_knowledge_graph_comparison_prompt, + }, + { + "role": "user", + "content": f"""{knowledge_graph} + + {contexts} + """, + }, + ], + tools=tools, + tool_choice={ + "type": "function", + "function": {"name": "compare_knowledge_graph_with_contexts"}, + }, + ) + response = cast(ModelResponse, response) + return response + + def _get_arguments(self, response: ModelResponse, value: str) -> str | bool: + choice = cast(Choices, response.choices[0]) + arguments = json.loads( + cast(Message, choice.message).tool_calls[0].function.arguments # type: ignore + ) + return arguments.get( + value, + f"{value} was not found in the arguments", + ) diff --git a/evaluators/langevals/tests/test_grapheval.py b/evaluators/langevals/tests/test_grapheval.py new file mode 100644 index 0000000..a8b8e4b --- /dev/null +++ b/evaluators/langevals/tests/test_grapheval.py @@ -0,0 +1,179 @@ +from langevals_langevals.grapheval import ( + GraphEvalEvaluator, + GraphEvalEntry, + GraphEvalSettings, +) +from langevals_core.base_evaluator import EvaluationResultError, EvaluationResultSkipped +from litellm.types.utils import ModelResponse +from unittest.mock import patch, MagicMock + + +def test_grapheval_evaluator_passed(): + entry = GraphEvalEntry( + output="Your effort is really appreciated!", + contexts=["John's efforts are really appreciated"], + ) + evaluator = GraphEvalEvaluator() + result = evaluator.evaluate(entry) + + assert result.status == "processed" + assert result.passed == True + + +def test_grapheval_empty(): + entry = GraphEvalEntry( + output="", + contexts=[""], + ) + evaluator = GraphEvalEvaluator(settings=GraphEvalSettings()) + result = evaluator.evaluate(entry) + + assert result.status == "processed" + assert result.passed == False + + +def test_grapheval_2_contexts(): + evaluator = GraphEvalEvaluator() + result = evaluator.evaluate( + GraphEvalEntry( + output="The capital of France is Paris.", + contexts=[ + "France is a country in Europe.", + "Paris is a city in France.", + "Paris is the capital of France", + ], + ) + ) + assert result.status == "processed" + assert result.passed == True + + +def test_grapheval_evaluator_failed(): + entry = GraphEvalEntry( + output="John's effort is really appreciated!", + contexts=["Frank, your effor is really appreciated"], + ) + evaluator = GraphEvalEvaluator(settings=GraphEvalSettings()) + result = evaluator.evaluate(entry) + + assert result.status == "processed" + assert result.passed == False + + +def test_grapheval_i_dont_know(): + entry = GraphEvalEntry( + output="I don't know the answer, please try again later.", + contexts=[ + "SRP is applicable both to classes and functions.", + "OCD is applicable to classes as well as to the functions.", + ], + ) + evaluator = GraphEvalEvaluator(settings=GraphEvalSettings()) + result = evaluator.evaluate(entry) + + assert result.status == "processed" + assert result.passed == False + assert result.cost + + +def test_grapheval_knowledge_graph_construction_output(): + entry = GraphEvalEntry( + output="Italy had 3.6x times more cases of coronavirus than China", + contexts=["John's efforts are really appreciated"], + ) + evaluator = GraphEvalEvaluator(settings=GraphEvalSettings()) + result = evaluator._construct_knowledge_graph(entry.output) + print(result) + assert isinstance(result, ModelResponse) + assert evaluator._get_arguments(result, value="triples") == [ + ["Italy", "had 3.6x times more cases of coronavirus than", "China"] + ] + + +def test_grapheval_cost(): + entry = GraphEvalEntry( + output="Italy had 3.6x times more cases of coronavirus than China", + contexts=["John's efforts are really appreciated"], + ) + evaluator = GraphEvalEvaluator(settings=GraphEvalSettings()) + result = evaluator.evaluate(entry) + assert result.status == "processed" + assert result.cost and result.cost.amount > 0.0001 + + +def test_grapheval_malformed_model_response(): + """Test when the model response is missing expected fields.""" + entry = GraphEvalEntry( + output="Some output", + contexts=["Some context"], + ) + evaluator = GraphEvalEvaluator() + with patch.object(evaluator, "_construct_knowledge_graph") as mock_kg: + mock_kg.return_value = MagicMock() + with patch.object( + evaluator, + "_get_arguments", + return_value="triples was not found in the arguments", + ): + result = evaluator.evaluate(entry) + if not ( + isinstance(result, EvaluationResultError) + or isinstance(result, EvaluationResultSkipped) + ): + assert result.passed is False + assert "could not evaluate" in (result.details or "").lower() + + +def test_grapheval_empty_output_and_contexts(): + """Test with both output and contexts empty.""" + entry = GraphEvalEntry( + output="", + contexts=[], + ) + evaluator = GraphEvalEvaluator() + result = evaluator.evaluate(entry) + if not ( + isinstance(result, EvaluationResultError) + or isinstance(result, EvaluationResultSkipped) + ): + assert result.passed == False + + +def test_grapheval_partial_triple(): + """Test when a triple is missing required fields.""" + entry = GraphEvalEntry( + output="Some output", + contexts=["Some context"], + ) + evaluator = GraphEvalEvaluator() + malformed_triple = [{"entity_1": "A", "relationship": "rel"}] # missing entity_2 + with patch.object(evaluator, "_construct_knowledge_graph") as mock_kg: + mock_kg.return_value = MagicMock() + with patch.object( + evaluator, "_get_arguments", side_effect=[malformed_triple, True] + ): + result = evaluator.evaluate(entry) + if not ( + isinstance(result, EvaluationResultError) + or isinstance(result, EvaluationResultSkipped) + ): + assert ( + result.passed is False or result.passed is True + ) # should not crash if an entity is missing + + +def test_grapheval_evaluator_details_check(): + entry = GraphEvalEntry( + output="Your effort is really appreciated!", + contexts=["John's efforts are really appreciated"], + ) + evaluator = GraphEvalEvaluator() + result = evaluator.evaluate(entry) + print(result.details) + + assert result.status == "processed" + assert result.passed == True + assert ( + result.details + == """The following entity_1-relationship->entity_2 triples were found in the output: [['effort', 'is', 'appreciated']]""" + )