Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empty Entity extraction results from Llama 3.1 8B #15

Closed
NumberChiffre opened this issue Aug 27, 2024 · 1 comment
Closed

Empty Entity extraction results from Llama 3.1 8B #15

NumberChiffre opened this issue Aug 27, 2024 · 1 comment

Comments

@NumberChiffre
Copy link
Collaborator

NumberChiffre commented Aug 27, 2024

Description

Hey guys,

I'm trying out the repo with the same structure as the DeepSeek example with local llama 3.1 8B as both the best and cheap model. The problem is, I get empty dictionaries from the entity extraction results. I thought it was a token context problem, so I even reduced the max token size down to 4096 and even 1024 with the same outcome, so they were probably not the cause. I need your help to figure out why this is not working with llama 3.1 8B, so far I'm thinking the cause could be a combination of these:

  • Model capabilities: Llama 3.1 8B might not be as capable as GPT-4 in understanding and following complex instructions for entity extraction.
  • Prompt format: The entity extraction prompt might be optimized for GPT-4 and not suitable for Llama 3.1 8B.

I can confirm that this works with either GPT-4 or DeepSeek-v2 chat. Do we need some kind of prompt format specifically for smaller models?

Updates:

Simplifying the entity extraction prompt to something super simple for llama 3.1 8B did not work:

entity_extract_prompt = """
Extract entities from the following text. For each entity, provide:
1. Entity type (PERSON, LOCATION, ORGANIZATION, etc.)
2. Entity name
3. Brief description

Text: {input_text}

Entities:
"""

Error output:

Below is the error that I got with llama 3.1 8B:

DEBUG:nano-graphrag:Entity extraction results: [({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {})]
INFO:nano-graphrag:Inserting 0 vectors to entities
Traceback (most recent call last):
  File "/Users/tiliu/Documents/nano-graphrag/examples/using_ollama_as_llm.py", line 103, in <module>
    insert(text=text)
  File "/Users/tiliu/Documents/nano-graphrag/examples/using_ollama_as_llm.py", line 95, in insert
    rag.insert(text)
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/graphrag.py", line 145, in insert
    return loop.run_until_complete(self.ainsert(string_or_strings))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/nano-graphrag/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/graphrag.py", line 226, in ainsert
    self.chunk_entity_relation_graph = await extract_entities(
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/_op.py", line 335, in extract_entities
    await entity_vdb.upsert(data_for_vdb)
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/_storage.py", line 108, in upsert
    embeddings = np.concatenate(embeddings_list)
ValueError: need at least one array to concatenate

Code to reproduce the error:

import os
import logging
from ollama import AsyncClient
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)

os.environ["OPENAI_API_KEY"] = "sk-......"
OLLAMA_MODEL = "llama3.1"
WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"


async def ollama_model_if_cache(
    prompt: str, system_prompt: str = None, history_messages: list = [], **kwargs
) -> str:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    # Get the cached response if having-------------------
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    if hashing_kv is not None:
        args_hash = compute_args_hash(OLLAMA_MODEL, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]
    # -----------------------------------------------------

    client = AsyncClient()
    response = await client.chat(model=OLLAMA_MODEL, messages=messages)
    content = response['message']['content']

    # Cache the response if having-------------------
    if hashing_kv is not None:
        await hashing_kv.upsert(
            {args_hash: {"return": content, "model": OLLAMA_MODEL}}
        )
    # -----------------------------------------------------
    return content


def remove_if_exist(file):
    if os.path.exists(file):
        os.remove(file)


def load_files(file_directory: str) -> list[str]:
    file_paths = [os.path.join(file_directory, file) for file in os.listdir(file_directory)]
    contents = []
    for file_path in file_paths:
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8-sig') as file:
                contents.append(file.read().strip())
        else:
            print(f"Warning: File not found - {file_path}")
    return contents


def query(query: str, param: QueryParam):
    rag = GraphRAG(
        working_dir=WORKING_DIR,
        best_model_func=ollama_model_if_cache,
        cheap_model_func=ollama_model_if_cache,
        best_model_max_token_size=4096,
        best_model_max_async=8,
        cheap_model_max_token_size=4096,
        cheap_model_max_async=8,
    )
    print(rag.query(query=query, param=param))


def insert(text: str | list[str]):
    from time import time
    remove_if_exist(f"{WORKING_DIR}/milvus_lite.db")
    remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
    remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
    rag = GraphRAG(
        working_dir=WORKING_DIR,
        enable_llm_cache=True,
        best_model_func=ollama_model_if_cache,
        cheap_model_func=ollama_model_if_cache,
        best_model_max_token_size=1024,
        best_model_max_async=2,
        cheap_model_max_token_size=1024,
        cheap_model_max_async=2,
    )
    start = time()
    rag.insert(text)
    print("indexing time:", time() - start)


if __name__ == "__main__":
    with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
        text = f.read()
    insert(text=text)
    # query(
    #     query="What are the main themes in these documents?",
    #     param=QueryParam(mode="global"),
    # )
@gusye1234
Copy link
Owner

Yeah, I agree with you on the specific prompts for smaller models.
Many developers have said that smaller models like qwen2-7B have troubles on extracting entities and relations, I added a FAQ.md to claim this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@NumberChiffre @gusye1234 and others