From a35cb12741c171483d69afe395ccdf2c5fceb3a7 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 3 Jan 2025 13:59:26 -0800 Subject: [PATCH] Remove datashaper strip code (#1581) Remove datashaper --- .../minor-20241231213627966329.json | 4 + .../minor-20241231214323349946.json | 4 + dictionary.txt | 4 - docs/examples_notebooks/index_migration.ipynb | 3 +- docs/index/architecture.md | 28 +- examples/README.md | 19 -- examples/__init__.py | 2 - examples/custom_input/__init__.py | 2 - examples/custom_input/pipeline.yml | 24 -- examples/custom_input/run.py | 46 --- examples/single_verb/__init__.py | 2 - examples/single_verb/input/data.csv | 3 - examples/single_verb/pipeline.yml | 12 - examples/single_verb/run.py | 77 ----- examples/use_built_in_workflows/__init__.py | 2 - examples/use_built_in_workflows/pipeline.yml | 23 -- examples/use_built_in_workflows/run.py | 118 -------- graphrag/api/index.py | 42 ++- graphrag/api/prompt_tune.py | 2 +- graphrag/callbacks/blob_workflow_callbacks.py | 3 +- .../callbacks/console_workflow_callbacks.py | 2 +- .../callbacks/delegating_verb_callbacks.py | 46 +++ graphrag/callbacks/factory.py | 3 +- graphrag/callbacks/file_workflow_callbacks.py | 2 +- graphrag/callbacks/noop_verb_callbacks.py | 35 +++ graphrag/callbacks/noop_workflow_callbacks.py | 46 +++ .../callbacks/progress_workflow_callbacks.py | 13 +- graphrag/callbacks/verb_callbacks.py | 38 +++ graphrag/callbacks/workflow_callbacks.py | 58 ++++ .../callbacks/workflow_callbacks_manager.py | 83 ++++++ graphrag/cli/query.py | 48 ++- graphrag/config/create_graphrag_config.py | 2 +- graphrag/config/defaults.py | 3 +- graphrag/config/enums.py | 7 + graphrag/config/models/llm_config.py | 2 +- graphrag/index/config/embeddings.py | 42 +++ graphrag/index/config/input.py | 6 - graphrag/index/config/workflow.py | 8 - graphrag/index/context.py | 7 - graphrag/index/create_pipeline_config.py | 48 +-- graphrag/index/exporter.py | 55 ---- graphrag/index/flows/__init__.py | 2 +- .../index/flows/create_base_text_units.py | 85 +----- .../flows/create_final_community_reports.py | 6 +- .../index/flows/create_final_covariates.py | 6 +- graphrag/index/flows/create_final_nodes.py | 4 +- graphrag/index/flows/extract_graph.py | 6 +- .../index/flows/generate_text_embeddings.py | 13 +- graphrag/index/llm/load_llm.py | 3 +- graphrag/index/load_pipeline_config.py | 78 ----- .../index/operations/chunk_text/chunk_text.py | 7 +- .../index/operations/chunk_text/strategies.py | 2 +- .../index/operations/chunk_text/typing.py | 3 +- .../index/operations/embed_text/embed_text.py | 2 +- .../operations/embed_text/strategies/mock.py | 4 +- .../embed_text/strategies/openai.py | 3 +- .../embed_text/strategies/typing.py | 3 +- .../extract_covariates/extract_covariates.py | 10 +- .../operations/extract_covariates/typing.py | 3 +- .../extract_entities/extract_entities.py | 10 +- .../graph_intelligence_strategy.py | 2 +- .../extract_entities/nltk_strategy.py | 2 +- .../operations/extract_entities/typing.py | 2 +- .../operations/layout_graph/layout_graph.py | 2 +- graphrag/index/operations/snapshot.py | 24 -- graphrag/index/operations/snapshot_graphml.py | 2 +- .../prepare_community_reports.py | 6 +- .../summarize_communities/strategies.py | 2 +- .../summarize_communities.py | 14 +- .../summarize_communities/typing.py | 2 +- .../graph_intelligence_strategy.py | 2 +- .../summarize_descriptions.py | 7 +- .../summarize_descriptions/typing.py | 3 +- graphrag/index/run/__init__.py | 4 - graphrag/index/run/derive_from_rows.py | 158 ++++++++++ graphrag/index/run/postprocess.py | 50 ---- graphrag/index/run/profiling.py | 71 ----- graphrag/index/run/run.py | 282 ------------------ graphrag/index/run/run_workflows.py | 199 ++++++++++++ graphrag/index/run/utils.py | 94 +----- graphrag/index/run/workflow.py | 154 ---------- graphrag/index/update/entities.py | 16 +- graphrag/index/update/incremental_index.py | 89 +++--- graphrag/index/utils/ds_util.py | 32 -- graphrag/index/utils/load_graph.py | 11 - graphrag/index/utils/topological_sort.py | 12 - graphrag/index/validate_config.py | 3 +- graphrag/index/workflows/__init__.py | 121 ++++++-- .../index/workflows/compute_communities.py | 40 +++ .../index/workflows/create_base_text_units.py | 41 +++ .../workflows/create_final_communities.py | 43 +++ .../create_final_community_reports.py | 55 ++++ .../workflows/create_final_covariates.py | 49 +++ .../index/workflows/create_final_documents.py | 37 +++ .../index/workflows/create_final_entities.py | 33 ++ .../index/workflows/create_final_nodes.py | 49 +++ .../workflows/create_final_relationships.py | 33 ++ .../workflows/create_final_text_units.py | 49 +++ graphrag/index/workflows/default_workflows.py | 93 ------ graphrag/index/workflows/extract_graph.py | 71 +++++ .../workflows/generate_text_embeddings.py | 57 ++++ graphrag/index/workflows/load.py | 169 ----------- graphrag/index/workflows/typing.py | 33 -- graphrag/index/workflows/v1/__init__.py | 4 - .../index/workflows/v1/compute_communities.py | 89 ------ .../workflows/v1/create_base_text_units.py | 107 ------- .../workflows/v1/create_final_communities.py | 60 ---- .../v1/create_final_community_reports.py | 107 ------- .../workflows/v1/create_final_covariates.py | 80 ----- .../workflows/v1/create_final_documents.py | 67 ----- .../workflows/v1/create_final_entities.py | 56 ---- .../index/workflows/v1/create_final_nodes.py | 76 ----- .../v1/create_final_relationships.py | 59 ---- .../workflows/v1/create_final_text_units.py | 86 ------ graphrag/index/workflows/v1/extract_graph.py | 130 -------- .../workflows/v1/generate_text_embeddings.py | 110 ------- graphrag/logger/base.py | 2 +- graphrag/logger/progress.py | 82 +++++ graphrag/logger/rich_progress.py | 4 +- graphrag/prompt_tune/loader/input.py | 2 +- graphrag/storage/blob_pipeline_storage.py | 2 +- graphrag/storage/cosmosdb_pipeline_storage.py | 2 +- graphrag/storage/file_pipeline_storage.py | 2 +- graphrag/utils/storage.py | 23 +- poetry.lock | 30 +- pyproject.toml | 4 +- tests/unit/config/test_default_config.py | 57 +--- tests/unit/indexing/config/__init__.py | 2 - ...ault_config_with_everything_overridden.yml | 20 -- .../default_config_with_overridden_input.yml | 5 - ...fault_config_with_overridden_workflows.yml | 6 - tests/unit/indexing/config/helpers.py | 59 ---- tests/unit/indexing/config/test_load.py | 131 -------- tests/unit/indexing/test_exports.py | 10 - tests/unit/indexing/workflows/__init__.py | 2 - tests/unit/indexing/workflows/helpers.py | 31 -- tests/unit/indexing/workflows/test_export.py | 125 -------- tests/unit/indexing/workflows/test_load.py | 237 --------------- tests/verbs/test_compute_communities.py | 33 +- tests/verbs/test_create_base_text_units.py | 58 +--- tests/verbs/test_create_final_communities.py | 36 ++- .../test_create_final_community_reports.py | 90 +++--- tests/verbs/test_create_final_covariates.py | 69 ++--- tests/verbs/test_create_final_documents.py | 61 ++-- tests/verbs/test_create_final_entities.py | 27 +- tests/verbs/test_create_final_nodes.py | 43 +-- .../verbs/test_create_final_relationships.py | 28 +- tests/verbs/test_create_final_text_units.py | 82 +++-- tests/verbs/test_extract_graph.py | 125 +++----- tests/verbs/test_generate_text_embeddings.py | 71 ++--- tests/verbs/util.py | 65 +--- 151 files changed, 2033 insertions(+), 4066 deletions(-) create mode 100644 .semversioner/next-release/minor-20241231213627966329.json create mode 100644 .semversioner/next-release/minor-20241231214323349946.json delete mode 100644 examples/README.md delete mode 100644 examples/__init__.py delete mode 100644 examples/custom_input/__init__.py delete mode 100644 examples/custom_input/pipeline.yml delete mode 100644 examples/custom_input/run.py delete mode 100644 examples/single_verb/__init__.py delete mode 100644 examples/single_verb/input/data.csv delete mode 100644 examples/single_verb/pipeline.yml delete mode 100644 examples/single_verb/run.py delete mode 100644 examples/use_built_in_workflows/__init__.py delete mode 100644 examples/use_built_in_workflows/pipeline.yml delete mode 100644 examples/use_built_in_workflows/run.py create mode 100644 graphrag/callbacks/delegating_verb_callbacks.py create mode 100644 graphrag/callbacks/noop_verb_callbacks.py create mode 100644 graphrag/callbacks/noop_workflow_callbacks.py create mode 100644 graphrag/callbacks/verb_callbacks.py create mode 100644 graphrag/callbacks/workflow_callbacks.py create mode 100644 graphrag/callbacks/workflow_callbacks_manager.py delete mode 100644 graphrag/index/exporter.py delete mode 100644 graphrag/index/load_pipeline_config.py delete mode 100644 graphrag/index/operations/snapshot.py create mode 100644 graphrag/index/run/derive_from_rows.py delete mode 100644 graphrag/index/run/postprocess.py delete mode 100644 graphrag/index/run/profiling.py delete mode 100644 graphrag/index/run/run.py create mode 100644 graphrag/index/run/run_workflows.py delete mode 100644 graphrag/index/run/workflow.py delete mode 100644 graphrag/index/utils/ds_util.py delete mode 100644 graphrag/index/utils/load_graph.py delete mode 100644 graphrag/index/utils/topological_sort.py create mode 100644 graphrag/index/workflows/compute_communities.py create mode 100644 graphrag/index/workflows/create_base_text_units.py create mode 100644 graphrag/index/workflows/create_final_communities.py create mode 100644 graphrag/index/workflows/create_final_community_reports.py create mode 100644 graphrag/index/workflows/create_final_covariates.py create mode 100644 graphrag/index/workflows/create_final_documents.py create mode 100644 graphrag/index/workflows/create_final_entities.py create mode 100644 graphrag/index/workflows/create_final_nodes.py create mode 100644 graphrag/index/workflows/create_final_relationships.py create mode 100644 graphrag/index/workflows/create_final_text_units.py delete mode 100644 graphrag/index/workflows/default_workflows.py create mode 100644 graphrag/index/workflows/extract_graph.py create mode 100644 graphrag/index/workflows/generate_text_embeddings.py delete mode 100644 graphrag/index/workflows/load.py delete mode 100644 graphrag/index/workflows/typing.py delete mode 100644 graphrag/index/workflows/v1/__init__.py delete mode 100644 graphrag/index/workflows/v1/compute_communities.py delete mode 100644 graphrag/index/workflows/v1/create_base_text_units.py delete mode 100644 graphrag/index/workflows/v1/create_final_communities.py delete mode 100644 graphrag/index/workflows/v1/create_final_community_reports.py delete mode 100644 graphrag/index/workflows/v1/create_final_covariates.py delete mode 100644 graphrag/index/workflows/v1/create_final_documents.py delete mode 100644 graphrag/index/workflows/v1/create_final_entities.py delete mode 100644 graphrag/index/workflows/v1/create_final_nodes.py delete mode 100644 graphrag/index/workflows/v1/create_final_relationships.py delete mode 100644 graphrag/index/workflows/v1/create_final_text_units.py delete mode 100644 graphrag/index/workflows/v1/extract_graph.py delete mode 100644 graphrag/index/workflows/v1/generate_text_embeddings.py create mode 100644 graphrag/logger/progress.py delete mode 100644 tests/unit/indexing/config/__init__.py delete mode 100644 tests/unit/indexing/config/default_config_with_everything_overridden.yml delete mode 100644 tests/unit/indexing/config/default_config_with_overridden_input.yml delete mode 100644 tests/unit/indexing/config/default_config_with_overridden_workflows.yml delete mode 100644 tests/unit/indexing/config/helpers.py delete mode 100644 tests/unit/indexing/config/test_load.py delete mode 100644 tests/unit/indexing/test_exports.py delete mode 100644 tests/unit/indexing/workflows/__init__.py delete mode 100644 tests/unit/indexing/workflows/helpers.py delete mode 100644 tests/unit/indexing/workflows/test_export.py delete mode 100644 tests/unit/indexing/workflows/test_load.py diff --git a/.semversioner/next-release/minor-20241231213627966329.json b/.semversioner/next-release/minor-20241231213627966329.json new file mode 100644 index 0000000000..93dbd4f4a0 --- /dev/null +++ b/.semversioner/next-release/minor-20241231213627966329.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Remove old pipeline runner." +} diff --git a/.semversioner/next-release/minor-20241231214323349946.json b/.semversioner/next-release/minor-20241231214323349946.json new file mode 100644 index 0000000000..a62cae7b78 --- /dev/null +++ b/.semversioner/next-release/minor-20241231214323349946.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Remove DataShaper (first steps)." +} diff --git a/dictionary.txt b/dictionary.txt index 575fe55548..02851fcd7b 100644 --- a/dictionary.txt +++ b/dictionary.txt @@ -148,10 +148,6 @@ codebases # Microsoft MSRC -# Broken Upstream -# TODO FIX IN DATASHAPER -Arrary - # Prompt Inputs ABILA Abila diff --git a/docs/examples_notebooks/index_migration.ipynb b/docs/examples_notebooks/index_migration.ipynb index a0ba6ae471..5021fa2cbb 100644 --- a/docs/examples_notebooks/index_migration.ipynb +++ b/docs/examples_notebooks/index_migration.ipynb @@ -206,9 +206,8 @@ "metadata": {}, "outputs": [], "source": [ - "from datashaper import NoopVerbCallbacks\n", - "\n", "from graphrag.cache.factory import create_cache\n", + "from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n", "from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n", "\n", "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", diff --git a/docs/index/architecture.md b/docs/index/architecture.md index 12c6015012..2d5d110ba7 100644 --- a/docs/index/architecture.md +++ b/docs/index/architecture.md @@ -8,33 +8,9 @@ In order to support the GraphRAG system, the outputs of the indexing engine (in This model is designed to be an abstraction over the underlying data storage technology, and to provide a common interface for the GraphRAG system to interact with. In normal use-cases the outputs of the GraphRAG Indexer would be loaded into a database system, and the GraphRAG's Query Engine would interact with the database using the knowledge model data-store types. -### DataShaper Workflows - -GraphRAG's Indexing Pipeline is built on top of our open-source library, [DataShaper](https://github.com/microsoft/datashaper). -DataShaper is a data processing library that allows users to declaratively express data pipelines, schemas, and related assets using well-defined schemas. -DataShaper has implementations in JavaScript and Python, and is designed to be extensible to other languages. - -One of the core resource types within DataShaper is a [Workflow](https://github.com/microsoft/datashaper/blob/main/javascript/schema/src/workflow/WorkflowSchema.ts). -Workflows are expressed as sequences of steps, which we call [verbs](https://github.com/microsoft/datashaper/blob/main/javascript/schema/src/workflow/verbs.ts). -Each step has a verb name and a configuration object. -In DataShaper, these verbs model relational concepts such as SELECT, DROP, JOIN, etc.. Each verb transforms an input data table, and that table is passed down the pipeline. - -```mermaid ---- -title: Sample Workflow ---- -flowchart LR - input[Input Table] --> select[SELECT] --> join[JOIN] --> binarize[BINARIZE] --> output[Output Table] -``` - -### LLM-based Workflow Steps - -GraphRAG's Indexing Pipeline implements a handful of custom verbs on top of the standard, relational verbs that our DataShaper library provides. These verbs give us the ability to augment text documents with rich, structured data using the power of LLMs such as GPT-4. We utilize these verbs in our standard workflow to extract entities, relationships, claims, community structures, and community reports and summaries. This behavior is customizable and can be extended to support many kinds of AI-based data enrichment and extraction tasks. - -### Workflow Graphs +### Workflows Because of the complexity of our data indexing tasks, we needed to be able to express our data pipeline as series of multiple, interdependent workflows. -In the GraphRAG Indexing Pipeline, each workflow may define dependencies on other workflows, effectively forming a directed acyclic graph (DAG) of workflows, which is then used to schedule processing. ```mermaid --- @@ -55,7 +31,7 @@ stateDiagram-v2 The primary unit of communication between workflows, and between workflow steps is an instance of `pandas.DataFrame`. Although side-effects are possible, our goal is to be _data-centric_ and _table-centric_ in our approach to data processing. This allows us to easily reason about our data, and to leverage the power of dataframe-based ecosystems. -Our underlying dataframe technology may change over time, but our primary goal is to support the DataShaper workflow schema while retaining single-machine ease of use and developer ergonomics. +Our underlying dataframe technology may change over time, but our primary goal is to support the workflow schema while retaining single-machine ease of use and developer ergonomics. ### LLM Caching diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index 5d80b4fd4c..0000000000 --- a/examples/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Indexing Engine Examples -This directory contains several examples of how to use the indexing engine. - -Most examples include two different forms of running the pipeline, both are contained in the examples `run.py` -1. Using mostly the Python API -2. Using mostly the a pipeline configuration file - -# Running an Example -First run `poetry shell` to activate a virtual environment with the required dependencies. - -Then run `PYTHONPATH="$(pwd)" python examples/path_to_example/run.py` from the `python/graphrag` directory. - -For example to run the single_verb example, you would run the following commands: - -```bash -cd python/graphrag -poetry shell -PYTHONPATH="$(pwd)" python examples/single_verb/run.py -``` \ No newline at end of file diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/examples/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/examples/custom_input/__init__.py b/examples/custom_input/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/examples/custom_input/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/examples/custom_input/pipeline.yml b/examples/custom_input/pipeline.yml deleted file mode 100644 index 80340c8291..0000000000 --- a/examples/custom_input/pipeline.yml +++ /dev/null @@ -1,24 +0,0 @@ - -# Setup reporting however you'd like -reporting: - type: console - -# Setup storage however you'd like -storage: - type: memory - -# Setup cache however you'd like -cache: - type: memory - -# Just a simple workflow -workflows: - - # This is an anonymous workflow, it doesn't have a name - - steps: - - # Unpack the nodes from the graph - - verb: fill - args: - to: filled_column - value: "Filled Value" \ No newline at end of file diff --git a/examples/custom_input/run.py b/examples/custom_input/run.py deleted file mode 100644 index debb022379..0000000000 --- a/examples/custom_input/run.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import asyncio -import os - -import pandas as pd - -from graphrag.index.run import run_pipeline_with_config - -pipeline_file = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "./pipeline.yml" -) - - -async def run(): - # Load your dataset - dataset = _load_dataset_some_unique_way() - - # Load your config without the input section - config = pipeline_file - - # Grab the last result from the pipeline, should be our entity extraction - outputs = [] - async for output in run_pipeline_with_config( - config_or_path=config, dataset=dataset - ): - outputs.append(output) - pipeline_result = outputs[-1] - - if pipeline_result.result is not None: - # Should look something like - # col1 col2 filled_column - # 0 2 4 Filled Value - # 1 5 10 Filled Value - print(pipeline_result.result) - else: - print("No results!") - - -def _load_dataset_some_unique_way() -> pd.DataFrame: - # Totally loaded from some other place - return pd.DataFrame([{"col1": 2, "col2": 4}, {"col1": 5, "col2": 10}]) - - -if __name__ == "__main__": - asyncio.run(run()) diff --git a/examples/single_verb/__init__.py b/examples/single_verb/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/examples/single_verb/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/examples/single_verb/input/data.csv b/examples/single_verb/input/data.csv deleted file mode 100644 index d1aaf77bfe..0000000000 --- a/examples/single_verb/input/data.csv +++ /dev/null @@ -1,3 +0,0 @@ -col1,col2 -2,4 -5,10 \ No newline at end of file diff --git a/examples/single_verb/pipeline.yml b/examples/single_verb/pipeline.yml deleted file mode 100644 index 9e8046124d..0000000000 --- a/examples/single_verb/pipeline.yml +++ /dev/null @@ -1,12 +0,0 @@ -input: - file_type: csv - base_dir: ./input - file_pattern: .*\.csv$ -workflows: - - steps: - - verb: derive # https://github.com/microsoft/datashaper/blob/main/python/datashaper/datashaper/verbs/derive.py - args: - column1: "col1" - column2: "col2" - to: "col_multiplied" - operator: "*" diff --git a/examples/single_verb/run.py b/examples/single_verb/run.py deleted file mode 100644 index 99f8137a98..0000000000 --- a/examples/single_verb/run.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import asyncio -import os - -import pandas as pd - -from graphrag.index.config.workflow import PipelineWorkflowReference -from graphrag.index.run import run_pipeline, run_pipeline_with_config - -# our fake dataset -dataset = pd.DataFrame([{"col1": 2, "col2": 4}, {"col1": 5, "col2": 10}]) - - -async def run_with_config(): - """Run a pipeline with a config file""" - # load pipeline.yml in this directory - config_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "./pipeline.yml" - ) - - tables = [] - async for table in run_pipeline_with_config( - config_or_path=config_path, dataset=dataset - ): - tables.append(table) - pipeline_result = tables[-1] - - if pipeline_result.result is not None: - # Should look something like this, which should be identical to the python example: - # col1 col2 col_multiplied - # 0 2 4 8 - # 1 5 10 50 - print(pipeline_result.result) - else: - print("No results!") - - -async def run_python(): - """Run a pipeline using the python API""" - workflows: list[PipelineWorkflowReference] = [ - PipelineWorkflowReference( - steps=[ - { - # built-in verb - "verb": "derive", # https://github.com/microsoft/datashaper/blob/main/python/datashaper/datashaper/verbs/derive.py - "args": { - "column1": "col1", # from above - "column2": "col2", # from above - "to": "col_multiplied", # new column name - "operator": "*", # multiply the two columns - }, - # Since we're trying to act on the default input, we don't need explicitly to specify an input - } - ] - ), - ] - - # Grab the last result from the pipeline, should be our entity extraction - tables = [] - async for table in run_pipeline(dataset=dataset, workflows=workflows): - tables.append(table) - pipeline_result = tables[-1] - - if pipeline_result.result is not None: - # Should look something like this: - # col1 col2 col_multiplied - # 0 2 4 8 - # 1 5 10 50 - print(pipeline_result.result) - else: - print("No results!") - - -if __name__ == "__main__": - asyncio.run(run_with_config()) - asyncio.run(run_python()) diff --git a/examples/use_built_in_workflows/__init__.py b/examples/use_built_in_workflows/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/examples/use_built_in_workflows/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/examples/use_built_in_workflows/pipeline.yml b/examples/use_built_in_workflows/pipeline.yml deleted file mode 100644 index cb1896857f..0000000000 --- a/examples/use_built_in_workflows/pipeline.yml +++ /dev/null @@ -1,23 +0,0 @@ -workflows: - - name: "entity_extraction" - config: - entity_extract: - strategy: - type: "nltk" - - - name: "entity_graph" - config: - cluster_graph: - strategy: - type: "leiden" - embed_graph: - strategy: - type: "node2vec" - num_walks: 10 - walk_length: 40 - window_size: 2 - iterations: 3 - random_seed: 597832 - layout_graph: - strategy: - type: "umap" \ No newline at end of file diff --git a/examples/use_built_in_workflows/run.py b/examples/use_built_in_workflows/run.py deleted file mode 100644 index adda7f6b4c..0000000000 --- a/examples/use_built_in_workflows/run.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import asyncio -import os - -from graphrag.index.config.input import PipelineCSVInputConfig -from graphrag.index.config.workflow import PipelineWorkflowReference -from graphrag.index.input.factory import create_input -from graphrag.index.run import run_pipeline, run_pipeline_with_config - -sample_data_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../_sample_data/" -) - -# Load our dataset once -shared_dataset = asyncio.run( - create_input( - PipelineCSVInputConfig( - file_pattern=".*\\.csv$", - base_dir=sample_data_dir, - source_column="author", - text_column="message", - timestamp_column="date(yyyyMMddHHmmss)", - timestamp_format="%Y%m%d%H%M%S", - title_column="message", - ), - ) -) - - -async def run_with_config(): - """Run a pipeline with a config file""" - # We're cheap, and this is an example, lets just do 10 - dataset = shared_dataset.head(10) - - # load pipeline.yml in this directory - config_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "./pipeline.yml" - ) - - # Grab the last result from the pipeline, should be our entity extraction - tables = [] - async for table in run_pipeline_with_config( - config_or_path=config_path, dataset=dataset - ): - tables.append(table) - pipeline_result = tables[-1] - - if pipeline_result.result is not None: - # The output of this should match the run_python() example - first_result = pipeline_result.result.head(1) - print(f"level: {first_result['level'][0]}") - print(f"embeddings: {first_result['embeddings'][0]}") - print(f"entity_graph_positions: {first_result['node_positions'][0]}") - else: - print("No results!") - - -async def run_python(): - # We're cheap, and this is an example, lets just do 10 - dataset = shared_dataset.head(10) - - workflows: list[PipelineWorkflowReference] = [ - # This workflow reference here is only necessary - # because we want to customize the entity_extraction workflow is configured - # otherwise, it can be omitted, but you're stuck with the default configuration for entity_extraction - PipelineWorkflowReference( - name="entity_extraction", - config={ - "entity_extract": { - "strategy": { - "type": "nltk", - } - } - }, - ), - PipelineWorkflowReference( - name="entity_graph", - config={ - "cluster_graph": {"strategy": {"type": "leiden"}}, - "embed_graph": { - "strategy": { - "type": "node2vec", - "num_walks": 10, - "walk_length": 40, - "window_size": 2, - "iterations": 3, - "random_seed": 597832, - } - }, - "layout_graph": { - "strategy": { - "type": "umap", - }, - }, - }, - ), - ] - - # Grab the last result from the pipeline, should be our entity extraction - tables = [] - async for table in run_pipeline(dataset=dataset, workflows=workflows): - tables.append(table) - pipeline_result = tables[-1] - - # The output will contain entity graphs per hierarchical level, with embeddings per entity - if pipeline_result.result is not None: - first_result = pipeline_result.result.head(1) - print(f"level: {first_result['level'][0]}") - print(f"embeddings: {first_result['embeddings'][0]}") - print(f"entity_graph_positions: {first_result['node_positions'][0]}") - else: - print("No results!") - - -if __name__ == "__main__": - asyncio.run(run_python()) - asyncio.run(run_with_config()) diff --git a/graphrag/api/index.py b/graphrag/api/index.py index d12afc06da..009a1ba5cf 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -8,17 +8,19 @@ Backwards compatibility is not guaranteed at this time. """ -from datashaper import WorkflowCallbacks +import logging from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.callbacks.factory import create_pipeline_reporter +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import CacheType from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.index.run import run_pipeline_with_config +from graphrag.index.run.run_workflows import run_workflows from graphrag.index.typing import PipelineRunResult from graphrag.logger.base import ProgressLogger +log = logging.getLogger(__name__) + async def build_index( config: GraphRagConfig, @@ -56,7 +58,6 @@ async def build_index( msg = "Cannot resume and update a run at the same time." raise ValueError(msg) - pipeline_config = create_pipeline_config(config) pipeline_cache = ( NoopPipelineCache() if config.cache.type == CacheType.none is None else None ) @@ -65,14 +66,19 @@ async def build_index( callbacks = callbacks or [] callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore outputs: list[PipelineRunResult] = [] - async for output in run_pipeline_with_config( - pipeline_config, - run_id=run_id, - memory_profile=memory_profile, + + if memory_profile: + log.warning("New pipeline does not yet support memory profiling.") + + workflows = _get_workflows_list(config) + + async for output in run_workflows( + workflows, + config, cache=pipeline_cache, callbacks=callbacks, logger=progress_logger, - is_resume_run=is_resume_run, + run_id=run_id, is_update_run=is_update_run, ): outputs.append(output) @@ -82,4 +88,22 @@ async def build_index( else: progress_logger.success(output.workflow) progress_logger.info(str(output.result)) + return outputs + + +def _get_workflows_list(config: GraphRagConfig) -> list[str]: + return [ + "create_base_text_units", + "create_final_documents", + "extract_graph", + "compute_communities", + "create_final_entities", + "create_final_relationships", + "create_final_nodes", + "create_final_communities", + *(["create_final_covariates"] if config.claim_extraction.enabled else []), + "create_final_text_units", + "create_final_community_reports", + "generate_text_embeddings", + ] diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 9d0823e93f..98c1dac3ba 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -11,9 +11,9 @@ Backwards compatibility is not guaranteed at this time. """ -from datashaper import NoopVerbCallbacks from pydantic import PositiveInt, validate_call +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.llm.load_llm import load_llm from graphrag.logger.print_progress import PrintProgressLogger diff --git a/graphrag/callbacks/blob_workflow_callbacks.py b/graphrag/callbacks/blob_workflow_callbacks.py index 56ed317a9f..36bd5f9e83 100644 --- a/graphrag/callbacks/blob_workflow_callbacks.py +++ b/graphrag/callbacks/blob_workflow_callbacks.py @@ -10,7 +10,8 @@ from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient -from datashaper import NoopWorkflowCallbacks + +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks class BlobWorkflowCallbacks(NoopWorkflowCallbacks): diff --git a/graphrag/callbacks/console_workflow_callbacks.py b/graphrag/callbacks/console_workflow_callbacks.py index 4e70ba7109..a2ab6ef08a 100644 --- a/graphrag/callbacks/console_workflow_callbacks.py +++ b/graphrag/callbacks/console_workflow_callbacks.py @@ -3,7 +3,7 @@ """A logger that emits updates from the indexing engine to the console.""" -from datashaper import NoopWorkflowCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks): diff --git a/graphrag/callbacks/delegating_verb_callbacks.py b/graphrag/callbacks/delegating_verb_callbacks.py new file mode 100644 index 0000000000..11687f3a24 --- /dev/null +++ b/graphrag/callbacks/delegating_verb_callbacks.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Contains the DelegatingVerbCallback definition.""" + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.logger.progress import Progress + + +class DelegatingVerbCallbacks(VerbCallbacks): + """A wrapper that implements VerbCallbacks that delegates to the underlying WorkflowCallbacks.""" + + _workflow_callbacks: WorkflowCallbacks + _name: str + + def __init__(self, name: str, workflow_callbacks: WorkflowCallbacks): + """Create a new instance of DelegatingVerbCallbacks.""" + self._workflow_callbacks = workflow_callbacks + self._name = name + + def progress(self, progress: Progress) -> None: + """Handle when progress occurs.""" + self._workflow_callbacks.on_step_progress(self._name, progress) + + def error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Handle when an error occurs.""" + self._workflow_callbacks.on_error(message, cause, stack, details) + + def warning(self, message: str, details: dict | None = None) -> None: + """Handle when a warning occurs.""" + self._workflow_callbacks.on_warning(message, details) + + def log(self, message: str, details: dict | None = None) -> None: + """Handle when a log occurs.""" + self._workflow_callbacks.on_log(message, details) + + def measure(self, name: str, value: float, details: dict | None = None) -> None: + """Handle when a measurement occurs.""" + self._workflow_callbacks.on_measure(name, value, details) diff --git a/graphrag/callbacks/factory.py b/graphrag/callbacks/factory.py index 26b33b713b..bffc3f2cc2 100644 --- a/graphrag/callbacks/factory.py +++ b/graphrag/callbacks/factory.py @@ -6,11 +6,10 @@ from pathlib import Path from typing import cast -from datashaper import WorkflowCallbacks - from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import ReportingType from graphrag.index.config.reporting import ( PipelineBlobReportingConfig, diff --git a/graphrag/callbacks/file_workflow_callbacks.py b/graphrag/callbacks/file_workflow_callbacks.py index 95ccfea272..b3b5ca1963 100644 --- a/graphrag/callbacks/file_workflow_callbacks.py +++ b/graphrag/callbacks/file_workflow_callbacks.py @@ -8,7 +8,7 @@ from io import TextIOWrapper from pathlib import Path -from datashaper import NoopWorkflowCallbacks +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks log = logging.getLogger(__name__) diff --git a/graphrag/callbacks/noop_verb_callbacks.py b/graphrag/callbacks/noop_verb_callbacks.py new file mode 100644 index 0000000000..5a2000af67 --- /dev/null +++ b/graphrag/callbacks/noop_verb_callbacks.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Defines the interface for verb callbacks.""" + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.logger.progress import Progress + + +class NoopVerbCallbacks(VerbCallbacks): + """A noop implementation of the verb callbacks.""" + + def __init__(self) -> None: + pass + + def progress(self, progress: Progress) -> None: + """Report a progress update from the verb execution".""" + + def error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Report a error from the verb execution.""" + + def warning(self, message: str, details: dict | None = None) -> None: + """Report a warning from verb execution.""" + + def log(self, message: str, details: dict | None = None) -> None: + """Report an informational message from the verb execution.""" + + def measure(self, name: str, value: float) -> None: + """Report a telemetry measurement from the verb execution.""" diff --git a/graphrag/callbacks/noop_workflow_callbacks.py b/graphrag/callbacks/noop_workflow_callbacks.py new file mode 100644 index 0000000000..2e8d6b883d --- /dev/null +++ b/graphrag/callbacks/noop_workflow_callbacks.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A no-op implementation of WorkflowCallbacks.""" + +from typing import Any + +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.logger.progress import Progress + + +class NoopWorkflowCallbacks(WorkflowCallbacks): + """A no-op implementation of WorkflowCallbacks.""" + + def on_workflow_start(self, name: str, instance: object) -> None: + """Execute this callback when a workflow starts.""" + + def on_workflow_end(self, name: str, instance: object) -> None: + """Execute this callback when a workflow ends.""" + + def on_step_start(self, step_name: str) -> None: + """Execute this callback every time a step starts.""" + + def on_step_end(self, step_name: str, result: Any) -> None: + """Execute this callback every time a step ends.""" + + def on_step_progress(self, step_name: str, progress: Progress) -> None: + """Handle when progress occurs.""" + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Handle when an error occurs.""" + + def on_warning(self, message: str, details: dict | None = None) -> None: + """Handle when a warning occurs.""" + + def on_log(self, message: str, details: dict | None = None) -> None: + """Handle when a log message occurs.""" + + def on_measure(self, name: str, value: float, details: dict | None = None) -> None: + """Handle when a measurement occurs.""" diff --git a/graphrag/callbacks/progress_workflow_callbacks.py b/graphrag/callbacks/progress_workflow_callbacks.py index 9fda1e0c06..1dc4ada022 100644 --- a/graphrag/callbacks/progress_workflow_callbacks.py +++ b/graphrag/callbacks/progress_workflow_callbacks.py @@ -5,9 +5,9 @@ from typing import Any -from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer - +from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): @@ -39,16 +39,15 @@ def on_workflow_end(self, name: str, instance: object) -> None: """Execute this callback when a workflow ends.""" self._pop() - def on_step_start(self, node: ExecutionNode, inputs: dict[str, Any]) -> None: + def on_step_start(self, step_name: str) -> None: """Execute this callback every time a step starts.""" - verb_id_str = f" ({node.node_id})" if node.has_explicit_id else "" - self._push(f"Verb {node.verb.name}{verb_id_str}") + self._push(f"Step {step_name}") self._latest(Progress(percent=0)) - def on_step_end(self, node: ExecutionNode, result: TableContainer | None) -> None: + def on_step_end(self, step_name: str, result: Any) -> None: """Execute this callback every time a step ends.""" self._pop() - def on_step_progress(self, node: ExecutionNode, progress: Progress) -> None: + def on_step_progress(self, step_name: str, progress: Progress) -> None: """Handle when progress occurs.""" self._latest(progress) diff --git a/graphrag/callbacks/verb_callbacks.py b/graphrag/callbacks/verb_callbacks.py new file mode 100644 index 0000000000..9489b4cab3 --- /dev/null +++ b/graphrag/callbacks/verb_callbacks.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Defines the interface for verb callbacks.""" + +from typing import Protocol + +from graphrag.logger.progress import Progress + + +class VerbCallbacks(Protocol): + """Provides a way to report status updates from the pipeline.""" + + def progress(self, progress: Progress) -> None: + """Report a progress update from the verb execution".""" + ... + + def error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Report a error from the verb execution.""" + ... + + def warning(self, message: str, details: dict | None = None) -> None: + """Report a warning from verb execution.""" + ... + + def log(self, message: str, details: dict | None = None) -> None: + """Report an informational message from the verb execution.""" + ... + + def measure(self, name: str, value: float) -> None: + """Report a telemetry measurement from the verb execution.""" + ... diff --git a/graphrag/callbacks/workflow_callbacks.py b/graphrag/callbacks/workflow_callbacks.py new file mode 100644 index 0000000000..f1adec6cb6 --- /dev/null +++ b/graphrag/callbacks/workflow_callbacks.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Collection of callbacks that can be used to monitor the workflow execution.""" + +from typing import Any, Protocol + +from graphrag.logger.progress import Progress + + +class WorkflowCallbacks(Protocol): + """ + A collection of callbacks that can be used to monitor the workflow execution. + + This base class is a "noop" implementation so that clients may implement just the callbacks they need. + """ + + def on_workflow_start(self, name: str, instance: object) -> None: + """Execute this callback when a workflow starts.""" + ... + + def on_workflow_end(self, name: str, instance: object) -> None: + """Execute this callback when a workflow ends.""" + ... + + def on_step_start(self, step_name: str) -> None: + """Execute this callback every time a step starts.""" + ... + + def on_step_end(self, step_name: str, result: Any) -> None: + """Execute this callback every time a step ends.""" + ... + + def on_step_progress(self, step_name: str, progress: Progress) -> None: + """Handle when progress occurs.""" + ... + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Handle when an error occurs.""" + ... + + def on_warning(self, message: str, details: dict | None = None) -> None: + """Handle when a warning occurs.""" + ... + + def on_log(self, message: str, details: dict | None = None) -> None: + """Handle when a log message occurs.""" + ... + + def on_measure(self, name: str, value: float, details: dict | None = None) -> None: + """Handle when a measurement occurs.""" + ... diff --git a/graphrag/callbacks/workflow_callbacks_manager.py b/graphrag/callbacks/workflow_callbacks_manager.py new file mode 100644 index 0000000000..d677462cb7 --- /dev/null +++ b/graphrag/callbacks/workflow_callbacks_manager.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the WorkflowCallbacks registry.""" + +from typing import Any + +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.logger.progress import Progress + + +class WorkflowCallbacksManager(WorkflowCallbacks): + """A registry of WorkflowCallbacks.""" + + _callbacks: list[WorkflowCallbacks] + + def __init__(self): + """Create a new instance of WorkflowCallbacksRegistry.""" + self._callbacks = [] + + def register(self, callbacks: WorkflowCallbacks) -> None: + """Register a new WorkflowCallbacks type.""" + self._callbacks.append(callbacks) + + def on_workflow_start(self, name: str, instance: object) -> None: + """Execute this callback when a workflow starts.""" + for callback in self._callbacks: + if hasattr(callback, "on_workflow_start"): + callback.on_workflow_start(name, instance) + + def on_workflow_end(self, name: str, instance: object) -> None: + """Execute this callback when a workflow ends.""" + for callback in self._callbacks: + if hasattr(callback, "on_workflow_end"): + callback.on_workflow_end(name, instance) + + def on_step_start(self, step_name: str) -> None: + """Execute this callback every time a step starts.""" + for callback in self._callbacks: + if hasattr(callback, "on_step_start"): + callback.on_step_start(step_name) + + def on_step_end(self, step_name: str, result: Any) -> None: + """Execute this callback every time a step ends.""" + for callback in self._callbacks: + if hasattr(callback, "on_step_end"): + callback.on_step_end(step_name, result) + + def on_step_progress(self, step_name: str, progress: Progress) -> None: + """Handle when progress occurs.""" + for callback in self._callbacks: + if hasattr(callback, "on_step_progress"): + callback.on_step_progress(step_name, progress) + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + """Handle when an error occurs.""" + for callback in self._callbacks: + if hasattr(callback, "on_error"): + callback.on_error(message, cause, stack, details) + + def on_warning(self, message: str, details: dict | None = None) -> None: + """Handle when a warning occurs.""" + for callback in self._callbacks: + if hasattr(callback, "on_warning"): + callback.on_warning(message, details) + + def on_log(self, message: str, details: dict | None = None) -> None: + """Handle when a log message occurs.""" + for callback in self._callbacks: + if hasattr(callback, "on_log"): + callback.on_log(message, details) + + def on_measure(self, name: str, value: float, details: dict | None = None) -> None: + """Handle when a measurement occurs.""" + for callback in self._callbacks: + if hasattr(callback, "on_measure"): + callback.on_measure(name, value, details) diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 7c45d24740..1d97b4fcfb 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -16,7 +16,7 @@ from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.logger.print_progress import PrintProgressLogger from graphrag.storage.factory import StorageFactory -from graphrag.utils.storage import load_table_from_storage +from graphrag.utils.storage import load_table_from_storage, storage_has_table logger = PrintProgressLogger("") @@ -43,10 +43,10 @@ def run_global_search( dataframe_dict = _resolve_output_files( config=config, output_list=[ - "create_final_nodes.parquet", - "create_final_entities.parquet", - "create_final_communities.parquet", - "create_final_community_reports.parquet", + "create_final_nodes", + "create_final_entities", + "create_final_communities", + "create_final_community_reports", ], optional_list=[], ) @@ -127,14 +127,14 @@ def run_local_search( dataframe_dict = _resolve_output_files( config=config, output_list=[ - "create_final_nodes.parquet", - "create_final_community_reports.parquet", - "create_final_text_units.parquet", - "create_final_relationships.parquet", - "create_final_entities.parquet", + "create_final_nodes", + "create_final_community_reports", + "create_final_text_units", + "create_final_relationships", + "create_final_entities", ], optional_list=[ - "create_final_covariates.parquet", + "create_final_covariates", ], ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] @@ -217,11 +217,11 @@ def run_drift_search( dataframe_dict = _resolve_output_files( config=config, output_list=[ - "create_final_nodes.parquet", - "create_final_community_reports.parquet", - "create_final_text_units.parquet", - "create_final_relationships.parquet", - "create_final_entities.parquet", + "create_final_nodes", + "create_final_community_reports", + "create_final_text_units", + "create_final_relationships", + "create_final_entities", ], ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] @@ -332,24 +332,20 @@ def _resolve_output_files( storage_obj = StorageFactory().create_storage( storage_type=storage_config["type"], kwargs=storage_config ) - for output_file in output_list: - df_key = output_file.split(".")[0] - df_value = asyncio.run( - load_table_from_storage(name=output_file, storage=storage_obj) - ) - dataframe_dict[df_key] = df_value + for name in output_list: + df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) + dataframe_dict[name] = df_value # for optional output files, set the dict entry to None instead of erroring out if it does not exist if optional_list: for optional_file in optional_list: - file_exists = asyncio.run(storage_obj.has(optional_file)) - df_key = optional_file.split(".")[0] + file_exists = asyncio.run(storage_has_table(optional_file, storage_obj)) if file_exists: df_value = asyncio.run( load_table_from_storage(name=optional_file, storage=storage_obj) ) - dataframe_dict[df_key] = df_value + dataframe_dict[optional_file] = df_value else: - dataframe_dict[df_key] = None + dataframe_dict[optional_file] = None return dataframe_dict diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 701069dd4b..433da098d6 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -8,11 +8,11 @@ from pathlib import Path from typing import Any, cast -from datashaper import AsyncType from environs import Env import graphrag.config.defaults as defs from graphrag.config.enums import ( + AsyncType, CacheType, InputFileType, InputType, diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index ac1c87d1e5..73f27dbe33 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -5,9 +5,8 @@ from pathlib import Path -from datashaper import AsyncType - from graphrag.config.enums import ( + AsyncType, CacheType, InputFileType, InputType, diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 4ff1e35571..b13da14874 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -114,3 +114,10 @@ class LLMType(str, Enum): def __repr__(self): """Get a string representation.""" return f'"{self.value}"' + + +class AsyncType(str, Enum): + """Enum for the type of async to use.""" + + AsyncIO = "asyncio" + Threaded = "threaded" diff --git a/graphrag/config/models/llm_config.py b/graphrag/config/models/llm_config.py index 3759bd949e..78459e8f1a 100644 --- a/graphrag/config/models/llm_config.py +++ b/graphrag/config/models/llm_config.py @@ -3,10 +3,10 @@ """Parameterization settings for the default configuration.""" -from datashaper import AsyncType from pydantic import BaseModel, Field import graphrag.config.defaults as defs +from graphrag.config.enums import AsyncType from graphrag.config.models.llm_parameters import LLMParameters from graphrag.config.models.parallelization_parameters import ParallelizationParameters diff --git a/graphrag/index/config/embeddings.py b/graphrag/index/config/embeddings.py index 02e9c912c4..74e4d11809 100644 --- a/graphrag/index/config/embeddings.py +++ b/graphrag/index/config/embeddings.py @@ -3,6 +3,10 @@ """A module containing embeddings values.""" +from graphrag.config.enums import TextEmbeddingTarget +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.text_embedding_config import TextEmbeddingConfig + entity_title_embedding = "entity.title" entity_description_embedding = "entity.description" relationship_description_embedding = "relationship.description" @@ -27,3 +31,41 @@ community_full_content_embedding, text_unit_text_embedding, } + + +def get_embedded_fields(settings: GraphRagConfig) -> set[str]: + """Get the fields to embed based on the enum or specifically skipped embeddings.""" + match settings.embeddings.target: + case TextEmbeddingTarget.all: + return all_embeddings.difference(settings.embeddings.skip) + case TextEmbeddingTarget.required: + return required_embeddings + case TextEmbeddingTarget.none: + return set() + case _: + msg = f"Unknown embeddings target: {settings.embeddings.target}" + raise ValueError(msg) + + +def get_embedding_settings( + settings: TextEmbeddingConfig, + vector_store_params: dict | None = None, +) -> dict: + """Transform GraphRAG config into settings for workflows.""" + # TEMP + vector_store_settings = settings.vector_store + if vector_store_settings is None: + return {"strategy": settings.resolved_strategy()} + # + # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. + # settings.vector_store.base contains connection information, or may be undefined + # settings.vector_store. contains the specific settings for this embedding + # + strategy = settings.resolved_strategy() # get the default strategy + strategy.update({ + "vector_store": {**(vector_store_params or {}), **vector_store_settings} + }) # update the default strategy with the vector store settings + # This ensures the vector store config is part of the strategy and not the global config + return { + "strategy": strategy, + } diff --git a/graphrag/index/config/input.py b/graphrag/index/config/input.py index 5d4b08dfa8..f9dad568d5 100644 --- a/graphrag/index/config/input.py +++ b/graphrag/index/config/input.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from graphrag.config.enums import InputFileType, InputType -from graphrag.index.config.workflow import PipelineWorkflowStep T = TypeVar("T") @@ -56,11 +55,6 @@ class PipelineInputConfig(BaseModel, Generic[T]): ) """The optional file filter for the input files.""" - post_process: list[PipelineWorkflowStep] | None = Field( - description="The post processing steps for the input.", default=None - ) - """The post processing steps for the input.""" - encoding: str | None = Field( description="The encoding for the input files.", default=None ) diff --git a/graphrag/index/config/workflow.py b/graphrag/index/config/workflow.py index 30e77d504f..58f1e5fddf 100644 --- a/graphrag/index/config/workflow.py +++ b/graphrag/index/config/workflow.py @@ -9,9 +9,6 @@ from pydantic import BaseModel, Field -PipelineWorkflowStep = dict[str, Any] -"""Represent a step in a workflow.""" - PipelineWorkflowConfig = dict[str, Any] """Represent a configuration for a workflow.""" @@ -22,11 +19,6 @@ class PipelineWorkflowReference(BaseModel): name: str | None = Field(description="Name of the workflow.", default=None) """Name of the workflow.""" - steps: list[PipelineWorkflowStep] | None = Field( - description="The optional steps for the workflow.", default=None - ) - """The optional steps for the workflow.""" - config: PipelineWorkflowConfig | None = Field( description="The optional configuration for the workflow.", default=None ) diff --git a/graphrag/index/context.py b/graphrag/index/context.py index c45decd173..c9242783c9 100644 --- a/graphrag/index/context.py +++ b/graphrag/index/context.py @@ -37,10 +37,3 @@ class PipelineRunContext: "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." cache: PipelineCache "Cache instance for reading previous LLM responses." - runtime_storage: PipelineStorage - "Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run." - - -# TODO: For now, just has the same props available to it -VerbRunContext = PipelineRunContext -"""Provides the context for the current verb run.""" diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 4ec4342222..75a213239e 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -12,11 +12,9 @@ InputFileType, ReportingType, StorageType, - TextEmbeddingTarget, ) from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.storage_config import StorageConfig -from graphrag.config.models.text_embedding_config import TextEmbeddingConfig from graphrag.index.config.cache import ( PipelineBlobCacheConfig, PipelineCacheConfigTypes, @@ -25,10 +23,7 @@ PipelineMemoryCacheConfig, PipelineNoneCacheConfig, ) -from graphrag.index.config.embeddings import ( - all_embeddings, - required_embeddings, -) +from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings from graphrag.index.config.input import ( PipelineCSVInputConfig, PipelineInputConfigTypes, @@ -53,7 +48,7 @@ from graphrag.index.config.workflow import ( PipelineWorkflowReference, ) -from graphrag.index.workflows.default_workflows import ( +from graphrag.index.workflows import ( compute_communities, create_base_text_units, create_final_communities, @@ -92,7 +87,7 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC _log_llm_settings(settings) skip_workflows = settings.skip_workflows - embedded_fields = _get_embedded_fields(settings) + embedded_fields = get_embedded_fields(settings) covariates_enabled = ( settings.claim_extraction.enabled and create_final_covariates not in skip_workflows @@ -123,19 +118,6 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC return result -def _get_embedded_fields(settings: GraphRagConfig) -> set[str]: - match settings.embeddings.target: - case TextEmbeddingTarget.all: - return all_embeddings.difference(settings.embeddings.skip) - case TextEmbeddingTarget.required: - return required_embeddings - case TextEmbeddingTarget.none: - return set() - case _: - msg = f"Unknown embeddings target: {settings.embeddings.target}" - raise ValueError(msg) - - def _log_llm_settings(settings: GraphRagConfig) -> None: log.info( "Using LLM Config %s", @@ -189,28 +171,6 @@ def _text_unit_workflows( ] -def _get_embedding_settings( - settings: TextEmbeddingConfig, - vector_store_params: dict | None = None, -) -> dict: - vector_store_settings = settings.vector_store - if vector_store_settings is None: - return {"strategy": settings.resolved_strategy()} - # - # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. - # settings.vector_store.base contains connection information, or may be undefined - # settings.vector_store. contains the specific settings for this embedding - # - strategy = settings.resolved_strategy() # get the default strategy - strategy.update({ - "vector_store": {**(vector_store_params or {}), **vector_store_settings} - }) # update the default strategy with the vector store settings - # This ensures the vector store config is part of the strategy and not the global config - return { - "strategy": strategy, - } - - def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]: return [ PipelineWorkflowReference( @@ -307,7 +267,7 @@ def _embeddings_workflows( name=generate_text_embeddings, config={ "snapshot_embeddings": settings.snapshots.embeddings, - "text_embed": _get_embedding_settings(settings.embeddings), + "text_embed": get_embedding_settings(settings.embeddings), "embedded_fields": embedded_fields, }, ), diff --git a/graphrag/index/exporter.py b/graphrag/index/exporter.py deleted file mode 100644 index 4910e87467..0000000000 --- a/graphrag/index/exporter.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""ParquetExporter module.""" - -import logging -import traceback - -import pandas as pd -from pyarrow.lib import ArrowInvalid, ArrowTypeError - -from graphrag.index.typing import ErrorHandlerFn -from graphrag.storage.pipeline_storage import PipelineStorage - -log = logging.getLogger(__name__) - - -class ParquetExporter: - """ParquetExporter class. - - A class that exports dataframe's to a storage destination in .parquet file format. - """ - - _storage: PipelineStorage - _on_error: ErrorHandlerFn - - def __init__( - self, - storage: PipelineStorage, - on_error: ErrorHandlerFn, - ): - """Create a new Parquet Table TableExporter.""" - self._storage = storage - self._on_error = on_error - - async def export(self, name: str, data: pd.DataFrame) -> None: - """Export dataframe to storage.""" - filename = f"{name}.parquet" - log.info("exporting parquet table %s", filename) - try: - await self._storage.set(filename, data.to_parquet()) - except ArrowTypeError as e: - log.exception("Error while exporting parquet table") - self._on_error( - e, - traceback.format_exc(), - None, - ) - except ArrowInvalid as e: - log.exception("Error while exporting parquet table") - self._on_error( - e, - traceback.format_exc(), - None, - ) diff --git a/graphrag/index/flows/__init__.py b/graphrag/index/flows/__init__.py index b09c865054..13b7827bb7 100644 --- a/graphrag/index/flows/__init__.py +++ b/graphrag/index/flows/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Core workflows without DataShaper wrappings.""" +"""Core workflows functions without workflow/pipeline wrappings.""" diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index 63f8f62b6e..33dad0aebd 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -3,20 +3,15 @@ """All the steps to transform base text_units.""" -from dataclasses import dataclass -from typing import Any, cast +from typing import cast import pandas as pd -from datashaper import ( - FieldAggregateOperation, - Progress, - VerbCallbacks, - aggregate_operation_mapping, -) +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.config.models.chunking_config import ChunkStrategyType from graphrag.index.operations.chunk_text.chunk_text import chunk_text from graphrag.index.utils.hashing import gen_sha512_hash +from graphrag.logger.progress import Progress def create_base_text_units( @@ -37,20 +32,16 @@ def create_base_text_units( callbacks.progress(Progress(percent=0)) - aggregated = _aggregate_df( - sort, - groupby=[*group_by_columns] if len(group_by_columns) > 0 else None, - aggregations=[ - { - "column": "text_with_ids", - "operation": "array_agg", - "to": "texts", - } - ], + aggregated = ( + ( + sort.groupby(group_by_columns, sort=False) + if len(group_by_columns) > 0 + else sort.groupby(lambda _x: True) + ) + .agg(texts=("text_with_ids", list)) + .reset_index() ) - callbacks.progress(Progress(percent=1)) - aggregated["chunks"] = chunk_text( aggregated, column="texts", @@ -81,57 +72,3 @@ def create_base_text_units( return cast( "pd.DataFrame", aggregated[aggregated["text"].notna()].reset_index(drop=True) ) - - -# TODO: would be nice to inline this completely in the main method with pandas -def _aggregate_df( - input: pd.DataFrame, - aggregations: list[dict[str, Any]], - groupby: list[str] | None = None, -) -> pd.DataFrame: - """Aggregate method definition.""" - aggregations_to_apply = _load_aggregations(aggregations) - df_aggregations = { - agg.column: _get_pandas_agg_operation(agg) - for agg in aggregations_to_apply.values() - } - if groupby is None: - output_grouped = input.groupby(lambda _x: True) - else: - output_grouped = input.groupby(groupby, sort=False) - output = cast("pd.DataFrame", output_grouped.agg(df_aggregations)) - output.rename( - columns={agg.column: agg.to for agg in aggregations_to_apply.values()}, - inplace=True, - ) - output.columns = [agg.to for agg in aggregations_to_apply.values()] - return output.reset_index() - - -@dataclass -class Aggregation: - """Aggregation class method definition.""" - - column: str | None - operation: str - to: str - - # Only useful for the concat operation - separator: str | None = None - - -def _get_pandas_agg_operation(agg: Aggregation) -> Any: - if agg.operation == "string_concat": - return (agg.separator or ",").join - return aggregate_operation_mapping[FieldAggregateOperation(agg.operation)] - - -def _load_aggregations( - aggregations: list[dict[str, Any]], -) -> dict[str, Aggregation]: - return { - aggregation["column"]: Aggregation( - aggregation["column"], aggregation["operation"], aggregation["to"] - ) - for aggregation in aggregations - } diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index 574945de9d..f94103db04 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -6,12 +6,10 @@ from uuid import uuid4 import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.operations.summarize_communities import ( prepare_community_reports, restore_community_hierarchy, diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py index f9b5f7e377..ce6cccaa9c 100644 --- a/graphrag/index/flows/create_final_covariates.py +++ b/graphrag/index/flows/create_final_covariates.py @@ -7,12 +7,10 @@ from uuid import uuid4 import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.operations.extract_covariates.extract_covariates import ( extract_covariates, ) diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index 511ff429e7..f75ef2733a 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -4,10 +4,8 @@ """All the steps to transform final nodes.""" import pandas as pd -from datashaper import ( - VerbCallbacks, -) +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.create_graph import create_graph diff --git a/graphrag/index/flows/extract_graph.py b/graphrag/index/flows/extract_graph.py index 87e369f525..8eaa4d2951 100644 --- a/graphrag/index/flows/extract_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -7,12 +7,10 @@ from uuid import uuid4 import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.operations.extract_entities import extract_entities from graphrag.index.operations.summarize_descriptions import ( summarize_descriptions, diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index 877966dab7..d8c547663d 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -6,11 +6,9 @@ import logging import pandas as pd -from datashaper import ( - VerbCallbacks, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.config.embeddings import ( community_full_content_embedding, community_summary_embedding, @@ -22,8 +20,8 @@ text_unit_text_embedding, ) from graphrag.index.operations.embed_text import embed_text -from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import write_table_to_storage log = logging.getLogger(__name__) @@ -131,9 +129,4 @@ async def _run_and_snapshot_embeddings( if snapshot_embeddings_enabled is True: data = data.loc[:, ["id", "embedding"]] - await snapshot( - data, - name=f"embeddings.{name}", - storage=storage, - formats=["parquet"], - ) + await write_table_to_storage(data, f"embeddings.{name}", storage) diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index ecd91b4bca..eae2cf34bd 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -29,9 +29,8 @@ from .mock_llm import MockChatLLM if TYPE_CHECKING: - from datashaper import VerbCallbacks - from graphrag.cache.pipeline_cache import PipelineCache + from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.typing import ErrorHandlerFn log = logging.getLogger(__name__) diff --git a/graphrag/index/load_pipeline_config.py b/graphrag/index/load_pipeline_config.py deleted file mode 100644 index 77893b9535..0000000000 --- a/graphrag/index/load_pipeline_config.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing read_dotenv, load_pipeline_config, _parse_yaml and _create_include_constructor methods definition.""" - -import json -from pathlib import Path - -import yaml -from pyaml_env import parse_config as parse_config_with_env - -from graphrag.config.create_graphrag_config import create_graphrag_config, read_dotenv -from graphrag.index.config.pipeline import PipelineConfig -from graphrag.index.create_pipeline_config import create_pipeline_config - - -def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig: - """Load a pipeline config from a file path or a config object.""" - if isinstance(config_or_path, PipelineConfig): - config = config_or_path - elif config_or_path == "default": - config = create_pipeline_config(create_graphrag_config(root_dir=".")) - else: - # Is there a .env file in the same directory as the config? - read_dotenv(str(Path(config_or_path).parent)) - - if config_or_path.endswith(".json"): - with Path(config_or_path).open("rb") as f: - config = json.loads(f.read().decode(encoding="utf-8", errors="strict")) - elif config_or_path.endswith((".yml", ".yaml")): - config = _parse_yaml(config_or_path) - else: - msg = f"Invalid config file type: {config_or_path}" - raise ValueError(msg) - - config = PipelineConfig.model_validate(config) - if not config.root_dir: - config.root_dir = str(Path(config_or_path).parent.resolve()) - - if config.extends is not None: - if isinstance(config.extends, str): - config.extends = [config.extends] - for extended_config in config.extends: - extended_config = load_pipeline_config(extended_config) - merged_config = { - **json.loads(extended_config.model_dump_json()), - **json.loads(config.model_dump_json(exclude_unset=True)), - } - config = PipelineConfig.model_validate(merged_config) - - return config - - -def _parse_yaml(path: str): - """Parse a yaml file, with support for !include directives.""" - # I don't like that this is static - loader_class = yaml.SafeLoader - - # Add !include constructor if not already present. - if "!include" not in loader_class.yaml_constructors: - loader_class.add_constructor("!include", _create_include_constructor()) - - return parse_config_with_env(path, loader=loader_class, default_value="") - - -def _create_include_constructor(): - """Create a constructor for !include directives.""" - - def handle_include(loader: yaml.Loader, node: yaml.Node): - """Include file referenced at node.""" - filename = str(Path(loader.name).parent / node.value) - if filename.endswith((".yml", ".yaml")): - return _parse_yaml(filename) - - with Path(filename).open("rb") as f: - return f.read().decode(encoding="utf-8", errors="strict") - - return handle_include diff --git a/graphrag/index/operations/chunk_text/chunk_text.py b/graphrag/index/operations/chunk_text/chunk_text.py index 554cfbda35..02c12e6f1a 100644 --- a/graphrag/index/operations/chunk_text/chunk_text.py +++ b/graphrag/index/operations/chunk_text/chunk_text.py @@ -6,17 +6,14 @@ from typing import Any, cast import pandas as pd -from datashaper import ( - ProgressTicker, - VerbCallbacks, - progress_ticker, -) +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType from graphrag.index.operations.chunk_text.typing import ( ChunkInput, ChunkStrategy, ) +from graphrag.logger.progress import ProgressTicker, progress_ticker def chunk_text( diff --git a/graphrag/index/operations/chunk_text/strategies.py b/graphrag/index/operations/chunk_text/strategies.py index 1468028537..3fc8fc6f2f 100644 --- a/graphrag/index/operations/chunk_text/strategies.py +++ b/graphrag/index/operations/chunk_text/strategies.py @@ -7,11 +7,11 @@ import nltk import tiktoken -from datashaper import ProgressTicker from graphrag.config.models.chunking_config import ChunkingConfig from graphrag.index.operations.chunk_text.typing import TextChunk from graphrag.index.text_splitting.text_splitting import Tokenizer +from graphrag.logger.progress import ProgressTicker def run_tokens( diff --git a/graphrag/index/operations/chunk_text/typing.py b/graphrag/index/operations/chunk_text/typing.py index 5f0994ec05..bf58ef5ec1 100644 --- a/graphrag/index/operations/chunk_text/typing.py +++ b/graphrag/index/operations/chunk_text/typing.py @@ -6,9 +6,8 @@ from collections.abc import Callable, Iterable from dataclasses import dataclass -from datashaper import ProgressTicker - from graphrag.config.models.chunking_config import ChunkingConfig +from graphrag.logger.progress import ProgressTicker @dataclass diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index f335802c5f..f4a7e5f367 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -9,9 +9,9 @@ import numpy as np import pandas as pd -from datashaper import VerbCallbacks from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.utils.embeddings import create_collection_name from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument diff --git a/graphrag/index/operations/embed_text/strategies/mock.py b/graphrag/index/operations/embed_text/strategies/mock.py index 3ebb1de8a2..9facd66643 100644 --- a/graphrag/index/operations/embed_text/strategies/mock.py +++ b/graphrag/index/operations/embed_text/strategies/mock.py @@ -7,10 +7,10 @@ from collections.abc import Iterable from typing import Any -from datashaper import ProgressTicker, VerbCallbacks, progress_ticker - from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult +from graphrag.logger.progress import ProgressTicker, progress_ticker async def run( # noqa RUF029 async is required for interface diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/strategies/openai.py index 36be774203..5bef604dab 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/strategies/openai.py @@ -8,17 +8,18 @@ from typing import Any import numpy as np -from datashaper import ProgressTicker, VerbCallbacks, progress_ticker from fnllm import EmbeddingsLLM from pydantic import TypeAdapter import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.llm.load_llm import load_llm_embeddings from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult from graphrag.index.text_splitting.text_splitting import TokenTextSplitter from graphrag.index.utils.is_null import is_null +from graphrag.logger.progress import ProgressTicker, progress_ticker log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/embed_text/strategies/typing.py b/graphrag/index/operations/embed_text/strategies/typing.py index b53d710c0b..5962045a67 100644 --- a/graphrag/index/operations/embed_text/strategies/typing.py +++ b/graphrag/index/operations/embed_text/strategies/typing.py @@ -6,9 +6,8 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass -from datashaper import VerbCallbacks - from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks @dataclass diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index 5dab42b8df..323d95627d 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -9,20 +9,18 @@ from typing import Any import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, - derive_from_rows, -) import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor from graphrag.index.operations.extract_covariates.typing import ( Covariate, CovariateExtractionResult, ) +from graphrag.index.run.derive_from_rows import derive_from_rows log = logging.getLogger(__name__) @@ -65,7 +63,7 @@ async def run_strategy(row): input, run_strategy, callbacks, - scheduling_type=async_mode, + async_type=async_mode, num_threads=num_threads, ) return pd.DataFrame([item for row in results for item in row or []]) diff --git a/graphrag/index/operations/extract_covariates/typing.py b/graphrag/index/operations/extract_covariates/typing.py index f5c7e0a02e..8f95b9b5fb 100644 --- a/graphrag/index/operations/extract_covariates/typing.py +++ b/graphrag/index/operations/extract_covariates/typing.py @@ -7,9 +7,8 @@ from dataclasses import dataclass from typing import Any -from datashaper import VerbCallbacks - from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks @dataclass diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index e3b7410d06..d50e1219b3 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -7,19 +7,17 @@ from typing import Any import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, - derive_from_rows, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.bootstrap import bootstrap from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractStrategy, ExtractEntityStrategyType, ) +from graphrag.index.run.derive_from_rows import derive_from_rows log = logging.getLogger(__name__) @@ -124,7 +122,7 @@ async def run_strategy(row): text_units, run_strategy, callbacks, - scheduling_type=async_mode, + async_type=async_mode, num_threads=num_threads, ) diff --git a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py index 9084321621..2a403112a1 100644 --- a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -4,11 +4,11 @@ """A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence.""" import networkx as nx -from datashaper import VerbCallbacks from fnllm import ChatLLM import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor from graphrag.index.operations.extract_entities.typing import ( diff --git a/graphrag/index/operations/extract_entities/nltk_strategy.py b/graphrag/index/operations/extract_entities/nltk_strategy.py index 81103c6955..e133aeeab4 100644 --- a/graphrag/index/operations/extract_entities/nltk_strategy.py +++ b/graphrag/index/operations/extract_entities/nltk_strategy.py @@ -5,10 +5,10 @@ import networkx as nx import nltk -from datashaper import VerbCallbacks from nltk.corpus import words from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractionResult, diff --git a/graphrag/index/operations/extract_entities/typing.py b/graphrag/index/operations/extract_entities/typing.py index 7eb2440674..247c781003 100644 --- a/graphrag/index/operations/extract_entities/typing.py +++ b/graphrag/index/operations/extract_entities/typing.py @@ -9,9 +9,9 @@ from typing import Any import networkx as nx -from datashaper import VerbCallbacks from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks ExtractedEntity = dict[str, Any] ExtractedRelationship = dict[str, Any] diff --git a/graphrag/index/operations/layout_graph/layout_graph.py b/graphrag/index/operations/layout_graph/layout_graph.py index a4c7471292..b96ef91e34 100644 --- a/graphrag/index/operations/layout_graph/layout_graph.py +++ b/graphrag/index/operations/layout_graph/layout_graph.py @@ -5,8 +5,8 @@ import networkx as nx import pandas as pd -from datashaper import VerbCallbacks +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.embed_graph.typing import NodeEmbeddings from graphrag.index.operations.layout_graph.typing import GraphLayout diff --git a/graphrag/index/operations/snapshot.py b/graphrag/index/operations/snapshot.py deleted file mode 100644 index 1a61fce1cd..0000000000 --- a/graphrag/index/operations/snapshot.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing snapshot method definition.""" - -import pandas as pd - -from graphrag.storage.pipeline_storage import PipelineStorage - - -async def snapshot( - input: pd.DataFrame, - name: str, - formats: list[str], - storage: PipelineStorage, -) -> None: - """Take a entire snapshot of the tabular data.""" - for fmt in formats: - if fmt == "parquet": - await storage.set(f"{name}.parquet", input.to_parquet()) - elif fmt == "json": - await storage.set( - f"{name}.json", input.to_json(orient="records", lines=True) - ) diff --git a/graphrag/index/operations/snapshot_graphml.py b/graphrag/index/operations/snapshot_graphml.py index 6d1d488494..c1eb9b0688 100644 --- a/graphrag/index/operations/snapshot_graphml.py +++ b/graphrag/index/operations/snapshot_graphml.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing snapshot method definition.""" +"""A module containing snapshot_graphml method definition.""" import networkx as nx diff --git a/graphrag/index/operations/summarize_communities/prepare_community_reports.py b/graphrag/index/operations/summarize_communities/prepare_community_reports.py index 45a6fec6d8..66fcaa2bb5 100644 --- a/graphrag/index/operations/summarize_communities/prepare_community_reports.py +++ b/graphrag/index/operations/summarize_communities/prepare_community_reports.py @@ -6,18 +6,16 @@ import logging import pandas as pd -from datashaper import ( - VerbCallbacks, - progress_iterable, -) import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( parallel_sort_context_batch, ) from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import ( get_levels, ) +from graphrag.logger.progress import progress_iterable log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 9003e777bf..e630baba73 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -6,10 +6,10 @@ import logging import traceback -from datashaper import VerbCallbacks from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( CommunityReportsExtractor, diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index d4c5c01072..df6dd631e1 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -6,17 +6,13 @@ import logging import pandas as pd -from datashaper import ( - AsyncType, - NoopVerbCallbacks, - VerbCallbacks, - derive_from_rows, - progress_ticker, -) import graphrag.config.defaults as defaults import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType from graphrag.index.operations.summarize_communities.community_reports_extractor import ( prep_community_report_context, ) @@ -28,6 +24,8 @@ CommunityReportsStrategy, CreateCommunityReportsStrategyType, ) +from graphrag.index.run.derive_from_rows import derive_from_rows +from graphrag.logger.progress import progress_ticker log = logging.getLogger(__name__) @@ -77,7 +75,7 @@ async def run_generate(record): run_generate, callbacks=NoopVerbCallbacks(), num_threads=num_threads, - scheduling_type=async_mode, + async_type=async_mode, ) reports.extend([lr for lr in local_reports if lr is not None]) diff --git a/graphrag/index/operations/summarize_communities/typing.py b/graphrag/index/operations/summarize_communities/typing.py index 6c6b7e6773..2a1ed3aca5 100644 --- a/graphrag/index/operations/summarize_communities/typing.py +++ b/graphrag/index/operations/summarize_communities/typing.py @@ -7,10 +7,10 @@ from enum import Enum from typing import Any -from datashaper import VerbCallbacks from typing_extensions import TypedDict from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks ExtractedEntity = dict[str, Any] StrategyConfig = dict[str, Any] diff --git a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py index 4a22b9b554..e5de39f57f 100644 --- a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py +++ b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py @@ -3,10 +3,10 @@ """A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" -from datashaper import VerbCallbacks from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( SummarizeExtractor, diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index cf6650dd08..d1ad4af487 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -8,17 +8,14 @@ from typing import Any import pandas as pd -from datashaper import ( - ProgressTicker, - VerbCallbacks, - progress_ticker, -) from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks from graphrag.index.operations.summarize_descriptions.typing import ( SummarizationStrategy, SummarizeStrategyType, ) +from graphrag.logger.progress import ProgressTicker, progress_ticker log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/summarize_descriptions/typing.py b/graphrag/index/operations/summarize_descriptions/typing.py index ca0ee13626..919ff9fd1c 100644 --- a/graphrag/index/operations/summarize_descriptions/typing.py +++ b/graphrag/index/operations/summarize_descriptions/typing.py @@ -8,9 +8,8 @@ from enum import Enum from typing import Any, NamedTuple -from datashaper import VerbCallbacks - from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.verb_callbacks import VerbCallbacks StrategyConfig = dict[str, Any] diff --git a/graphrag/index/run/__init__.py b/graphrag/index/run/__init__.py index afb43acd8e..d5e41d66a5 100644 --- a/graphrag/index/run/__init__.py +++ b/graphrag/index/run/__init__.py @@ -2,7 +2,3 @@ # Licensed under the MIT License """Run module for GraphRAG.""" - -from graphrag.index.run.run import run_pipeline, run_pipeline_with_config - -__all__ = ["run_pipeline", "run_pipeline_with_config"] diff --git a/graphrag/index/run/derive_from_rows.py b/graphrag/index/run/derive_from_rows.py new file mode 100644 index 0000000000..283621bb93 --- /dev/null +++ b/graphrag/index/run/derive_from_rows.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Apply a generic transform function to each row in a table.""" + +import asyncio +import inspect +import logging +import traceback +from collections.abc import Awaitable, Callable, Coroutine, Hashable +from typing import Any, TypeVar, cast + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.enums import AsyncType +from graphrag.logger.progress import progress_ticker + +logger = logging.getLogger(__name__) +ItemType = TypeVar("ItemType") + + +class ParallelizationError(ValueError): + """Exception for invalid parallel processing.""" + + def __init__(self, num_errors: int): + super().__init__( + f"{num_errors} Errors occurred while running parallel transformation, could not complete!" + ) + + +async def derive_from_rows( + input: pd.DataFrame, + transform: Callable[[pd.Series], Awaitable[ItemType]], + callbacks: VerbCallbacks, + num_threads: int = 4, + async_type: AsyncType = AsyncType.AsyncIO, +) -> list[ItemType | None]: + """Apply a generic transform function to each row. Any errors will be reported and thrown.""" + match async_type: + case AsyncType.AsyncIO: + return await derive_from_rows_asyncio( + input, transform, callbacks, num_threads + ) + case AsyncType.Threaded: + return await derive_from_rows_asyncio_threads( + input, transform, callbacks, num_threads + ) + case _: + msg = f"Unsupported scheduling type {async_type}" + raise ValueError(msg) + + +"""A module containing the derive_from_rows_async method.""" + + +async def derive_from_rows_asyncio_threads( + input: pd.DataFrame, + transform: Callable[[pd.Series], Awaitable[ItemType]], + callbacks: VerbCallbacks, + num_threads: int | None = 4, +) -> list[ItemType | None]: + """ + Derive from rows asynchronously. + + This is useful for IO bound operations. + """ + semaphore = asyncio.Semaphore(num_threads or 4) + + async def gather(execute: ExecuteFn[ItemType]) -> list[ItemType | None]: + tasks = [asyncio.to_thread(execute, row) for row in input.iterrows()] + + async def execute_task(task: Coroutine) -> ItemType | None: + async with semaphore: + # fire off the thread + thread = await task + return await thread + + return await asyncio.gather(*[execute_task(task) for task in tasks]) + + return await _derive_from_rows_base(input, transform, callbacks, gather) + + +"""A module containing the derive_from_rows_async method.""" + + +async def derive_from_rows_asyncio( + input: pd.DataFrame, + transform: Callable[[pd.Series], Awaitable[ItemType]], + callbacks: VerbCallbacks, + num_threads: int = 4, +) -> list[ItemType | None]: + """ + Derive from rows asynchronously. + + This is useful for IO bound operations. + """ + semaphore = asyncio.Semaphore(num_threads or 4) + + async def gather(execute: ExecuteFn[ItemType]) -> list[ItemType | None]: + async def execute_row_protected( + row: tuple[Hashable, pd.Series], + ) -> ItemType | None: + async with semaphore: + return await execute(row) + + tasks = [ + asyncio.create_task(execute_row_protected(row)) for row in input.iterrows() + ] + return await asyncio.gather(*tasks) + + return await _derive_from_rows_base(input, transform, callbacks, gather) + + +ItemType = TypeVar("ItemType") + +ExecuteFn = Callable[[tuple[Hashable, pd.Series]], Awaitable[ItemType | None]] +GatherFn = Callable[[ExecuteFn], Awaitable[list[ItemType | None]]] + + +async def _derive_from_rows_base( + input: pd.DataFrame, + transform: Callable[[pd.Series], Awaitable[ItemType]], + callbacks: VerbCallbacks, + gather: GatherFn[ItemType], +) -> list[ItemType | None]: + """ + Derive from rows asynchronously. + + This is useful for IO bound operations. + """ + tick = progress_ticker(callbacks.progress, num_total=len(input)) + errors: list[tuple[BaseException, str]] = [] + + async def execute(row: tuple[Any, pd.Series]) -> ItemType | None: + try: + result = transform(row[1]) + if inspect.iscoroutine(result): + result = await result + except Exception as e: # noqa: BLE001 + errors.append((e, traceback.format_exc())) + return None + else: + return cast("ItemType", result) + finally: + tick(1) + + result = await gather(execute) + + tick.done() + + for error, stack in errors: + callbacks.error("parallel transformation error", error, stack) + + if len(errors) > 0: + raise ParallelizationError(len(errors)) + + return result diff --git a/graphrag/index/run/postprocess.py b/graphrag/index/run/postprocess.py deleted file mode 100644 index 52c20064ad..0000000000 --- a/graphrag/index/run/postprocess.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Post Processing functions for the GraphRAG run module.""" - -from typing import cast - -import pandas as pd -from datashaper import DEFAULT_INPUT_NAME, WorkflowCallbacks - -from graphrag.index.config.input import PipelineInputConfigTypes -from graphrag.index.config.workflow import PipelineWorkflowStep -from graphrag.index.context import PipelineRunContext -from graphrag.index.workflows.load import create_workflow - - -def _create_postprocess_steps( - config: PipelineInputConfigTypes | None, -) -> list[PipelineWorkflowStep] | None: - """Retrieve the post process steps for the pipeline.""" - return config.post_process if config is not None else None - - -async def _run_post_process_steps( - post_process: list[PipelineWorkflowStep] | None, - dataset: pd.DataFrame, - context: PipelineRunContext, - callbacks: WorkflowCallbacks, -) -> pd.DataFrame: - """Run the pipeline. - - Args: - - post_process - The post process steps to run - - dataset - The dataset to run the steps on - - context - The pipeline run context - Returns: - - output - The dataset after running the post process steps - """ - if post_process: - input_workflow = create_workflow( - "Input Post Process", - post_process, - ) - input_workflow.add_table(DEFAULT_INPUT_NAME, dataset) - await input_workflow.run( - context=context, - callbacks=callbacks, - ) - dataset = cast("pd.DataFrame", input_workflow.output()) - return dataset diff --git a/graphrag/index/run/profiling.py b/graphrag/index/run/profiling.py deleted file mode 100644 index 36efcde019..0000000000 --- a/graphrag/index/run/profiling.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Profiling functions for the GraphRAG run module.""" - -import json -import logging -import time -from dataclasses import asdict - -from datashaper import MemoryProfile, Workflow, WorkflowRunResult - -from graphrag.index.context import PipelineRunStats -from graphrag.storage.pipeline_storage import PipelineStorage - -log = logging.getLogger(__name__) - - -async def _save_profiler_stats( - storage: PipelineStorage, workflow_name: str, profile: MemoryProfile -): - """Save the profiler stats to the storage.""" - await storage.set( - f"{workflow_name}_profiling.peak_stats.csv", - profile.peak_stats.to_csv(index=True), - ) - - await storage.set( - f"{workflow_name}_profiling.snapshot_stats.csv", - profile.snapshot_stats.to_csv(index=True), - ) - - await storage.set( - f"{workflow_name}_profiling.time_stats.csv", - profile.time_stats.to_csv(index=True), - ) - - await storage.set( - f"{workflow_name}_profiling.detailed_view.csv", - profile.detailed_view.to_csv(index=True), - ) - - -async def _dump_stats(stats: PipelineRunStats, storage: PipelineStorage) -> None: - """Dump the stats to the storage.""" - await storage.set( - "stats.json", json.dumps(asdict(stats), indent=4, ensure_ascii=False) - ) - - -async def _write_workflow_stats( - workflow: Workflow, - workflow_result: WorkflowRunResult, - workflow_start_time: float, - start_time: float, - stats: PipelineRunStats, - storage: PipelineStorage, -) -> None: - """Write the workflow stats to the storage.""" - for vt in workflow_result.verb_timings: - stats.workflows[workflow.name][f"{vt.index}_{vt.verb}"] = vt.timing - - workflow_end_time = time.time() - stats.workflows[workflow.name]["overall"] = workflow_end_time - workflow_start_time - stats.total_runtime = time.time() - start_time - await _dump_stats(stats, storage) - - if workflow_result.memory_profile is not None: - await _save_profiler_stats( - storage, workflow.name, workflow_result.memory_profile - ) diff --git a/graphrag/index/run/run.py b/graphrag/index/run/run.py deleted file mode 100644 index 7ef8ee1426..0000000000 --- a/graphrag/index/run/run.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Different methods to run the pipeline.""" - -import gc -import logging -import time -import traceback -from collections.abc import AsyncIterable -from typing import cast - -import pandas as pd -from datashaper import NoopVerbCallbacks, WorkflowCallbacks - -from graphrag.cache.factory import CacheFactory -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks -from graphrag.index.config.pipeline import ( - PipelineConfig, - PipelineWorkflowReference, -) -from graphrag.index.config.workflow import PipelineWorkflowStep -from graphrag.index.exporter import ParquetExporter -from graphrag.index.input.factory import create_input -from graphrag.index.load_pipeline_config import load_pipeline_config -from graphrag.index.run.postprocess import ( - _create_postprocess_steps, - _run_post_process_steps, -) -from graphrag.index.run.profiling import _dump_stats -from graphrag.index.run.utils import ( - _apply_substitutions, - _validate_dataset, - create_run_context, -) -from graphrag.index.run.workflow import ( - _create_callback_chain, - _process_workflow, -) -from graphrag.index.typing import PipelineRunResult -from graphrag.index.update.incremental_index import ( - get_delta_docs, - update_dataframe_outputs, -) -from graphrag.index.workflows import ( - VerbDefinitions, - WorkflowDefinitions, - load_workflows, -) -from graphrag.logger.base import ProgressLogger -from graphrag.logger.null_progress import NullProgressLogger -from graphrag.storage.factory import StorageFactory -from graphrag.storage.pipeline_storage import PipelineStorage - -log = logging.getLogger(__name__) - - -async def run_pipeline_with_config( - config_or_path: PipelineConfig | str, - workflows: list[PipelineWorkflowReference] | None = None, - dataset: pd.DataFrame | None = None, - storage: PipelineStorage | None = None, - update_index_storage: PipelineStorage | None = None, - cache: PipelineCache | None = None, - callbacks: list[WorkflowCallbacks] | None = None, - logger: ProgressLogger | None = None, - input_post_process_steps: list[PipelineWorkflowStep] | None = None, - additional_verbs: VerbDefinitions | None = None, - additional_workflows: WorkflowDefinitions | None = None, - memory_profile: bool = False, - run_id: str | None = None, - is_resume_run: bool = False, - is_update_run: bool = False, - **_kwargs: dict, -) -> AsyncIterable[PipelineRunResult]: - """Run a pipeline with the given config. - - Args: - - config_or_path - The config to run the pipeline with - - workflows - The workflows to run (this overrides the config) - - dataset - The dataset to run the pipeline on (this overrides the config) - - storage - The storage to use for the pipeline (this overrides the config) - - cache - The cache to use for the pipeline (this overrides the config) - - logger - The logger to use for the pipeline (this overrides the config) - - input_post_process_steps - The post process steps to run on the input data (this overrides the config) - - additional_verbs - The custom verbs to use for the pipeline. - - additional_workflows - The custom workflows to use for the pipeline. - - memory_profile - Whether or not to profile the memory. - - run_id - The run id to start or resume from. - """ - if isinstance(config_or_path, str): - log.info("Running pipeline with config %s", config_or_path) - else: - log.info("Running pipeline") - - run_id = run_id or time.strftime("%Y%m%d-%H%M%S") - config = load_pipeline_config(config_or_path) - config = _apply_substitutions(config, run_id) - root_dir = config.root_dir or "" - - progress_logger = logger or NullProgressLogger() - storage_config = config.storage.model_dump() # type: ignore - storage = storage or StorageFactory().create_storage( - storage_type=storage_config["type"], # type: ignore - kwargs=storage_config, - ) - - if is_update_run: - update_storage_config = config.update_index_storage.model_dump() # type: ignore - update_index_storage = update_index_storage or StorageFactory().create_storage( - storage_type=update_storage_config["type"], # type: ignore - kwargs=update_storage_config, - ) - - # TODO: remove the type ignore when the new config system guarantees the existence of a cache config - cache_config = config.cache.model_dump() # type: ignore - cache = cache or CacheFactory().create_cache( - cache_type=cache_config["type"], # type: ignore - root_dir=root_dir, - kwargs=cache_config, - ) - # TODO: remove the type ignore when the new config system guarantees the existence of an input config - dataset = ( - dataset - if dataset is not None - else await create_input(config.input, progress_logger, root_dir) # type: ignore - ) - - post_process_steps = input_post_process_steps or _create_postprocess_steps( - config.input - ) - workflows = workflows or config.workflows - - if is_update_run and update_index_storage: - delta_dataset = await get_delta_docs(dataset, storage) - - # Fail on empty delta dataset - if delta_dataset.new_inputs.empty: - error_msg = "Incremental Indexing Error: No new documents to process." - raise ValueError(error_msg) - - delta_storage = update_index_storage.child("delta") - - # Run the pipeline on the new documents - tables_dict = {} - async for table in run_pipeline( - workflows=workflows, - dataset=delta_dataset.new_inputs, - storage=delta_storage, - cache=cache, - callbacks=callbacks, - input_post_process_steps=post_process_steps, - memory_profile=memory_profile, - additional_verbs=additional_verbs, - additional_workflows=additional_workflows, - progress_logger=progress_logger, - is_resume_run=False, - ): - tables_dict[table.workflow] = table.result - - progress_logger.success("Finished running workflows on new documents.") - await update_dataframe_outputs( - dataframe_dict=tables_dict, - storage=storage, - update_storage=update_index_storage, - config=config, - cache=cache, - callbacks=NoopVerbCallbacks(), - progress_logger=progress_logger, - ) - - else: - async for table in run_pipeline( - workflows=workflows, - dataset=dataset, - storage=storage, - cache=cache, - callbacks=callbacks, - input_post_process_steps=post_process_steps, - memory_profile=memory_profile, - additional_verbs=additional_verbs, - additional_workflows=additional_workflows, - progress_logger=progress_logger, - is_resume_run=is_resume_run, - ): - yield table - - -async def run_pipeline( - workflows: list[PipelineWorkflowReference], - dataset: pd.DataFrame, - storage: PipelineStorage | None = None, - cache: PipelineCache | None = None, - callbacks: list[WorkflowCallbacks] | None = None, - progress_logger: ProgressLogger | None = None, - input_post_process_steps: list[PipelineWorkflowStep] | None = None, - additional_verbs: VerbDefinitions | None = None, - additional_workflows: WorkflowDefinitions | None = None, - memory_profile: bool = False, - is_resume_run: bool = False, - **_kwargs: dict, -) -> AsyncIterable[PipelineRunResult]: - """Run the pipeline. - - Args: - - workflows - The workflows to run - - dataset - The dataset to run the pipeline on, specifically a dataframe with the following columns at a minimum: - - id - The id of the document - - text - The text of the document - - title - The title of the document - These must exist after any post process steps are run if there are any! - - storage - The storage to use for the pipeline - - cache - The cache to use for the pipeline - - progress_logger - The logger to use for the pipeline - - input_post_process_steps - The post process steps to run on the input data - - additional_verbs - The custom verbs to use for the pipeline - - additional_workflows - The custom workflows to use for the pipeline - - debug - Whether or not to run in debug mode - Returns: - - output - An iterable of workflow results as they complete running, as well as any errors that occur - """ - start_time = time.time() - - progress_reporter = progress_logger or NullProgressLogger() - callbacks = callbacks or [ConsoleWorkflowCallbacks()] - callback_chain = _create_callback_chain(callbacks, progress_reporter) - context = create_run_context(storage=storage, cache=cache, stats=None) - exporter = ParquetExporter( - context.storage, - lambda e, s, d: cast("WorkflowCallbacks", callback_chain).on_error( - "Error exporting table", e, s, d - ), - ) - - loaded_workflows = load_workflows( - workflows, - additional_verbs=additional_verbs, - additional_workflows=additional_workflows, - memory_profile=memory_profile, - ) - workflows_to_run = loaded_workflows.workflows - workflow_dependencies = loaded_workflows.dependencies - dataset = await _run_post_process_steps( - input_post_process_steps, dataset, context, callback_chain - ) - - # ensure the incoming data is valid - _validate_dataset(dataset) - - log.info("Final # of rows loaded: %s", len(dataset)) - context.stats.num_documents = len(dataset) - last_workflow = "input" - - try: - await _dump_stats(context.stats, context.storage) - - for workflow_to_run in workflows_to_run: - # flush out any intermediate dataframes - gc.collect() - last_workflow = workflow_to_run.workflow.name - result = await _process_workflow( - workflow_to_run.workflow, - context, - callback_chain, - exporter, - workflow_dependencies, - dataset, - start_time, - is_resume_run, - ) - if result: - yield result - - context.stats.total_runtime = time.time() - start_time - await _dump_stats(context.stats, context.storage) - except Exception as e: - log.exception("error running workflow %s", last_workflow) - cast("WorkflowCallbacks", callback_chain).on_error( - "Error running pipeline!", e, traceback.format_exc() - ) - yield PipelineRunResult(last_workflow, None, [e]) diff --git a/graphrag/index/run/run_workflows.py b/graphrag/index/run/run_workflows.py new file mode 100644 index 0000000000..096fe9fb1a --- /dev/null +++ b/graphrag/index/run/run_workflows.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Different methods to run the pipeline.""" + +import json +import logging +import time +import traceback +from collections.abc import AsyncIterable +from dataclasses import asdict +from typing import cast + +import pandas as pd + +from graphrag.cache.factory import CacheFactory +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks +from graphrag.callbacks.delegating_verb_callbacks import DelegatingVerbCallbacks +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunStats +from graphrag.index.input.factory import create_input +from graphrag.index.run.utils import create_callback_chain, create_run_context +from graphrag.index.typing import PipelineRunResult +from graphrag.index.update.incremental_index import ( + get_delta_docs, + update_dataframe_outputs, +) +from graphrag.index.workflows import all_workflows +from graphrag.logger.base import ProgressLogger +from graphrag.logger.null_progress import NullProgressLogger +from graphrag.logger.progress import Progress +from graphrag.storage.factory import StorageFactory +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import delete_table_from_storage, write_table_to_storage + +log = logging.getLogger(__name__) + + +# these are transient outputs written to storage for downstream workflow use +# they are not required after indexing, so we'll clean them up at the end for clarity +# (unless snapshots.transient is set!) +transient_outputs = [ + "input", + "base_communities", + "base_entity_nodes", + "base_relationship_edges", + "create_base_text_units", +] + + +async def run_workflows( + workflows: list[str], + config: GraphRagConfig, + cache: PipelineCache | None = None, + callbacks: list[WorkflowCallbacks] | None = None, + logger: ProgressLogger | None = None, + run_id: str | None = None, + is_update_run: bool = False, +) -> AsyncIterable[PipelineRunResult]: + """Run all workflows using a simplified pipeline.""" + run_id = run_id or time.strftime("%Y%m%d-%H%M%S") + root_dir = config.root_dir or "" + progress_logger = logger or NullProgressLogger() + callbacks = callbacks or [ConsoleWorkflowCallbacks()] + callback_chain = create_callback_chain(callbacks, progress_logger) + storage_config = config.storage.model_dump() # type: ignore + storage = StorageFactory().create_storage( + storage_type=storage_config["type"], # type: ignore + kwargs=storage_config, + ) + cache_config = config.cache.model_dump() # type: ignore + cache = cache or CacheFactory().create_cache( + cache_type=cache_config["type"], # type: ignore + root_dir=root_dir, + kwargs=cache_config, + ) + + dataset = await create_input(config.input, logger, root_dir) + + if is_update_run: + progress_logger.info("Running incremental indexing.") + + update_storage_config = config.update_index_storage.model_dump() # type: ignore + update_index_storage = StorageFactory().create_storage( + storage_type=update_storage_config["type"], # type: ignore + kwargs=update_storage_config, + ) + + delta_dataset = await get_delta_docs(dataset, storage) + + # Fail on empty delta dataset + if delta_dataset.new_inputs.empty: + error_msg = "Incremental Indexing Error: No new documents to process." + raise ValueError(error_msg) + + delta_storage = update_index_storage.child("delta") + + # Run the pipeline on the new documents + tables_dict = {} + async for table in _run_workflows( + workflows=workflows, + config=config, + dataset=delta_dataset.new_inputs, + cache=cache, + storage=delta_storage, + callbacks=callback_chain, + logger=progress_logger, + ): + tables_dict[table.workflow] = table.result + + progress_logger.success("Finished running workflows on new documents.") + + await update_dataframe_outputs( + dataframe_dict=tables_dict, + storage=storage, + update_storage=update_index_storage, + config=config, + cache=cache, + callbacks=NoopVerbCallbacks(), + progress_logger=progress_logger, + ) + + else: + progress_logger.info("Running standard indexing.") + + async for table in _run_workflows( + workflows=workflows, + config=config, + dataset=dataset, + cache=cache, + storage=storage, + callbacks=callback_chain, + logger=progress_logger, + ): + yield table + + +async def _run_workflows( + workflows: list[str], + config: GraphRagConfig, + dataset: pd.DataFrame, + cache: PipelineCache, + storage: PipelineStorage, + callbacks: WorkflowCallbacks, + logger: ProgressLogger, +) -> AsyncIterable[PipelineRunResult]: + start_time = time.time() + + context = create_run_context(storage=storage, cache=cache, stats=None) + + log.info("Final # of rows loaded: %s", len(dataset)) + context.stats.num_documents = len(dataset) + last_workflow = "input" + + try: + await _dump_stats(context.stats, context.storage) + await write_table_to_storage(dataset, "input", context.storage) + + for workflow in workflows: + last_workflow = workflow + run_workflow = all_workflows[workflow] + progress = logger.child(workflow, transient=False) + callbacks.on_workflow_start(workflow, None) + verb_callbacks = DelegatingVerbCallbacks(workflow, callbacks) + work_time = time.time() + result = await run_workflow( + config, + context, + verb_callbacks, + ) + progress(Progress(percent=1)) + callbacks.on_workflow_end(workflow, result) + yield PipelineRunResult(workflow, result, None) + + context.stats.workflows[workflow] = {"overall": time.time() - work_time} + + context.stats.total_runtime = time.time() - start_time + await _dump_stats(context.stats, context.storage) + + if not config.snapshots.transient: + for output in transient_outputs: + await delete_table_from_storage(output, context.storage) + + except Exception as e: + log.exception("error running workflow %s", last_workflow) + cast("WorkflowCallbacks", callbacks).on_error( + "Error running pipeline!", e, traceback.format_exc() + ) + yield PipelineRunResult(last_workflow, None, [e]) + + +async def _dump_stats(stats: PipelineRunStats, storage: PipelineStorage) -> None: + """Dump the stats to the storage.""" + await storage.set( + "stats.json", json.dumps(asdict(stats), indent=4, ensure_ascii=False) + ) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index e78ee11179..04dd31df24 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -3,89 +3,16 @@ """Utility functions for the GraphRAG run module.""" -import logging -from string import Template -from typing import Any - -import pandas as pd - from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.cache import ( - PipelineBlobCacheConfig, - PipelineFileCacheConfig, -) -from graphrag.index.config.pipeline import PipelineConfig -from graphrag.index.config.reporting import ( - PipelineBlobReportingConfig, - PipelineFileReportingConfig, -) -from graphrag.index.config.storage import ( - PipelineBlobStorageConfig, - PipelineFileStorageConfig, -) +from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager from graphrag.index.context import PipelineRunContext, PipelineRunStats +from graphrag.logger.base import ProgressLogger from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage -log = logging.getLogger(__name__) - - -def _validate_dataset(dataset: Any): - """Validate the dataset for the pipeline. - - Args: - - dataset - The dataset to validate - """ - if not isinstance(dataset, pd.DataFrame): - msg = "Dataset must be a pandas dataframe!" - raise TypeError(msg) - - -def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig: - """Apply the substitutions to the configuration.""" - substitutions = {"timestamp": run_id} - - if ( - isinstance( - config.storage, PipelineFileStorageConfig | PipelineBlobStorageConfig - ) - and config.storage.base_dir - ): - config.storage.base_dir = Template(config.storage.base_dir).substitute( - substitutions - ) - if ( - config.update_index_storage - and isinstance( - config.update_index_storage, - PipelineFileStorageConfig | PipelineBlobStorageConfig, - ) - and config.update_index_storage.base_dir - ): - config.update_index_storage.base_dir = Template( - config.update_index_storage.base_dir - ).substitute(substitutions) - if ( - isinstance(config.cache, PipelineFileCacheConfig | PipelineBlobCacheConfig) - and config.cache.base_dir - ): - config.cache.base_dir = Template(config.cache.base_dir).substitute( - substitutions - ) - - if ( - isinstance( - config.reporting, PipelineFileReportingConfig | PipelineBlobReportingConfig - ) - and config.reporting.base_dir - ): - config.reporting.base_dir = Template(config.reporting.base_dir).substitute( - substitutions - ) - - return config - def create_run_context( storage: PipelineStorage | None, @@ -97,5 +24,16 @@ def create_run_context( stats=stats or PipelineRunStats(), cache=cache or InMemoryCache(), storage=storage or MemoryPipelineStorage(), - runtime_storage=MemoryPipelineStorage(), ) + + +def create_callback_chain( + callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None +) -> WorkflowCallbacks: + """Create a callback manager that encompasses multiple callbacks.""" + manager = WorkflowCallbacksManager() + for callback in callbacks or []: + manager.register(callback) + if progress is not None: + manager.register(ProgressWorkflowCallbacks(progress)) + return manager diff --git a/graphrag/index/run/workflow.py b/graphrag/index/run/workflow.py deleted file mode 100644 index e4758fd951..0000000000 --- a/graphrag/index/run/workflow.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Workflow functions for the GraphRAG update module.""" - -import logging -import time -from typing import cast - -import pandas as pd -from datashaper import ( - DEFAULT_INPUT_NAME, - Workflow, - WorkflowCallbacks, - WorkflowCallbacksManager, -) - -from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks -from graphrag.index.config.pipeline import PipelineConfig -from graphrag.index.context import PipelineRunContext -from graphrag.index.exporter import ParquetExporter -from graphrag.index.run.profiling import _write_workflow_stats -from graphrag.index.typing import PipelineRunResult -from graphrag.logger.base import ProgressLogger -from graphrag.storage.pipeline_storage import PipelineStorage -from graphrag.utils.storage import load_table_from_storage - -log = logging.getLogger(__name__) - - -async def _inject_workflow_data_dependencies( - workflow: Workflow, - workflow_dependencies: dict[str, list[str]], - dataset: pd.DataFrame, - storage: PipelineStorage, -) -> None: - """Inject the data dependencies into the workflow.""" - workflow.add_table(DEFAULT_INPUT_NAME, dataset) - deps = workflow_dependencies[workflow.name] - log.info("dependencies for %s: %s", workflow.name, deps) - for id in deps: - workflow_id = f"workflow:{id}" - try: - table = await load_table_from_storage(f"{id}.parquet", storage) - except ValueError: - # our workflows allow for transient tables, and we avoid putting those in storage - # however, we need to keep the table in the dependency list for proper execution order. - # this allows us to catch missing table errors and issue a warning for pipeline users who may genuinely have an error (which we expect to be very rare) - # todo: this issue will resolve itself once we remove DataShaper completely - log.warning( - "Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.", - id, - ) - table = pd.DataFrame() - workflow.add_table(workflow_id, table) - - -async def _export_workflow_output( - workflow: Workflow, exporter: ParquetExporter -) -> pd.DataFrame: - """Export the output from each step of the workflow.""" - output = cast("pd.DataFrame", workflow.output()) - # only write final output that is not empty (i.e. has content) - # NOTE: this design is intentional - it accounts for workflow steps with "side effects" that don't produce a formal output to save - if not output.empty: - await exporter.export(workflow.name, output) - return output - - -def _create_callback_chain( - callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None -) -> WorkflowCallbacks: - """Create a callback manager that encompasses multiple callbacks.""" - manager = WorkflowCallbacksManager() - for callback in callbacks or []: - manager.register(callback) - if progress is not None: - manager.register(ProgressWorkflowCallbacks(progress)) - return manager - - -async def _process_workflow( - workflow: Workflow, - context: PipelineRunContext, - callbacks: WorkflowCallbacks, - exporter: ParquetExporter, - workflow_dependencies: dict[str, list[str]], - dataset: pd.DataFrame, - start_time: float, - is_resume_run: bool, -): - workflow_name = workflow.name - if is_resume_run and await context.storage.has(f"{workflow_name}.parquet"): - log.info("Skipping %s because it already exists", workflow_name) - return None - - context.stats.workflows[workflow_name] = {"overall": 0.0} - - await _inject_workflow_data_dependencies( - workflow, - workflow_dependencies, - dataset, - context.storage, - ) - - workflow_start_time = time.time() - result = await workflow.run(context, callbacks) - await _write_workflow_stats( - workflow, - result, - workflow_start_time, - start_time, - context.stats, - context.storage, - ) - - # Save the output from the workflow - output = await _export_workflow_output(workflow, exporter) - workflow.dispose() - return PipelineRunResult(workflow_name, output, None) - - -def _find_workflow_config( - config: PipelineConfig, workflow_name: str, step: str | None = None -) -> dict: - """Find a workflow in the pipeline configuration. - - Parameters - ---------- - config : PipelineConfig - The pipeline configuration. - workflow_name : str - The name of the workflow. - step : str - The step in the workflow. - - Returns - ------- - dict - The workflow configuration. - """ - try: - workflow = next( - filter(lambda workflow: workflow.name == workflow_name, config.workflows) - ) - except StopIteration as err: - error_message = ( - f"Workflow {workflow_name} not found in the pipeline configuration." - ) - raise ValueError(error_message) from err - - if not workflow.config: - return {} - return workflow.config if not step else workflow.config.get(step, {}) diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 3c117ac837..849fa4a749 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -8,14 +8,13 @@ import numpy as np import pandas as pd -from datashaper import VerbCallbacks from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.pipeline import PipelineConfig +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import ( run_graph_intelligence as run_entity_summarization, ) -from graphrag.index.run.workflow import _find_workflow_config def _group_and_resolve_entities( @@ -91,7 +90,7 @@ def _group_and_resolve_entities( async def _run_entity_summarization( entities_df: pd.DataFrame, - config: PipelineConfig, + config: GraphRagConfig, cache: PipelineCache, callbacks: VerbCallbacks, ) -> pd.DataFrame: @@ -101,7 +100,7 @@ async def _run_entity_summarization( ---------- entities_df : pd.DataFrame The entities dataframe. - config : PipelineConfig + config : GraphRagConfig The pipeline configuration. cache : PipelineCache Pipeline cache used during the summarization process. @@ -111,10 +110,9 @@ async def _run_entity_summarization( pd.DataFrame The updated entities dataframe with summarized descriptions. """ - summarize_config = _find_workflow_config( - config, "extract_graph", "summarize_descriptions" + summarization_strategy = config.summarize_descriptions.resolved_strategy( + config.root_dir, ) - strategy = summarize_config.get("strategy", {}) # Prepare tasks for async summarization where needed async def process_row(row): @@ -122,7 +120,7 @@ async def process_row(row): if isinstance(description, list) and len(description) > 1: # Run entity summarization asynchronously result = await run_entity_summarization( - row["title"], description, callbacks, cache, strategy + row["title"], description, callbacks, cache, summarization_strategy ) return result.description # Handle case where description is a single-item list or not a list diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 05da47f01d..4ba486af6b 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd -from datashaper import VerbCallbacks from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.pipeline import PipelineConfig +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings -from graphrag.index.run.workflow import _find_workflow_config from graphrag.index.update.communities import ( _merge_and_resolve_nodes, _update_and_merge_communities, @@ -25,7 +25,11 @@ from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.logger.print_progress import ProgressLogger from graphrag.storage.pipeline_storage import PipelineStorage -from graphrag.utils.storage import load_table_from_storage +from graphrag.utils.storage import ( + load_table_from_storage, + storage_has_table, + write_table_to_storage, +) @dataclass @@ -61,9 +65,7 @@ async def get_delta_docs( InputDelta The input delta. With new inputs and deleted inputs. """ - final_docs = await load_table_from_storage( - "create_final_documents.parquet", storage - ) + final_docs = await load_table_from_storage("create_final_documents", storage) # Select distinct title from final docs and from dataset previous_docs: list[str] = final_docs["title"].unique().tolist() @@ -82,7 +84,7 @@ async def update_dataframe_outputs( dataframe_dict: dict[str, pd.DataFrame], storage: PipelineStorage, update_storage: PipelineStorage, - config: PipelineConfig, + config: GraphRagConfig, cache: PipelineCache, callbacks: VerbCallbacks, progress_logger: ProgressLogger, @@ -121,7 +123,7 @@ async def update_dataframe_outputs( # Merge final covariates if ( - await storage.has("create_final_covariates.parquet") + await storage_has_table("create_final_covariates", storage) and "create_final_covariates" in dataframe_dict ): progress_logger.info("Updating Final Covariates") @@ -145,13 +147,10 @@ async def update_dataframe_outputs( dataframe_dict, storage, update_storage, community_id_mapping ) - # Extract the embeddings config - embeddings_config = _find_workflow_config( - config=config, workflow_name="generate_text_embeddings" - ) - # Generate text embeddings progress_logger.info("Updating Text Embeddings") + embedded_fields = get_embedded_fields(config) + text_embed = get_embedding_settings(config.embeddings) await generate_text_embeddings( final_documents=final_documents_df, final_relationships=merged_relationships_df, @@ -161,9 +160,9 @@ async def update_dataframe_outputs( callbacks=callbacks, cache=cache, storage=update_storage, - text_embed_config=embeddings_config.get("text_embed", {}), - embedded_fields=embeddings_config.get("embedded_fields", {}), - snapshot_embeddings_enabled=embeddings_config.get("snapshot_embeddings", False), + text_embed_config=text_embed, + embedded_fields=embedded_fields, + snapshot_embeddings_enabled=config.snapshots.embeddings, ) @@ -172,7 +171,7 @@ async def _update_community_reports( ): """Update the community reports output.""" old_community_reports = await load_table_from_storage( - "create_final_community_reports.parquet", storage + "create_final_community_reports", storage ) delta_community_reports = dataframe_dict["create_final_community_reports"] @@ -180,9 +179,8 @@ async def _update_community_reports( old_community_reports, delta_community_reports, community_id_mapping ) - await update_storage.set( - "create_final_community_reports.parquet", - merged_community_reports.to_parquet(), + await write_table_to_storage( + merged_community_reports, "create_final_community_reports", update_storage ) return merged_community_reports @@ -192,42 +190,40 @@ async def _update_communities( dataframe_dict, storage, update_storage, community_id_mapping ): """Update the communities output.""" - old_communities = await load_table_from_storage( - "create_final_communities.parquet", storage - ) + old_communities = await load_table_from_storage("create_final_communities", storage) delta_communities = dataframe_dict["create_final_communities"] merged_communities = _update_and_merge_communities( old_communities, delta_communities, community_id_mapping ) - await update_storage.set( - "create_final_communities.parquet", merged_communities.to_parquet() + await write_table_to_storage( + merged_communities, "create_final_communities", update_storage ) async def _update_nodes(dataframe_dict, storage, update_storage, merged_entities_df): """Update the nodes output.""" - old_nodes = await load_table_from_storage("create_final_nodes.parquet", storage) + old_nodes = await load_table_from_storage("create_final_nodes", storage) delta_nodes = dataframe_dict["create_final_nodes"] merged_nodes, community_id_mapping = _merge_and_resolve_nodes( old_nodes, delta_nodes, merged_entities_df ) - await update_storage.set("create_final_nodes.parquet", merged_nodes.to_parquet()) + await write_table_to_storage(merged_nodes, "create_final_nodes", update_storage) + return merged_nodes, community_id_mapping async def _update_covariates(dataframe_dict, storage, update_storage): """Update the covariates output.""" - old_covariates = await load_table_from_storage( - "create_final_covariates.parquet", storage - ) + old_covariates = await load_table_from_storage("create_final_covariates", storage) delta_covariates = dataframe_dict["create_final_covariates"] merged_covariates = _merge_covariates(old_covariates, delta_covariates) - await update_storage.set( - "create_final_covariates.parquet", merged_covariates.to_parquet() + + await write_table_to_storage( + merged_covariates, "create_final_covariates", update_storage ) @@ -235,17 +231,15 @@ async def _update_text_units( dataframe_dict, storage, update_storage, entity_id_mapping ): """Update the text units output.""" - old_text_units = await load_table_from_storage( - "create_final_text_units.parquet", storage - ) + old_text_units = await load_table_from_storage("create_final_text_units", storage) delta_text_units = dataframe_dict["create_final_text_units"] merged_text_units = _update_and_merge_text_units( old_text_units, delta_text_units, entity_id_mapping ) - await update_storage.set( - "create_final_text_units.parquet", merged_text_units.to_parquet() + await write_table_to_storage( + merged_text_units, "create_final_text_units", update_storage ) return merged_text_units @@ -254,7 +248,7 @@ async def _update_text_units( async def _update_relationships(dataframe_dict, storage, update_storage): """Update the relationships output.""" old_relationships = await load_table_from_storage( - "create_final_relationships.parquet", storage + "create_final_relationships", storage ) delta_relationships = dataframe_dict["create_final_relationships"] merged_relationships_df = _update_and_merge_relationships( @@ -262,8 +256,8 @@ async def _update_relationships(dataframe_dict, storage, update_storage): delta_relationships, ) - await update_storage.set( - "create_final_relationships.parquet", merged_relationships_df.to_parquet() + await write_table_to_storage( + merged_relationships_df, "create_final_relationships", update_storage ) return merged_relationships_df @@ -273,9 +267,7 @@ async def _update_entities( dataframe_dict, storage, update_storage, config, cache, callbacks ): """Update Final Entities output.""" - old_entities = await load_table_from_storage( - "create_final_entities.parquet", storage - ) + old_entities = await load_table_from_storage("create_final_entities", storage) delta_entities = dataframe_dict["create_final_entities"] merged_entities_df, entity_id_mapping = _group_and_resolve_entities( @@ -291,8 +283,8 @@ async def _update_entities( ) # Save the updated entities back to storage - await update_storage.set( - "create_final_entities.parquet", merged_entities_df.to_parquet() + await write_table_to_storage( + merged_entities_df, "create_final_entities", update_storage ) return merged_entities_df, entity_id_mapping @@ -310,13 +302,14 @@ async def _concat_dataframes(name, dataframe_dict, storage, update_storage): storage : PipelineStorage The storage used to store the dataframes. """ - old_df = await load_table_from_storage(f"{name}.parquet", storage) + old_df = await load_table_from_storage(name, storage) delta_df = dataframe_dict[name] # Merge the final documents final_df = pd.concat([old_df, delta_df], copy=False) - await update_storage.set(f"{name}.parquet", final_df.to_parquet()) + await write_table_to_storage(final_df, name, update_storage) + return final_df diff --git a/graphrag/index/utils/ds_util.py b/graphrag/index/utils/ds_util.py deleted file mode 100644 index e59d30754f..0000000000 --- a/graphrag/index/utils/ds_util.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A utility module datashaper-specific utility methods.""" - -from typing import cast - -from datashaper import TableContainer, VerbInput - -_NAMED_INPUTS_REQUIRED = "Named inputs are required" - - -def get_required_input_table(input: VerbInput, name: str) -> TableContainer: - """Get a required input table by name.""" - return cast("TableContainer", get_named_input_table(input, name, required=True)) - - -def get_named_input_table( - input: VerbInput, name: str, required: bool = False -) -> TableContainer | None: - """Get an input table from datashaper verb-inputs by name.""" - named_inputs = input.named - if named_inputs is None: - if not required: - return None - raise ValueError(_NAMED_INPUTS_REQUIRED) - - result = named_inputs.get(name) - if result is None and required: - msg = f"input '${name}' is required" - raise ValueError(msg) - return result diff --git a/graphrag/index/utils/load_graph.py b/graphrag/index/utils/load_graph.py deleted file mode 100644 index 57992a04c8..0000000000 --- a/graphrag/index/utils/load_graph.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Networkx load_graph utility definition.""" - -import networkx as nx - - -def load_graph(graphml: str | nx.Graph) -> nx.Graph: - """Load a graph from a graphml file or a networkx graph.""" - return nx.parse_graphml(graphml) if isinstance(graphml, str) else graphml diff --git a/graphrag/index/utils/topological_sort.py b/graphrag/index/utils/topological_sort.py deleted file mode 100644 index a19b464559..0000000000 --- a/graphrag/index/utils/topological_sort.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Topological sort utility method.""" - -from graphlib import TopologicalSorter - - -def topological_sort(graph: dict[str, list[str]]) -> list[str]: - """Topological sort.""" - ts = TopologicalSorter(graph) - return list(ts.static_order()) diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index 07e4638fc3..a98e4cb707 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -6,8 +6,7 @@ import asyncio import sys -from datashaper import NoopVerbCallbacks - +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings from graphrag.logger.print_progress import ProgressLogger diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index db1cb74c7b..a904dc7bb8 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -1,25 +1,108 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""The Indexing Engine workflows package root.""" -from graphrag.index.workflows.load import create_workflow, load_workflows -from graphrag.index.workflows.typing import ( - StepDefinition, - VerbDefinitions, - VerbTiming, - WorkflowConfig, - WorkflowDefinitions, - WorkflowToRun, +"""A package containing all built-in workflow definitions.""" + +from collections.abc import Awaitable, Callable + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext + +from .compute_communities import ( + run_workflow as run_compute_communities, +) +from .compute_communities import ( + workflow_name as compute_communities, +) +from .create_base_text_units import ( + run_workflow as run_create_base_text_units, +) +from .create_base_text_units import ( + workflow_name as create_base_text_units, +) +from .create_final_communities import ( + run_workflow as run_create_final_communities, +) +from .create_final_communities import ( + workflow_name as create_final_communities, +) +from .create_final_community_reports import ( + run_workflow as run_create_final_community_reports, +) +from .create_final_community_reports import ( + workflow_name as create_final_community_reports, +) +from .create_final_covariates import ( + run_workflow as run_create_final_covariates, +) +from .create_final_covariates import ( + workflow_name as create_final_covariates, +) +from .create_final_documents import ( + run_workflow as run_create_final_documents, +) +from .create_final_documents import ( + workflow_name as create_final_documents, +) +from .create_final_entities import ( + run_workflow as run_create_final_entities, +) +from .create_final_entities import ( + workflow_name as create_final_entities, +) +from .create_final_nodes import ( + run_workflow as run_create_final_nodes, +) +from .create_final_nodes import ( + workflow_name as create_final_nodes, +) +from .create_final_relationships import ( + run_workflow as run_create_final_relationships, +) +from .create_final_relationships import ( + workflow_name as create_final_relationships, +) +from .create_final_text_units import ( + run_workflow as run_create_final_text_units, +) +from .create_final_text_units import ( + workflow_name as create_final_text_units, +) +from .extract_graph import ( + run_workflow as run_extract_graph, +) +from .extract_graph import ( + workflow_name as extract_graph, +) +from .generate_text_embeddings import ( + run_workflow as run_generate_text_embeddings, +) +from .generate_text_embeddings import ( + workflow_name as generate_text_embeddings, ) -__all__ = [ - "StepDefinition", - "VerbDefinitions", - "VerbTiming", - "WorkflowConfig", - "WorkflowDefinitions", - "WorkflowToRun", - "create_workflow", - "load_workflows", -] +all_workflows: dict[ + str, + Callable[ + [GraphRagConfig, PipelineRunContext, VerbCallbacks], + Awaitable[pd.DataFrame | None], + ], +] = { + compute_communities: run_compute_communities, + create_base_text_units: run_create_base_text_units, + create_final_communities: run_create_final_communities, + create_final_community_reports: run_create_final_community_reports, + create_final_covariates: run_create_final_covariates, + create_final_documents: run_create_final_documents, + create_final_entities: run_create_final_entities, + create_final_nodes: run_create_final_nodes, + create_final_relationships: run_create_final_relationships, + create_final_text_units: run_create_final_text_units, + extract_graph: run_extract_graph, + generate_text_embeddings: run_generate_text_embeddings, +} +"""This is a dictionary of all build-in workflows. To be replace with an injectable provider!""" diff --git a/graphrag/index/workflows/compute_communities.py b/graphrag/index/workflows/compute_communities.py new file mode 100644 index 0000000000..51cf511d50 --- /dev/null +++ b/graphrag/index/workflows/compute_communities.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.compute_communities import compute_communities +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "compute_communities" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to create the base communities.""" + base_relationship_edges = await load_table_from_storage( + "base_relationship_edges", context.storage + ) + + max_cluster_size = config.cluster_graph.max_cluster_size + use_lcc = config.cluster_graph.use_lcc + seed = config.cluster_graph.seed + + base_communities = compute_communities( + base_relationship_edges, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=seed, + ) + + await write_table_to_storage(base_communities, "base_communities", context.storage) + + return base_communities diff --git a/graphrag/index/workflows/create_base_text_units.py b/graphrag/index/workflows/create_base_text_units.py new file mode 100644 index 0000000000..91d5822884 --- /dev/null +++ b/graphrag/index/workflows/create_base_text_units.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_base_text_units import ( + create_base_text_units, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_base_text_units" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform base text_units.""" + documents = await load_table_from_storage("input", context.storage) + + chunks = config.chunks + + output = create_base_text_units( + documents, + callbacks, + chunks.group_by_columns, + chunks.size, + chunks.overlap, + chunks.encoding_model, + strategy=chunks.strategy, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_communities.py b/graphrag/index/workflows/create_final_communities.py new file mode 100644 index 0000000000..e1cf950e97 --- /dev/null +++ b/graphrag/index/workflows/create_final_communities.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_communities import ( + create_final_communities, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_communities" + + +async def run_workflow( + _config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform final communities.""" + base_entity_nodes = await load_table_from_storage( + "base_entity_nodes", context.storage + ) + base_relationship_edges = await load_table_from_storage( + "base_relationship_edges", context.storage + ) + base_communities = await load_table_from_storage( + "base_communities", context.storage + ) + + output = create_final_communities( + base_entity_nodes, + base_relationship_edges, + base_communities, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_community_reports.py b/graphrag/index/workflows/create_final_community_reports.py new file mode 100644 index 0000000000..7aacc79fbf --- /dev/null +++ b/graphrag/index/workflows/create_final_community_reports.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_community_reports import ( + create_final_community_reports, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_community_reports" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform community reports.""" + nodes = await load_table_from_storage("create_final_nodes", context.storage) + edges = await load_table_from_storage("create_final_relationships", context.storage) + entities = await load_table_from_storage("create_final_entities", context.storage) + communities = await load_table_from_storage( + "create_final_communities", context.storage + ) + claims = None + if config.claim_extraction.enabled: + claims = await load_table_from_storage( + "create_final_covariates", context.storage + ) + async_mode = config.community_reports.async_mode + num_threads = config.community_reports.parallelization.num_threads + summarization_strategy = config.community_reports.resolved_strategy(config.root_dir) + + output = await create_final_community_reports( + nodes, + edges, + entities, + communities, + claims, + callbacks, + context.cache, + summarization_strategy, + async_mode=async_mode, + num_threads=num_threads, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_covariates.py b/graphrag/index/workflows/create_final_covariates.py new file mode 100644 index 0000000000..9ab91fdf16 --- /dev/null +++ b/graphrag/index/workflows/create_final_covariates.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_covariates import ( + create_final_covariates, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_covariates" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to extract and format covariates.""" + text_units = await load_table_from_storage( + "create_base_text_units", context.storage + ) + + extraction_strategy = config.claim_extraction.resolved_strategy( + config.root_dir, config.encoding_model + ) + + async_mode = config.claim_extraction.async_mode + num_threads = config.claim_extraction.parallelization.num_threads + + output = await create_final_covariates( + text_units, + callbacks, + context.cache, + "claim", + extraction_strategy, + async_mode=async_mode, + entity_types=None, + num_threads=num_threads, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_documents.py b/graphrag/index/workflows/create_final_documents.py new file mode 100644 index 0000000000..bbc1490b8f --- /dev/null +++ b/graphrag/index/workflows/create_final_documents.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_documents import ( + create_final_documents, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_documents" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform final documents.""" + documents = await load_table_from_storage("input", context.storage) + text_units = await load_table_from_storage( + "create_base_text_units", context.storage + ) + + input = config.input + output = create_final_documents( + documents, text_units, input.document_attribute_columns + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_entities.py b/graphrag/index/workflows/create_final_entities.py new file mode 100644 index 0000000000..565da6cf6b --- /dev/null +++ b/graphrag/index/workflows/create_final_entities.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_entities import ( + create_final_entities, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_entities" + + +async def run_workflow( + _config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform final entities.""" + base_entity_nodes = await load_table_from_storage( + "base_entity_nodes", context.storage + ) + + output = create_final_entities(base_entity_nodes) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_nodes.py b/graphrag/index/workflows/create_final_nodes.py new file mode 100644 index 0000000000..aa1ec3c177 --- /dev/null +++ b/graphrag/index/workflows/create_final_nodes.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_nodes import ( + create_final_nodes, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_nodes" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform final nodes.""" + base_entity_nodes = await load_table_from_storage( + "base_entity_nodes", context.storage + ) + base_relationship_edges = await load_table_from_storage( + "base_relationship_edges", context.storage + ) + base_communities = await load_table_from_storage( + "base_communities", context.storage + ) + + embed_config = config.embed_graph + layout_enabled = config.umap.enabled + + output = create_final_nodes( + base_entity_nodes, + base_relationship_edges, + base_communities, + callbacks, + embed_config=embed_config, + layout_enabled=layout_enabled, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_relationships.py b/graphrag/index/workflows/create_final_relationships.py new file mode 100644 index 0000000000..f6896420b0 --- /dev/null +++ b/graphrag/index/workflows/create_final_relationships.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_relationships import ( + create_final_relationships, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_relationships" + + +async def run_workflow( + _config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform final relationships.""" + base_relationship_edges = await load_table_from_storage( + "base_relationship_edges", context.storage + ) + + output = create_final_relationships(base_relationship_edges) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/create_final_text_units.py b/graphrag/index/workflows/create_final_text_units.py new file mode 100644 index 0000000000..d9d49fec4f --- /dev/null +++ b/graphrag/index/workflows/create_final_text_units.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.create_final_text_units import ( + create_final_text_units, +) +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "create_final_text_units" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + _callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform the text units.""" + text_units = await load_table_from_storage( + "create_base_text_units", context.storage + ) + final_entities = await load_table_from_storage( + "create_final_entities", context.storage + ) + final_relationships = await load_table_from_storage( + "create_final_relationships", context.storage + ) + final_covariates = None + if config.claim_extraction.enabled: + final_covariates = await load_table_from_storage( + "create_final_covariates", context.storage + ) + + output = create_final_text_units( + text_units, + final_entities, + final_relationships, + final_covariates, + ) + + await write_table_to_storage(output, workflow_name, context.storage) + + return output diff --git a/graphrag/index/workflows/default_workflows.py b/graphrag/index/workflows/default_workflows.py deleted file mode 100644 index 009f9fa8ce..0000000000 --- a/graphrag/index/workflows/default_workflows.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A package containing default workflows definitions.""" - -from graphrag.index.workflows.typing import WorkflowDefinitions -from graphrag.index.workflows.v1.compute_communities import ( - build_steps as build_compute_communities_steps, -) -from graphrag.index.workflows.v1.compute_communities import ( - workflow_name as compute_communities, -) -from graphrag.index.workflows.v1.create_base_text_units import ( - build_steps as build_create_base_text_units_steps, -) -from graphrag.index.workflows.v1.create_base_text_units import ( - workflow_name as create_base_text_units, -) -from graphrag.index.workflows.v1.create_final_communities import ( - build_steps as build_create_final_communities_steps, -) -from graphrag.index.workflows.v1.create_final_communities import ( - workflow_name as create_final_communities, -) -from graphrag.index.workflows.v1.create_final_community_reports import ( - build_steps as build_create_final_community_reports_steps, -) -from graphrag.index.workflows.v1.create_final_community_reports import ( - workflow_name as create_final_community_reports, -) -from graphrag.index.workflows.v1.create_final_covariates import ( - build_steps as build_create_final_covariates_steps, -) -from graphrag.index.workflows.v1.create_final_covariates import ( - workflow_name as create_final_covariates, -) -from graphrag.index.workflows.v1.create_final_documents import ( - build_steps as build_create_final_documents_steps, -) -from graphrag.index.workflows.v1.create_final_documents import ( - workflow_name as create_final_documents, -) -from graphrag.index.workflows.v1.create_final_entities import ( - build_steps as build_create_final_entities_steps, -) -from graphrag.index.workflows.v1.create_final_entities import ( - workflow_name as create_final_entities, -) -from graphrag.index.workflows.v1.create_final_nodes import ( - build_steps as build_create_final_nodes_steps, -) -from graphrag.index.workflows.v1.create_final_nodes import ( - workflow_name as create_final_nodes, -) -from graphrag.index.workflows.v1.create_final_relationships import ( - build_steps as build_create_final_relationships_steps, -) -from graphrag.index.workflows.v1.create_final_relationships import ( - workflow_name as create_final_relationships, -) -from graphrag.index.workflows.v1.create_final_text_units import ( - build_steps as build_create_final_text_units, -) -from graphrag.index.workflows.v1.create_final_text_units import ( - workflow_name as create_final_text_units, -) -from graphrag.index.workflows.v1.extract_graph import ( - build_steps as build_extract_graph_steps, -) -from graphrag.index.workflows.v1.extract_graph import ( - workflow_name as extract_graph, -) -from graphrag.index.workflows.v1.generate_text_embeddings import ( - build_steps as build_generate_text_embeddings_steps, -) -from graphrag.index.workflows.v1.generate_text_embeddings import ( - workflow_name as generate_text_embeddings, -) - -default_workflows: WorkflowDefinitions = { - extract_graph: build_extract_graph_steps, - compute_communities: build_compute_communities_steps, - create_base_text_units: build_create_base_text_units_steps, - create_final_text_units: build_create_final_text_units, - create_final_community_reports: build_create_final_community_reports_steps, - create_final_nodes: build_create_final_nodes_steps, - create_final_relationships: build_create_final_relationships_steps, - create_final_documents: build_create_final_documents_steps, - create_final_covariates: build_create_final_covariates_steps, - create_final_entities: build_create_final_entities_steps, - create_final_communities: build_create_final_communities_steps, - generate_text_embeddings: build_generate_text_embeddings_steps, -} diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py new file mode 100644 index 0000000000..454bf7806a --- /dev/null +++ b/graphrag/index/workflows/extract_graph.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.extract_graph import ( + extract_graph, +) +from graphrag.index.operations.create_graph import create_graph +from graphrag.index.operations.snapshot_graphml import snapshot_graphml +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +workflow_name = "extract_graph" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to create the base entity graph.""" + text_units = await load_table_from_storage( + "create_base_text_units", context.storage + ) + + extraction_strategy = config.entity_extraction.resolved_strategy( + config.root_dir, config.encoding_model + ) + extraction_num_threads = config.entity_extraction.parallelization.num_threads + extraction_async_mode = config.entity_extraction.async_mode + entity_types = config.entity_extraction.entity_types + + summarization_strategy = config.summarize_descriptions.resolved_strategy( + config.root_dir, + ) + summarization_num_threads = ( + config.summarize_descriptions.parallelization.num_threads + ) + + base_entity_nodes, base_relationship_edges = await extract_graph( + text_units, + callbacks, + context.cache, + extraction_strategy=extraction_strategy, + extraction_num_threads=extraction_num_threads, + extraction_async_mode=extraction_async_mode, + entity_types=entity_types, + summarization_strategy=summarization_strategy, + summarization_num_threads=summarization_num_threads, + ) + + await write_table_to_storage( + base_entity_nodes, "base_entity_nodes", context.storage + ) + await write_table_to_storage( + base_relationship_edges, "base_relationship_edges", context.storage + ) + + if config.snapshots.graphml: + # todo: extract graphs at each level, and add in meta like descriptions + graph = create_graph(base_relationship_edges) + await snapshot_graphml( + graph, + name="graph", + storage=context.storage, + ) diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py new file mode 100644 index 0000000000..29a8bf0988 --- /dev/null +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import pandas as pd + +from graphrag.callbacks.verb_callbacks import VerbCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings +from graphrag.index.context import PipelineRunContext +from graphrag.index.flows.generate_text_embeddings import ( + generate_text_embeddings, +) +from graphrag.utils.storage import load_table_from_storage + +workflow_name = "generate_text_embeddings" + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, + callbacks: VerbCallbacks, +) -> pd.DataFrame | None: + """All the steps to transform community reports.""" + final_documents = await load_table_from_storage( + "create_final_documents", context.storage + ) + final_relationships = await load_table_from_storage( + "create_final_relationships", context.storage + ) + final_text_units = await load_table_from_storage( + "create_final_text_units", context.storage + ) + final_entities = await load_table_from_storage( + "create_final_entities", context.storage + ) + final_community_reports = await load_table_from_storage( + "create_final_community_reports", context.storage + ) + + embedded_fields = get_embedded_fields(config) + text_embed = get_embedding_settings(config.embeddings) + + await generate_text_embeddings( + final_documents=final_documents, + final_relationships=final_relationships, + final_text_units=final_text_units, + final_entities=final_entities, + final_community_reports=final_community_reports, + callbacks=callbacks, + cache=context.cache, + storage=context.storage, + text_embed_config=text_embed, + embedded_fields=embedded_fields, + snapshot_embeddings_enabled=config.snapshots.embeddings, + ) diff --git a/graphrag/index/workflows/load.py b/graphrag/index/workflows/load.py deleted file mode 100644 index 5aa874ecba..0000000000 --- a/graphrag/index/workflows/load.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing load_workflows, create_workflow, _get_steps_for_workflow and _remove_disabled_steps methods definition.""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, NamedTuple, cast - -from datashaper import Workflow - -from graphrag.index.errors import ( - NoWorkflowsDefinedError, - UndefinedWorkflowError, - UnknownWorkflowError, -) -from graphrag.index.utils.topological_sort import topological_sort -from graphrag.index.workflows.default_workflows import ( - default_workflows as _default_workflows, -) -from graphrag.index.workflows.typing import ( - VerbDefinitions, - WorkflowDefinitions, - WorkflowToRun, -) - -if TYPE_CHECKING: - from graphrag.index.config.workflow import ( - PipelineWorkflowConfig, - PipelineWorkflowReference, - PipelineWorkflowStep, - ) - -anonymous_workflow_count = 0 - -VerbFn = Callable[..., Any] -log = logging.getLogger(__name__) - - -class LoadWorkflowResult(NamedTuple): - """A workflow loading result object.""" - - workflows: list[WorkflowToRun] - """The loaded workflow names in the order they should be run.""" - - dependencies: dict[str, list[str]] - """A dictionary of workflow name to workflow dependencies.""" - - -def load_workflows( - workflows_to_load: list[PipelineWorkflowReference], - additional_verbs: VerbDefinitions | None = None, - additional_workflows: WorkflowDefinitions | None = None, - memory_profile: bool = False, -) -> LoadWorkflowResult: - """Load the given workflows. - - Args: - - workflows_to_load - The workflows to load - - additional_verbs - The list of custom verbs available to the workflows - - additional_workflows - The list of custom workflows - Returns: - - output[0] - The loaded workflow names in the order they should be run - - output[1] - A dictionary of workflow name to workflow dependencies - """ - workflow_graph: dict[str, WorkflowToRun] = {} - - global anonymous_workflow_count - for reference in workflows_to_load: - name = reference.name - is_anonymous = name is None or name.strip() == "" - if is_anonymous: - name = f"Anonymous Workflow {anonymous_workflow_count}" - anonymous_workflow_count += 1 - name = cast("str", name) - - config = reference.config - workflow = create_workflow( - name or "MISSING NAME!", - reference.steps, - config, - additional_verbs, - additional_workflows, - ) - workflow_graph[name] = WorkflowToRun(workflow, config=config or {}) - - # Backfill any missing workflows - for name in list(workflow_graph.keys()): - workflow = workflow_graph[name] - deps = [ - d.replace("workflow:", "") - for d in workflow.workflow.dependencies - if d.startswith("workflow:") - ] - for dependency in deps: - if dependency not in workflow_graph: - reference = {"name": dependency, **workflow.config} - workflow_graph[dependency] = WorkflowToRun( - workflow=create_workflow( - dependency, - config=reference, - additional_verbs=additional_verbs, - additional_workflows=additional_workflows, - memory_profile=memory_profile, - ), - config=reference, - ) - - # Run workflows in order of dependencies - def filter_wf_dependencies(name: str) -> list[str]: - externals = [ - e.replace("workflow:", "") - for e in workflow_graph[name].workflow.dependencies - ] - return [e for e in externals if e in workflow_graph] - - task_graph = {name: filter_wf_dependencies(name) for name in workflow_graph} - workflow_run_order = topological_sort(task_graph) - workflows = [workflow_graph[name] for name in workflow_run_order] - log.info("Workflow Run Order: %s", workflow_run_order) - return LoadWorkflowResult(workflows=workflows, dependencies=task_graph) - - -def create_workflow( - name: str, - steps: list[PipelineWorkflowStep] | None = None, - config: PipelineWorkflowConfig | None = None, - additional_verbs: VerbDefinitions | None = None, - additional_workflows: WorkflowDefinitions | None = None, - memory_profile: bool = False, -) -> Workflow: - """Create a workflow from the given config.""" - additional_workflows = { - **_default_workflows, - **(additional_workflows or {}), - } - steps = steps or _get_steps_for_workflow(name, config, additional_workflows) - return Workflow( - verbs=additional_verbs or {}, - schema={ - "name": name, - "steps": steps, - }, - validate=False, - memory_profile=memory_profile, - ) - - -def _get_steps_for_workflow( - name: str | None, - config: PipelineWorkflowConfig | None, - workflows: dict[str, Callable] | None, -) -> list[PipelineWorkflowStep]: - """Get the steps for the given workflow config.""" - if config is not None and "steps" in config: - return config["steps"] - - if workflows is None: - raise NoWorkflowsDefinedError - - if name is None: - raise UndefinedWorkflowError - - if name not in workflows: - raise UnknownWorkflowError(name) - - return workflows[name](config or {}) diff --git a/graphrag/index/workflows/typing.py b/graphrag/index/workflows/typing.py deleted file mode 100644 index 3b44545bd4..0000000000 --- a/graphrag/index/workflows/typing.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing 'WorkflowToRun' model.""" - -from collections.abc import Callable -from dataclasses import dataclass as dc_dataclass -from typing import Any - -from datashaper import TableContainer, Workflow - -StepDefinition = dict[str, Any] -"""A step definition.""" - -VerbDefinitions = dict[str, Callable[..., TableContainer]] -"""A mapping of verb names to their implementations.""" - -WorkflowConfig = dict[str, Any] -"""A workflow configuration.""" - -WorkflowDefinitions = dict[str, Callable[[WorkflowConfig], list[StepDefinition]]] -"""A mapping of workflow names to their implementations.""" - -VerbTiming = dict[str, float] -"""The timings of verbs by id.""" - - -@dc_dataclass -class WorkflowToRun: - """Workflow to run class definition.""" - - workflow: Workflow - config: dict[str, Any] diff --git a/graphrag/index/workflows/v1/__init__.py b/graphrag/index/workflows/v1/__init__.py deleted file mode 100644 index 69518f5ee2..0000000000 --- a/graphrag/index/workflows/v1/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine workflows package root.""" diff --git a/graphrag/index/workflows/v1/compute_communities.py b/graphrag/index/workflows/v1/compute_communities.py deleted file mode 100644 index 3e70725c32..0000000000 --- a/graphrag/index/workflows/v1/compute_communities.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import TYPE_CHECKING, cast - -import pandas as pd -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.compute_communities import compute_communities -from graphrag.index.operations.snapshot import snapshot -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - from graphrag.config.models.cluster_graph_config import ClusterGraphConfig - -workflow_name = "compute_communities" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base communities from the graph edges. - - ## Dependencies - * `workflow:extract_graph` - """ - clustering_config = cast("ClusterGraphConfig", config.get("cluster_graph")) - max_cluster_size = clustering_config.max_cluster_size - use_lcc = clustering_config.use_lcc - seed = clustering_config.seed - - snapshot_transient = config.get("snapshot_transient", False) or False - - return [ - { - "verb": workflow_name, - "args": { - "max_cluster_size": max_cluster_size, - "use_lcc": use_lcc, - "seed": seed, - "snapshot_transient_enabled": snapshot_transient, - }, - "input": ({"source": "workflow:extract_graph"}), - }, - ] - - -@verb( - name=workflow_name, - treats_input_tables_as_immutable=True, -) -async def workflow( - storage: PipelineStorage, - runtime_storage: PipelineStorage, - max_cluster_size: int, - use_lcc: bool, - seed: int | None, - snapshot_transient_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to create the base entity graph.""" - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - - base_communities = compute_communities( - base_relationship_edges, - max_cluster_size=max_cluster_size, - use_lcc=use_lcc, - seed=seed, - ) - - await runtime_storage.set("base_communities", base_communities) - - if snapshot_transient_enabled: - await snapshot( - base_communities, - name="base_communities", - storage=storage, - formats=["parquet"], - ) - - return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py deleted file mode 100644 index 84f5366df8..0000000000 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import TYPE_CHECKING, cast - -import pandas as pd -from datashaper import ( - DEFAULT_INPUT_NAME, - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.config.models.chunking_config import ChunkStrategyType -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_base_text_units import ( - create_base_text_units, -) -from graphrag.index.operations.snapshot import snapshot -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - from graphrag.config.models.chunking_config import ChunkingConfig - -workflow_name = "create_base_text_units" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for text units. - - ## Dependencies - (input dataframe) - """ - chunks = cast("ChunkingConfig", config.get("chunks")) - group_by_columns = chunks.group_by_columns - size = chunks.size - overlap = chunks.overlap - encoding_model = chunks.encoding_model - strategy = chunks.strategy - - snapshot_transient = config.get("snapshot_transient", False) or False - return [ - { - "verb": workflow_name, - "args": { - "group_by_columns": group_by_columns, - "size": size, - "overlap": overlap, - "encoding_model": encoding_model, - "strategy": strategy, - "snapshot_transient_enabled": snapshot_transient, - }, - "input": {"source": DEFAULT_INPUT_NAME}, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - input: VerbInput, - callbacks: VerbCallbacks, - storage: PipelineStorage, - runtime_storage: PipelineStorage, - group_by_columns: list[str], - size: int, - overlap: int, - encoding_model: str, - strategy: ChunkStrategyType, - snapshot_transient_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform base text_units.""" - source = cast("pd.DataFrame", input.get_input()) - - output = create_base_text_units( - source, - callbacks, - group_by_columns, - size, - overlap, - encoding_model, - strategy=strategy, - ) - - await runtime_storage.set("base_text_units", output) - - if snapshot_transient_enabled: - await snapshot( - output, - name="create_base_text_units", - storage=storage, - formats=["parquet"], - ) - - return create_verb_result( - cast( - "Table", - pd.DataFrame(), - ) - ) diff --git a/graphrag/index/workflows/v1/create_final_communities.py b/graphrag/index/workflows/v1/create_final_communities.py deleted file mode 100644 index c9683991c5..0000000000 --- a/graphrag/index/workflows/v1/create_final_communities.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_communities import ( - create_final_communities, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "create_final_communities" - - -def build_steps( - _config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final communities table. - - ## Dependencies - * `workflow:extract_graph` - """ - return [ - { - "verb": workflow_name, - "input": {"source": "workflow:extract_graph"}, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final communities.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await runtime_storage.get("base_communities") - output = create_final_communities( - base_entity_nodes, - base_relationship_edges, - base_communities, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py deleted file mode 100644 index 401a4bffab..0000000000 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_community_reports import ( - create_final_community_reports, -) -from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table - -if TYPE_CHECKING: - import pandas as pd - -workflow_name = "create_final_community_reports" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final community reports table. - - ## Dependencies - * `workflow:extract_graph` - """ - covariates_enabled = config.get("covariates_enabled", False) - create_community_reports_config = config.get("create_community_reports", {}) - summarization_strategy = create_community_reports_config.get("strategy") - async_mode = create_community_reports_config.get("async_mode") - num_threads = create_community_reports_config.get("num_threads") - - input = { - "source": "workflow:create_final_nodes", - "relationships": "workflow:create_final_relationships", - "entities": "workflow:create_final_entities", - "communities": "workflow:create_final_communities", - } - if covariates_enabled: - input["covariates"] = "workflow:create_final_covariates" - - return [ - { - "verb": workflow_name, - "args": { - "summarization_strategy": summarization_strategy, - "async_mode": async_mode, - "num_threads": num_threads, - }, - "input": input, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - summarization_strategy: dict, - async_mode: AsyncType = AsyncType.AsyncIO, - num_threads: int = 4, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform community reports.""" - nodes = cast("pd.DataFrame", input.get_input()) - edges = cast("pd.DataFrame", get_required_input_table(input, "relationships").table) - entities = cast("pd.DataFrame", get_required_input_table(input, "entities").table) - communities = cast( - "pd.DataFrame", get_required_input_table(input, "communities").table - ) - - claims = get_named_input_table(input, "covariates") - if claims: - claims = cast("pd.DataFrame", claims.table) - - output = await create_final_community_reports( - nodes, - edges, - entities, - communities, - claims, - callbacks, - cache, - summarization_strategy, - async_mode=async_mode, - num_threads=num_threads, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/create_final_covariates.py b/graphrag/index/workflows/v1/create_final_covariates.py deleted file mode 100644 index 2804e389f3..0000000000 --- a/graphrag/index/workflows/v1/create_final_covariates.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import Any, cast - -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_covariates import ( - create_final_covariates, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "create_final_covariates" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final covariates table. - - ## Dependencies - * `workflow:create_base_text_units` - """ - claim_extract_config = config.get("claim_extract", {}) - extraction_strategy = claim_extract_config.get("strategy") - async_mode = claim_extract_config.get("async_mode", AsyncType.AsyncIO) - num_threads = claim_extract_config.get("num_threads") - - return [ - { - "verb": workflow_name, - "args": { - "covariate_type": "claim", - "extraction_strategy": extraction_strategy, - "async_mode": async_mode, - "num_threads": num_threads, - }, - "input": {"source": "workflow:create_base_text_units"}, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - callbacks: VerbCallbacks, - cache: PipelineCache, - runtime_storage: PipelineStorage, - covariate_type: str, - extraction_strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - num_threads: int = 4, - **_kwargs: dict, -) -> VerbResult: - """All the steps to extract and format covariates.""" - text_units = await runtime_storage.get("base_text_units") - - output = await create_final_covariates( - text_units, - callbacks, - cache, - covariate_type, - extraction_strategy, - async_mode=async_mode, - entity_types=entity_types, - num_threads=num_threads, - ) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py deleted file mode 100644 index a9b5af67fd..0000000000 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - DEFAULT_INPUT_NAME, - Table, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_documents import ( - create_final_documents, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - import pandas as pd - - -workflow_name = "create_final_documents" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final documents table. - - ## Dependencies - * `workflow:create_base_text_units` - """ - document_attribute_columns = config.get("document_attribute_columns", None) - return [ - { - "verb": workflow_name, - "args": {"document_attribute_columns": document_attribute_columns}, - "input": { - "source": DEFAULT_INPUT_NAME, - "text_units": "workflow:create_base_text_units", - }, - }, - ] - - -@verb( - name=workflow_name, - treats_input_tables_as_immutable=True, -) -async def workflow( - input: VerbInput, - runtime_storage: PipelineStorage, - document_attribute_columns: list[str] | None = None, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final documents.""" - source = cast("pd.DataFrame", input.get_input()) - text_units = await runtime_storage.get("base_text_units") - - output = create_final_documents(source, text_units, document_attribute_columns) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py deleted file mode 100644 index 35a86bbdff..0000000000 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -import logging -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_entities import ( - create_final_entities, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "create_final_entities" -log = logging.getLogger(__name__) - - -def build_steps( - config: PipelineWorkflowConfig, # noqa: ARG001 -) -> list[PipelineWorkflowStep]: - """ - Create the final entities table. - - ## Dependencies - * `workflow:extract_graph` - """ - return [ - { - "verb": workflow_name, - "args": {}, - "input": {"source": "workflow:extract_graph"}, - }, - ] - - -@verb( - name=workflow_name, - treats_input_tables_as_immutable=True, -) -async def workflow( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final entities.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - - output = create_final_entities(base_entity_nodes) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_nodes.py b/graphrag/index/workflows/v1/create_final_nodes.py deleted file mode 100644 index bdbfab084e..0000000000 --- a/graphrag/index/workflows/v1/create_final_nodes.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import cast - -from datashaper import ( - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.config.models.embed_graph_config import EmbedGraphConfig -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_nodes import ( - create_final_nodes, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "create_final_nodes" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for the document graph. - - ## Dependencies - * `workflow:extract_graph` - """ - layout_enabled = config["layout_enabled"] - embed_config = cast("EmbedGraphConfig", config["embed_graph"]) - - return [ - { - "verb": workflow_name, - "args": {"layout_enabled": layout_enabled, "embed_config": embed_config}, - "input": { - "source": "workflow:extract_graph", - "communities": "workflow:compute_communities", - }, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - callbacks: VerbCallbacks, - runtime_storage: PipelineStorage, - embed_config: EmbedGraphConfig, - layout_enabled: bool, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final nodes.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await runtime_storage.get("base_communities") - - output = create_final_nodes( - base_entity_nodes, - base_relationship_edges, - base_communities, - callbacks, - embed_config=embed_config, - layout_enabled=layout_enabled, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py deleted file mode 100644 index 603b03f75d..0000000000 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -import logging -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_relationships import ( - create_final_relationships, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "create_final_relationships" - -log = logging.getLogger(__name__) - - -def build_steps( - config: PipelineWorkflowConfig, # noqa: ARG001 -) -> list[PipelineWorkflowStep]: - """ - Create the final relationships table. - - ## Dependencies - * `workflow:extract_graph` - """ - return [ - { - "verb": workflow_name, - "args": {}, - "input": { - "source": "workflow:extract_graph", - }, - }, - ] - - -@verb( - name=workflow_name, - treats_input_tables_as_immutable=True, -) -async def workflow( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final relationships.""" - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - - output = create_final_relationships(base_relationship_edges) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py deleted file mode 100644 index 887477c593..0000000000 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - Table, - VerbInput, - VerbResult, - create_verb_result, - verb, -) - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.create_final_text_units import ( - create_final_text_units, -) -from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - import pandas as pd - -workflow_name = "create_final_text_units" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final text-units table. - - ## Dependencies - * `workflow:create_base_text_units` - * `workflow:create_final_entities` - * `workflow:create_final_communities` - """ - covariates_enabled = config.get("covariates_enabled", False) - - input = { - "source": "workflow:create_base_text_units", - "entities": "workflow:create_final_entities", - "relationships": "workflow:create_final_relationships", - } - - if covariates_enabled: - input["covariates"] = "workflow:create_final_covariates" - - return [ - { - "verb": workflow_name, - "args": {}, - "input": input, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - input: VerbInput, - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform the text units.""" - text_units = await runtime_storage.get("base_text_units") - final_entities = cast( - "pd.DataFrame", get_required_input_table(input, "entities").table - ) - final_relationships = cast( - "pd.DataFrame", get_required_input_table(input, "relationships").table - ) - final_covariates = get_named_input_table(input, "covariates") - - if final_covariates: - final_covariates = cast("pd.DataFrame", final_covariates.table) - - output = create_final_text_units( - text_units, - final_entities, - final_relationships, - final_covariates, - ) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/extract_graph.py b/graphrag/index/workflows/v1/extract_graph.py deleted file mode 100644 index 65016d6a6a..0000000000 --- a/graphrag/index/workflows/v1/extract_graph.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.extract_graph import ( - extract_graph, -) -from graphrag.index.operations.create_graph import create_graph -from graphrag.index.operations.snapshot import snapshot -from graphrag.index.operations.snapshot_graphml import snapshot_graphml -from graphrag.storage.pipeline_storage import PipelineStorage - -workflow_name = "extract_graph" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for the entity graph. - - ## Dependencies - * `workflow:create_base_text_units` - """ - entity_extraction_config = config.get("entity_extract", {}) - async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) - extraction_strategy = entity_extraction_config.get("strategy") - extraction_num_threads = entity_extraction_config.get("num_threads", 4) - entity_types = entity_extraction_config.get("entity_types") - - summarize_descriptions_config = config.get("summarize_descriptions", {}) - summarization_strategy = summarize_descriptions_config.get("strategy") - summarization_num_threads = summarize_descriptions_config.get("num_threads", 4) - - snapshot_graphml = config.get("snapshot_graphml", False) or False - snapshot_transient = config.get("snapshot_transient", False) or False - - return [ - { - "verb": workflow_name, - "args": { - "extraction_strategy": extraction_strategy, - "extraction_num_threads": extraction_num_threads, - "extraction_async_mode": async_mode, - "entity_types": entity_types, - "summarization_strategy": summarization_strategy, - "summarization_num_threads": summarization_num_threads, - "snapshot_graphml_enabled": snapshot_graphml, - "snapshot_transient_enabled": snapshot_transient, - }, - "input": ({"source": "workflow:create_base_text_units"}), - }, - ] - - -@verb( - name=workflow_name, - treats_input_tables_as_immutable=True, -) -async def workflow( - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - runtime_storage: PipelineStorage, - extraction_strategy: dict[str, Any] | None, - extraction_num_threads: int = 4, - extraction_async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - summarization_strategy: dict[str, Any] | None = None, - summarization_num_threads: int = 4, - snapshot_graphml_enabled: bool = False, - snapshot_transient_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to create the base entity graph.""" - text_units = await runtime_storage.get("base_text_units") - - base_entity_nodes, base_relationship_edges = await extract_graph( - text_units, - callbacks, - cache, - extraction_strategy=extraction_strategy, - extraction_num_threads=extraction_num_threads, - extraction_async_mode=extraction_async_mode, - entity_types=entity_types, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_num_threads, - ) - - await runtime_storage.set("base_entity_nodes", base_entity_nodes) - await runtime_storage.set("base_relationship_edges", base_relationship_edges) - - if snapshot_graphml_enabled: - # todo: extract graphs at each level, and add in meta like descriptions - graph = create_graph(base_relationship_edges) - await snapshot_graphml( - graph, - name="graph", - storage=storage, - ) - - if snapshot_transient_enabled: - await snapshot( - base_entity_nodes, - name="base_entity_nodes", - storage=storage, - formats=["parquet"], - ) - await snapshot( - base_relationship_edges, - name="base_relationship_edges", - storage=storage, - formats=["parquet"], - ) - - return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/generate_text_embeddings.py b/graphrag/index/workflows/v1/generate_text_embeddings.py deleted file mode 100644 index 5af6f354ea..0000000000 --- a/graphrag/index/workflows/v1/generate_text_embeddings.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -import logging -from typing import cast - -import pandas as pd -from datashaper import ( - Table, - VerbCallbacks, - VerbInput, - VerbResult, - create_verb_result, - verb, -) - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep -from graphrag.index.flows.generate_text_embeddings import ( - generate_text_embeddings, -) -from graphrag.index.utils.ds_util import get_required_input_table -from graphrag.storage.pipeline_storage import PipelineStorage - -log = logging.getLogger(__name__) - -workflow_name = "generate_text_embeddings" - -input = { - "source": "workflow:create_final_documents", - "relationships": "workflow:create_final_relationships", - "text_units": "workflow:create_final_text_units", - "entities": "workflow:create_final_entities", - "community_reports": "workflow:create_final_community_reports", -} - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the final embeddings files. - - ## Dependencies - * `workflow:create_final_documents` - * `workflow:create_final_relationships` - * `workflow:create_final_text_units` - * `workflow:create_final_entities` - * `workflow:create_final_community_reports` - """ - text_embed = config.get("text_embed", {}) - embedded_fields = config.get("embedded_fields", {}) - snapshot_embeddings = config.get("snapshot_embeddings", False) - return [ - { - "verb": workflow_name, - "args": { - "text_embed": text_embed, - "embedded_fields": embedded_fields, - "snapshot_embeddings_enabled": snapshot_embeddings, - }, - "input": input, - }, - ] - - -@verb(name=workflow_name, treats_input_tables_as_immutable=True) -async def workflow( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - text_embed: dict, - embedded_fields: set[str], - snapshot_embeddings_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to generate embeddings.""" - source = cast("pd.DataFrame", input.get_input()) - final_relationships = cast( - "pd.DataFrame", get_required_input_table(input, "relationships").table - ) - final_text_units = cast( - "pd.DataFrame", get_required_input_table(input, "text_units").table - ) - final_entities = cast( - "pd.DataFrame", get_required_input_table(input, "entities").table - ) - - final_community_reports = cast( - "pd.DataFrame", get_required_input_table(input, "community_reports").table - ) - - await generate_text_embeddings( - final_documents=source, - final_relationships=final_relationships, - final_text_units=final_text_units, - final_entities=final_entities, - final_community_reports=final_community_reports, - callbacks=callbacks, - cache=cache, - storage=storage, - text_embed_config=text_embed, - embedded_fields=embedded_fields, - snapshot_embeddings_enabled=snapshot_embeddings_enabled, - ) - - return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/logger/base.py b/graphrag/logger/base.py index b730668e87..73b5668552 100644 --- a/graphrag/logger/base.py +++ b/graphrag/logger/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Any -from datashaper.progress.types import Progress +from graphrag.logger.progress import Progress class StatusLogger(ABC): diff --git a/graphrag/logger/progress.py b/graphrag/logger/progress.py new file mode 100644 index 0000000000..536786100b --- /dev/null +++ b/graphrag/logger/progress.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Progress reporting types.""" + +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from typing import TypeVar + +T = TypeVar("T") + + +@dataclass +class Progress: + """A class representing the progress of a task.""" + + percent: float | None = None + """0 - 1 progress""" + + description: str | None = None + """Description of the progress""" + + total_items: int | None = None + """Total number of items""" + + completed_items: int | None = None + """Number of items completed""" "" + + +ProgressHandler = Callable[[Progress], None] +"""A function to handle progress reports.""" + + +class ProgressTicker: + """A class that emits progress reports incrementally.""" + + _callback: ProgressHandler | None + _num_total: int + _num_complete: int + + def __init__(self, callback: ProgressHandler | None, num_total: int): + self._callback = callback + self._num_total = num_total + self._num_complete = 0 + + def __call__(self, num_ticks: int = 1) -> None: + """Emit progress.""" + self._num_complete += num_ticks + if self._callback is not None: + self._callback( + Progress( + total_items=self._num_total, completed_items=self._num_complete + ) + ) + + def done(self) -> None: + """Mark the progress as done.""" + if self._callback is not None: + self._callback( + Progress(total_items=self._num_total, completed_items=self._num_total) + ) + + +def progress_ticker(callback: ProgressHandler | None, num_total: int) -> ProgressTicker: + """Create a progress ticker.""" + return ProgressTicker(callback, num_total) + + +def progress_iterable( + iterable: Iterable[T], + progress: ProgressHandler | None, + num_total: int | None = None, +) -> Iterable[T]: + """Wrap an iterable with a progress handler. Every time an item is yielded, the progress handler will be called with the current progress.""" + if num_total is None: + num_total = len(list(iterable)) + + tick = ProgressTicker(progress, num_total) + + for item in iterable: + tick(1) + yield item diff --git a/graphrag/logger/rich_progress.py b/graphrag/logger/rich_progress.py index 22145d12df..818697a4f2 100644 --- a/graphrag/logger/rich_progress.py +++ b/graphrag/logger/rich_progress.py @@ -6,7 +6,6 @@ # Print iterations progress import asyncio -from datashaper import Progress as DSProgress from rich.console import Console, Group from rich.live import Live from rich.progress import Progress, TaskID, TimeElapsedColumn @@ -14,6 +13,7 @@ from rich.tree import Tree from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress as GRProgress # https://stackoverflow.com/a/34325723 @@ -138,7 +138,7 @@ def info(self, message: str) -> None: """Log information.""" self._console.print(message) - def __call__(self, progress_update: DSProgress) -> None: + def __call__(self, progress_update: GRProgress) -> None: """Update progress.""" if self._disposing: return diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index e2b6c49b56..db8a95804a 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -5,11 +5,11 @@ import numpy as np import pandas as pd -from datashaper import NoopVerbCallbacks from fnllm import ChatLLM from pydantic import TypeAdapter import graphrag.config.defaults as defs +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.input.factory import create_input diff --git a/graphrag/storage/blob_pipeline_storage.py b/graphrag/storage/blob_pipeline_storage.py index 701e59e25c..f72663052c 100644 --- a/graphrag/storage/blob_pipeline_storage.py +++ b/graphrag/storage/blob_pipeline_storage.py @@ -11,9 +11,9 @@ from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient -from datashaper import Progress from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/storage/cosmosdb_pipeline_storage.py b/graphrag/storage/cosmosdb_pipeline_storage.py index 9de9cf6dc0..c832ebc8bd 100644 --- a/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/graphrag/storage/cosmosdb_pipeline_storage.py @@ -15,9 +15,9 @@ from azure.cosmos.exceptions import CosmosResourceNotFoundError from azure.cosmos.partition_key import PartitionKey from azure.identity import DefaultAzureCredential -from datashaper import Progress from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/storage/file_pipeline_storage.py b/graphrag/storage/file_pipeline_storage.py index f64df723b4..a2d45b89b3 100644 --- a/graphrag/storage/file_pipeline_storage.py +++ b/graphrag/storage/file_pipeline_storage.py @@ -14,9 +14,9 @@ import aiofiles from aiofiles.os import remove from aiofiles.ospath import exists -from datashaper import Progress from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/utils/storage.py b/graphrag/utils/storage.py index a28b0c4c1c..caf8003fc5 100644 --- a/graphrag/utils/storage.py +++ b/graphrag/utils/storage.py @@ -15,14 +15,15 @@ async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.DataFrame: """Load a parquet from the storage instance.""" - if not await storage.has(name): - msg = f"Could not find {name} in storage!" + filename = f"{name}.parquet" + if not await storage.has(filename): + msg = f"Could not find {filename} in storage!" raise ValueError(msg) try: - log.info("reading table from storage: %s", name) - return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True))) + log.info("reading table from storage: %s", filename) + return pd.read_parquet(BytesIO(await storage.get(filename, as_bytes=True))) except Exception: - log.exception("error loading table from storage: %s", name) + log.exception("error loading table from storage: %s", filename) raise @@ -30,4 +31,14 @@ async def write_table_to_storage( table: pd.DataFrame, name: str, storage: PipelineStorage ) -> None: """Write a table to storage.""" - await storage.set(name, table.to_parquet()) + await storage.set(f"{name}.parquet", table.to_parquet()) + + +async def delete_table_from_storage(name: str, storage: PipelineStorage) -> None: + """Delete a table to storage.""" + await storage.delete(f"{name}.parquet") + + +async def storage_has_table(name: str, storage: PipelineStorage) -> bool: + """Check if a table exists in storage.""" + return await storage.has(f"{name}.parquet") diff --git a/poetry.lock b/poetry.lock index 89c4bd1d56..cef2d2812f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -851,23 +851,6 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] -[[package]] -name = "datashaper" -version = "0.0.49" -description = "This project provides a collection of utilities for doing lightweight data wrangling." -optional = false -python-versions = ">=3.10,<4" -files = [ - {file = "datashaper-0.0.49-py3-none-any.whl", hash = "sha256:7f58cabacc834765595c6e04cfbbd05be6af71907e46ebc7a91d2a4add7c2643"}, - {file = "datashaper-0.0.49.tar.gz", hash = "sha256:05bfba5964474a62bdd5259ec3fa0173d01e365208b6a4aff4ea0e63096a7533"}, -] - -[package.dependencies] -diskcache = ">=5.6.3,<6.0.0" -jsonschema = ">=4.21.1,<5.0.0" -pandas = ">=2.2.0,<3.0.0" -pyarrow = ">=15.0.0,<16.0.0" - [[package]] name = "debugpy" version = "1.8.11" @@ -987,17 +970,6 @@ asttokens = ">=2.0.0,<3.0.0" executing = ">=1.1.1" pygments = ">=2.15.0" -[[package]] -name = "diskcache" -version = "5.6.3" -description = "Disk Cache -- Disk and file backed persistent cache." -optional = false -python-versions = ">=3" -files = [ - {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, - {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, -] - [[package]] name = "distro" version = "1.9.0" @@ -5286,4 +5258,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "54d5f5d253d47c5c28c874c9aa56eaeba543fa3c27fed6143ae266b0a07ed391" +content-hash = "1adafa89f86e853b424eb1d66d3434520596e6b1e782c975a497a1c857ceabb9" diff --git a/pyproject.toml b/pyproject.toml index 7d27132548..97a9557a78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ format-jinja = """ [tool.poetry.dependencies] python = ">=3.10,<3.13" environs = "^11.0.0" -datashaper = "^0.0.49" # Vector Stores azure-search-documents = "^11.5.2" @@ -252,7 +251,6 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/*" = ["S", "D", "ANN", "T201", "ASYNC", "ARG", "PTH", "TRY"] -"examples/*" = ["S", "D", "ANN", "T201", "PTH", "TRY", "PERF"] "graphrag/index/config/*" = ["TCH"] "*.ipynb" = ["T201"] @@ -264,7 +262,7 @@ convention = "numpy" # https://github.com/microsoft/pyright/blob/9f81564a4685ff5c55edd3959f9b39030f590b2f/docs/configuration.md#sample-pyprojecttoml-file [tool.pyright] -include = ["graphrag", "tests", "examples", "examples_notebooks"] +include = ["graphrag", "tests", "examples_notebooks"] exclude = ["**/node_modules", "**/__pycache__"] [tool.pytest.ini_options] diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index 1457092c59..65617cbfec 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -42,13 +42,7 @@ from graphrag.config.models.umap_config import UmapConfig from graphrag.index.config.cache import PipelineFileCacheConfig from graphrag.index.config.input import ( - PipelineCSVInputConfig, PipelineInputConfig, - PipelineTextInputConfig, -) -from graphrag.index.config.pipeline import ( - PipelineConfig, - PipelineWorkflowReference, ) from graphrag.index.config.reporting import PipelineFileReportingConfig from graphrag.index.config.storage import PipelineFileStorageConfig @@ -201,12 +195,10 @@ def test_clear_warnings(self): assert SummarizeDescriptionsConfig is not None assert TextEmbeddingConfig is not None assert UmapConfig is not None - assert PipelineConfig is not None assert PipelineFileReportingConfig is not None assert PipelineFileStorageConfig is not None assert PipelineInputConfig is not None assert PipelineFileCacheConfig is not None - assert PipelineWorkflowReference is not None @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) def test_string_repr(self): @@ -218,47 +210,31 @@ def test_string_repr(self): # __repr__ can be eval()'d repr_str = config.__repr__() - # TODO: add __repr__ to datashaper enum + # TODO: add __repr__ to enum repr_str = repr_str.replace("async_mode=,", "") assert eval(repr_str) is not None - # Pipeline config __str__ can be json loaded - pipeline_config = create_pipeline_config(config) - string_repr = str(pipeline_config) - assert string_repr is not None - assert json.loads(string_repr) is not None - - # Pipeline config __repr__ can be eval()'d - repr_str = pipeline_config.__repr__() - # TODO: add __repr__ to datashaper enum - repr_str = repr_str.replace( - "'async_mode': ,", "" - ) - assert eval(repr_str) is not None - @mock.patch.dict(os.environ, {}, clear=True) def test_default_config_with_no_env_vars_throws(self): with pytest.raises(ApiKeyMissingError): # This should throw an error because the API key is missing - create_pipeline_config(create_graphrag_config()) + create_graphrag_config() @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_default_config_with_api_key_passes(self): # doesn't throw - config = create_pipeline_config(create_graphrag_config()) + config = create_graphrag_config() assert config is not None @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) def test_default_config_with_oai_key_passes_envvar(self): # doesn't throw - config = create_pipeline_config(create_graphrag_config()) + config = create_graphrag_config() assert config is not None def test_default_config_with_oai_key_passes_obj(self): # doesn't throw - config = create_pipeline_config( - create_graphrag_config({"llm": {"api_key": "test"}}) - ) + config = create_graphrag_config({"llm": {"api_key": "test"}}) assert config is not None @mock.patch.dict( @@ -352,29 +328,6 @@ def test_throws_if_azure_is_used_without_embedding_deployment_name(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config() - @mock.patch.dict( - os.environ, - {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "csv"}, - clear=True, - ) - def test_csv_input_returns_correct_config(self): - config = create_pipeline_config(create_graphrag_config(root_dir="/some/root")) - assert config.root_dir == "/some/root" - # Make sure the input is a CSV input - assert isinstance(config.input, PipelineCSVInputConfig) - assert (config.input.file_pattern or "") == ".*\\.csv$" # type: ignore - - @mock.patch.dict( - os.environ, - {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "text"}, - clear=True, - ) - def test_text_input_returns_correct_config(self): - config = create_pipeline_config(create_graphrag_config(root_dir=".")) - assert isinstance(config.input, PipelineTextInputConfig) - assert config.input is not None - assert (config.input.file_pattern or "") == ".*\\.txt$" # type: ignore - @mock.patch.dict( os.environ, { diff --git a/tests/unit/indexing/config/__init__.py b/tests/unit/indexing/config/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/indexing/config/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/unit/indexing/config/default_config_with_everything_overridden.yml b/tests/unit/indexing/config/default_config_with_everything_overridden.yml deleted file mode 100644 index 7a2f712e46..0000000000 --- a/tests/unit/indexing/config/default_config_with_everything_overridden.yml +++ /dev/null @@ -1,20 +0,0 @@ -extends: default - -input: - file_type: text - base_dir: /some/overridden/dir - file_pattern: test.txt - -storage: - type: file - -cache: - type: file - -reporting: - type: file - -workflows: - - name: TEST_WORKFLOW - steps: - - verb: TEST_VERB diff --git a/tests/unit/indexing/config/default_config_with_overridden_input.yml b/tests/unit/indexing/config/default_config_with_overridden_input.yml deleted file mode 100644 index 68631a315a..0000000000 --- a/tests/unit/indexing/config/default_config_with_overridden_input.yml +++ /dev/null @@ -1,5 +0,0 @@ -extends: default -input: - file_type: text - base_dir: /some/overridden/dir - file_pattern: test.txt diff --git a/tests/unit/indexing/config/default_config_with_overridden_workflows.yml b/tests/unit/indexing/config/default_config_with_overridden_workflows.yml deleted file mode 100644 index c3c9d07c2c..0000000000 --- a/tests/unit/indexing/config/default_config_with_overridden_workflows.yml +++ /dev/null @@ -1,6 +0,0 @@ -extends: default - -workflows: - - name: TEST_WORKFLOW - steps: - - verb: TEST_VERB diff --git a/tests/unit/indexing/config/helpers.py b/tests/unit/indexing/config/helpers.py deleted file mode 100644 index f70b9af81e..0000000000 --- a/tests/unit/indexing/config/helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import json -import unittest -from typing import Any - -from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.index.create_pipeline_config import PipelineConfig, create_pipeline_config - - -def assert_contains_default_config( - test_case: unittest.TestCase, - config: Any, - check_input=True, - check_storage=True, - check_reporting=True, - check_cache=True, - check_workflows=True, -): - """Asserts that the config contains the default config.""" - assert config is not None - assert isinstance(config, PipelineConfig) - - checked_config = json.loads( - config.model_dump_json(exclude_defaults=True, exclude_unset=True) - ) - - actual_default_config = json.loads( - create_pipeline_config(create_graphrag_config()).model_dump_json( - exclude_defaults=True, exclude_unset=True - ) - ) - props_to_ignore = ["root_dir", "extends"] - - # Make sure there is some sort of workflows - if not check_workflows: - props_to_ignore.append("workflows") - - # Make sure it tries to load some sort of input - if not check_input: - props_to_ignore.append("input") - - # Make sure it tries to load some sort of storage - if not check_storage: - props_to_ignore.append("storage") - - # Make sure it tries to load some sort of reporting - if not check_reporting: - props_to_ignore.append("reporting") - - # Make sure it tries to load some sort of cache - if not check_cache: - props_to_ignore.append("cache") - - for prop in props_to_ignore: - checked_config.pop(prop, None) - actual_default_config.pop(prop, None) - - assert actual_default_config == actual_default_config | checked_config diff --git a/tests/unit/indexing/config/test_load.py b/tests/unit/indexing/config/test_load.py deleted file mode 100644 index c458081ced..0000000000 --- a/tests/unit/indexing/config/test_load.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT Licenses -import json -import os -import unittest -from pathlib import Path -from typing import Any -from unittest import mock - -from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.index.config.pipeline import PipelineConfig -from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.index.load_pipeline_config import load_pipeline_config - -current_dir = os.path.dirname(__file__) - - -class TestLoadPipelineConfig(unittest.TestCase): - @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) - def test_config_passed_in_returns_config(self): - config = PipelineConfig() - result = load_pipeline_config(config) - assert result == config - - @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) - def test_loading_default_config_returns_config(self): - result = load_pipeline_config("default") - self.assert_is_default_config(result) - - @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) - def test_loading_default_config_with_input_overridden(self): - config = load_pipeline_config( - str(Path(current_dir) / "default_config_with_overridden_input.yml") - ) - - # Check that the config is merged - # but skip checking the input - self.assert_is_default_config( - config, check_input=False, ignore_workflows=["create_base_text_units"] - ) - - if config.input is None: - msg = "Input should not be none" - raise Exception(msg) - - # Check that the input is merged - assert config.input.file_pattern == "test.txt" - assert config.input.file_type == "text" - assert config.input.base_dir == "/some/overridden/dir" - - @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) - def test_loading_default_config_with_workflows_overridden(self): - config = load_pipeline_config( - str(Path(current_dir) / "default_config_with_overridden_workflows.yml") - ) - - # Check that the config is merged - # but skip checking the input - self.assert_is_default_config(config, check_workflows=False) - - # Make sure the workflows are overridden - assert len(config.workflows) == 1 - assert config.workflows[0].name == "TEST_WORKFLOW" - assert config.workflows[0].steps is not None - assert len(config.workflows[0].steps) == 1 # type: ignore - assert config.workflows[0].steps[0]["verb"] == "TEST_VERB" # type: ignore - - @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) - def assert_is_default_config( - self, - config: Any, - check_input=True, - check_storage=True, - check_reporting=True, - check_cache=True, - check_workflows=True, - ignore_workflows=None, - ): - if ignore_workflows is None: - ignore_workflows = [] - assert config is not None - assert isinstance(config, PipelineConfig) - - checked_config = json.loads( - config.model_dump_json(exclude_defaults=True, exclude_unset=True) - ) - - actual_default_config = json.loads( - create_pipeline_config( - create_graphrag_config(root_dir=".") - ).model_dump_json(exclude_defaults=True, exclude_unset=True) - ) - props_to_ignore = ["root_dir", "extends"] - - # Make sure there is some sort of workflows - if not check_workflows: - props_to_ignore.append("workflows") - - # Make sure it tries to load some sort of input - if not check_input: - props_to_ignore.append("input") - - # Make sure it tries to load some sort of storage - if not check_storage: - props_to_ignore.append("storage") - - # Make sure it tries to load some sort of reporting - if not check_reporting: - props_to_ignore.append("reporting") - - # Make sure it tries to load some sort of cache - if not check_cache: - props_to_ignore.append("cache") - - for prop in props_to_ignore: - checked_config.pop(prop, None) - actual_default_config.pop(prop, None) - - for prop in actual_default_config: - if prop == "workflows": - assert len(checked_config[prop]) == len(actual_default_config[prop]) - for i, workflow in enumerate(actual_default_config[prop]): - if workflow["name"] not in ignore_workflows: - assert workflow == actual_default_config[prop][i] - else: - assert checked_config[prop] == actual_default_config[prop] - - def setUp(self) -> None: - os.environ["GRAPHRAG_OPENAI_API_KEY"] = "test" - os.environ["GRAPHRAG_OPENAI_EMBEDDING_API_KEY"] = "test" - return super().setUp() diff --git a/tests/unit/indexing/test_exports.py b/tests/unit/indexing/test_exports.py deleted file mode 100644 index ee2b23e622..0000000000 --- a/tests/unit/indexing/test_exports.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -from graphrag.index.create_pipeline_config import create_pipeline_config -from graphrag.index.run import run_pipeline, run_pipeline_with_config - - -def test_exported_functions(): - assert callable(create_pipeline_config) - assert callable(run_pipeline_with_config) - assert callable(run_pipeline) diff --git a/tests/unit/indexing/workflows/__init__.py b/tests/unit/indexing/workflows/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/indexing/workflows/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/unit/indexing/workflows/helpers.py b/tests/unit/indexing/workflows/helpers.py deleted file mode 100644 index 512e8294c2..0000000000 --- a/tests/unit/indexing/workflows/helpers.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -mock_verbs = { - "mock_verb": lambda x: x, - "mock_verb_2": lambda x: x, -} - -mock_workflows = { - "mock_workflow": lambda _x: [ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ], - "mock_workflow_2": lambda _x: [ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - }, - { - "verb": "mock_verb_2", - "args": { - "column": "test", - }, - }, - ], -} diff --git a/tests/unit/indexing/workflows/test_export.py b/tests/unit/indexing/workflows/test_export.py deleted file mode 100644 index 206b4869e6..0000000000 --- a/tests/unit/indexing/workflows/test_export.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - Table, - VerbInput, - VerbResult, - create_verb_result, -) - -from graphrag.index.config.pipeline import PipelineWorkflowReference -from graphrag.index.run import run_pipeline -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage - - -async def mock_verb( - input: VerbInput, storage: PipelineStorage, **_kwargs -) -> VerbResult: - source = cast("pd.DataFrame", input.get_input()) - - output = source[["id"]] - - await storage.set("mock_write", source[["id"]]) - - return create_verb_result( - cast( - "Table", - output, - ) - ) - - -async def mock_no_return_verb( - input: VerbInput, storage: PipelineStorage, **_kwargs -) -> VerbResult: - source = cast("pd.DataFrame", input.get_input()) - - # write some outputs to storage independent of the return - await storage.set("empty_write", source[["name"]]) - - return create_verb_result( - cast( - "Table", - pd.DataFrame(), - ) - ) - - -async def test_normal_result_exports_parquet(): - mock_verbs: Any = {"mock_verb": mock_verb} - mock_workflows: Any = { - "mock_workflow": lambda _x: [ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ] - } - workflows = [ - PipelineWorkflowReference( - name="mock_workflow", - config=None, - ) - ] - dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) - storage = MemoryPipelineStorage() - pipeline_result = [ - gen - async for gen in run_pipeline( - workflows, - dataset, - storage=storage, - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - ] - - assert len(pipeline_result) == 1 - assert storage.keys() == ["stats.json", "mock_write", "mock_workflow.parquet"], ( - "Mock workflow output should be written to storage by the exporter when there is a non-empty data frame" - ) - - -async def test_empty_result_does_not_export_parquet(): - mock_verbs: Any = {"mock_no_return_verb": mock_no_return_verb} - mock_workflows: Any = { - "mock_workflow": lambda _x: [ - { - "verb": "mock_no_return_verb", - "args": { - "column": "test", - }, - } - ] - } - workflows = [ - PipelineWorkflowReference( - name="mock_workflow", - config=None, - ) - ] - dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) - storage = MemoryPipelineStorage() - pipeline_result = [ - gen - async for gen in run_pipeline( - workflows, - dataset, - storage=storage, - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - ] - - assert len(pipeline_result) == 1 - assert storage.keys() == [ - "stats.json", - "empty_write", - ], "Mock workflow output should not be written to storage by the exporter" diff --git a/tests/unit/indexing/workflows/test_load.py b/tests/unit/indexing/workflows/test_load.py deleted file mode 100644 index 60ae6647b4..0000000000 --- a/tests/unit/indexing/workflows/test_load.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import unittest - -import pytest - -from graphrag.index.config.pipeline import PipelineWorkflowReference -from graphrag.index.errors import UnknownWorkflowError -from graphrag.index.workflows.load import create_workflow, load_workflows - -from .helpers import mock_verbs, mock_workflows - - -class TestCreateWorkflow(unittest.TestCase): - def test_workflow_with_steps_should_not_fail(self): - create_workflow( - "workflow_with_steps", - [ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ], - config=None, - additional_verbs=mock_verbs, - ) - - def test_non_existent_workflow_without_steps_should_crash(self): - # since we don't have a workflow named "test", and the user didn't provide any steps, we should crash - # since we don't know what to do - with pytest.raises(UnknownWorkflowError): - create_workflow("test", None, config=None, additional_verbs=mock_verbs) - - def test_existing_workflow_should_not_crash(self): - create_workflow( - "mock_workflow", - None, - config=None, - additional_verbs=mock_verbs, - additional_workflows=mock_workflows, - ) - - -class TestLoadWorkflows(unittest.TestCase): - def test_non_existent_workflow_should_crash(self): - with pytest.raises(UnknownWorkflowError): - load_workflows( - [ - PipelineWorkflowReference( - name="some_workflow_that_does_not_exist", - config=None, - ) - ], - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - def test_single_workflow_should_not_crash(self): - load_workflows( - [ - PipelineWorkflowReference( - name="mock_workflow", - config=None, - ) - ], - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - def test_multiple_workflows_should_not_crash(self): - load_workflows( - [ - PipelineWorkflowReference( - name="mock_workflow", - config=None, - ), - PipelineWorkflowReference( - name="mock_workflow_2", - config=None, - ), - ], - # the two above are in the "mock_workflows" list - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - def test_two_interdependent_workflows_should_provide_correct_order(self): - ordered_workflows, _deps = load_workflows( - [ - PipelineWorkflowReference( - name="interdependent_workflow_1", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - "input": { - "source": "workflow:interdependent_workflow_2" - }, # This one is dependent on the second one, so when it comes out of load_workflows, it should be first - } - ], - ), - PipelineWorkflowReference( - name="interdependent_workflow_2", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ], - ), - ], - # the two above are in the "mock_workflows" list - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - # two should only come out - assert len(ordered_workflows) == 2 - assert ordered_workflows[0].workflow.name == "interdependent_workflow_2" - assert ordered_workflows[1].workflow.name == "interdependent_workflow_1" - - def test_three_interdependent_workflows_should_provide_correct_order(self): - ordered_workflows, _deps = load_workflows( - [ - PipelineWorkflowReference( - name="interdependent_workflow_3", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ], - ), - PipelineWorkflowReference( - name="interdependent_workflow_1", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - "input": {"source": "workflow:interdependent_workflow_2"}, - } - ], - ), - PipelineWorkflowReference( - name="interdependent_workflow_2", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - "input": {"source": "workflow:interdependent_workflow_3"}, - } - ], - ), - ], - # the two above are in the "mock_workflows" list - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - order = [ - "interdependent_workflow_3", - "interdependent_workflow_2", - "interdependent_workflow_1", - ] - assert [x.workflow.name for x in ordered_workflows] == order - - def test_two_workflows_dependent_on_another_single_workflow_should_provide_correct_order( - self, - ): - ordered_workflows, _deps = load_workflows( - [ - # Workflows 1 and 2 are dependent on 3, so 3 should come out first - PipelineWorkflowReference( - name="interdependent_workflow_3", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - } - ], - ), - PipelineWorkflowReference( - name="interdependent_workflow_1", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - "input": {"source": "workflow:interdependent_workflow_3"}, - } - ], - ), - PipelineWorkflowReference( - name="interdependent_workflow_2", - steps=[ - { - "verb": "mock_verb", - "args": { - "column": "test", - }, - "input": {"source": "workflow:interdependent_workflow_3"}, - } - ], - ), - ], - # the two above are in the "mock_workflows" list - additional_workflows=mock_workflows, - additional_verbs=mock_verbs, - ) - - assert len(ordered_workflows) == 3 - assert ordered_workflows[0].workflow.name == "interdependent_workflow_3" - - # The order of the other two doesn't matter, but they need to be there - assert ordered_workflows[1].workflow.name in [ - "interdependent_workflow_1", - "interdependent_workflow_2", - ] - assert ordered_workflows[2].workflow.name in [ - "interdependent_workflow_1", - "interdependent_workflow_2", - ] diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py index 1b23ef97b9..a460793e0b 100644 --- a/tests/verbs/test_compute_communities.py +++ b/tests/verbs/test_compute_communities.py @@ -1,34 +1,35 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.flows.compute_communities import ( - compute_communities, -) -from graphrag.index.workflows.v1.compute_communities import ( - workflow_name, -) +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.compute_communities import run_workflow +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, - get_config_for_workflow, + create_test_context, load_test_table, ) -def test_compute_communities(): - edges = load_test_table("base_relationship_edges") +async def test_compute_communities(): expected = load_test_table("base_communities") - config = get_config_for_workflow(workflow_name) - cluster_config = config["cluster_graph"] + context = await create_test_context( + storage=["base_relationship_edges"], + ) - actual = compute_communities( - edges, - cluster_config.max_cluster_size, - cluster_config.use_lcc, - cluster_config.seed, + config = create_graphrag_config() + + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage("base_communities", context.storage) + columns = list(expected.columns.values) compare_outputs(actual, expected, columns) assert len(actual.columns) == len(expected.columns) diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index cf1d267aa3..587db6549d 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -1,65 +1,33 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.create_base_text_units import ( - build_steps, - workflow_name, -) +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_base_text_units import run_workflow, workflow_name +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) async def test_create_base_text_units(): - input_tables = load_input_tables(inputs=[]) expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) + context = await create_test_context() - config = get_config_for_workflow(workflow_name) + config = create_graphrag_config() # test data was created with 4o, so we need to match the encoding for chunks to be identical - config["chunks"].encoding_model = "o200k_base" - - steps = build_steps(config) + config.chunks.encoding_model = "o200k_base" - await get_workflow_output( - input_tables, - { - "steps": steps, - }, + await run_workflow( + config, context, + NoopVerbCallbacks(), ) - actual = await context.runtime_storage.get("base_text_units") - compare_outputs(actual, expected) - - -async def test_create_base_text_units_with_snapshot(): - input_tables = load_input_tables(inputs=[]) - - context = create_run_context(None, None, None) - - config = get_config_for_workflow(workflow_name) - # test data was created with 4o, so we need to match the encoding for chunks to be identical - config["chunks"].encoding_model = "o200k_base" - config["snapshot_transient"] = True - - steps = build_steps(config) + actual = await load_table_from_storage(workflow_name, context.storage) - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context, - ) - - assert context.storage.keys() == ["create_base_text_units.parquet"], ( - "Text unit snapshot keys differ" - ) + compare_outputs(actual, expected) diff --git a/tests/verbs/test_create_final_communities.py b/tests/verbs/test_create_final_communities.py index b9f16f4c2b..07c9e9baa5 100644 --- a/tests/verbs/test_create_final_communities.py +++ b/tests/verbs/test_create_final_communities.py @@ -1,32 +1,42 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.flows.create_final_communities import ( - create_final_communities, -) -from graphrag.index.workflows.v1.create_final_communities import ( +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_communities import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, + create_test_context, load_test_table, ) -def test_create_final_communities(): - base_entity_nodes = load_test_table("base_entity_nodes") - base_relationship_edges = load_test_table("base_relationship_edges") - base_communities = load_test_table("base_communities") - +async def test_create_final_communities(): expected = load_test_table(workflow_name) - actual = create_final_communities( - base_entity_nodes=base_entity_nodes, - base_relationship_edges=base_relationship_edges, - base_communities=base_communities, + context = await create_test_context( + storage=[ + "base_entity_nodes", + "base_relationship_edges", + "base_communities", + ], + ) + + config = create_graphrag_config() + + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + assert "period" in expected.columns assert "id" in expected.columns columns = list(expected.columns.values) diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 85a6c3ee2b..896fe6e3cb 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -3,23 +3,24 @@ import pytest -from datashaper.errors import VerbParallelizationError +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( CommunityReportResponse, FindingModel, ) -from graphrag.index.workflows.v1.create_final_community_reports import ( - build_steps, +from graphrag.index.run.derive_from_rows import ParallelizationError +from graphrag.index.workflows.create_final_community_reports import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) @@ -48,28 +49,32 @@ async def test_create_final_community_reports(): - input_tables = load_input_tables([ - "workflow:create_final_nodes", - "workflow:create_final_covariates", - "workflow:create_final_relationships", - "workflow:create_final_entities", - "workflow:create_final_communities", - ]) expected = load_test_table(workflow_name) - config = get_config_for_workflow(workflow_name) - - config["create_community_reports"]["strategy"]["llm"] = MOCK_LLM_CONFIG + context = await create_test_context( + storage=[ + "create_final_nodes", + "create_final_covariates", + "create_final_relationships", + "create_final_entities", + "create_final_communities", + ] + ) - steps = build_steps(config) + config = create_graphrag_config() + config.community_reports.strategy = { + "type": "graph_intelligence", + "llm": MOCK_LLM_CONFIG, + } - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + assert len(actual.columns) == len(expected.columns) # only assert a couple of columns that are not mock - most of this table is LLM-generated @@ -81,25 +86,24 @@ async def test_create_final_community_reports(): async def test_create_final_community_reports_missing_llm_throws(): - input_tables = load_input_tables([ - "workflow:create_final_nodes", - "workflow:create_final_covariates", - "workflow:create_final_relationships", - "workflow:create_final_entities", - "workflow:create_final_communities", - ]) - - config = get_config_for_workflow(workflow_name) - - # deleting the llm config results in a default mock injection in run_graph_intelligence - del config["create_community_reports"]["strategy"]["llm"] - - steps = build_steps(config) - - with pytest.raises(VerbParallelizationError): - await get_workflow_output( - input_tables, - { - "steps": steps, - }, + context = await create_test_context( + storage=[ + "create_final_nodes", + "create_final_covariates", + "create_final_relationships", + "create_final_entities", + "create_final_communities", + ] + ) + + config = create_graphrag_config() + config.community_reports.strategy = { + "type": "graph_intelligence", + } + + with pytest.raises(ParallelizationError): + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) diff --git a/tests/verbs/test_create_final_covariates.py b/tests/verbs/test_create_final_covariates.py index aecd3e7782..8236abd7bc 100644 --- a/tests/verbs/test_create_final_covariates.py +++ b/tests/verbs/test_create_final_covariates.py @@ -2,20 +2,20 @@ # Licensed under the MIT License import pytest -from datashaper.errors import VerbParallelizationError from pandas.testing import assert_series_equal +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.create_final_covariates import ( - build_steps, +from graphrag.index.run.derive_from_rows import ParallelizationError +from graphrag.index.workflows.create_final_covariates import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) @@ -29,29 +29,27 @@ async def test_create_final_covariates(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) + input = load_test_table("create_base_text_units") expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) + config = create_graphrag_config() + config.claim_extraction.strategy = { + "type": "graph_intelligence", + "llm": MOCK_LLM_CONFIG, + "claim_description": "description", + } - config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, + await run_workflow( + config, context, + NoopVerbCallbacks(), ) - input = input_tables["workflow:create_base_text_units"] + actual = await load_table_from_storage(workflow_name, context.storage) assert len(actual.columns) == len(expected.columns) # our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data @@ -83,24 +81,19 @@ async def test_create_final_covariates(): async def test_create_final_covariates_missing_llm_throws(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) - - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) - - del config["claim_extract"]["strategy"]["llm"] - - steps = build_steps(config) + config = create_graphrag_config() + config.claim_extraction.strategy = { + "type": "graph_intelligence", + "claim_description": "description", + } - with pytest.raises(VerbParallelizationError): - await get_workflow_output( - input_tables, - { - "steps": steps, - }, + with pytest.raises(ParallelizationError): + await run_workflow( + config, context, + NoopVerbCallbacks(), ) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index f58b0e2721..a6916530a0 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -1,70 +1,59 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.create_final_documents import ( - build_steps, +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_documents import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) async def test_create_final_documents(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - ]) expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) + config = create_graphrag_config() - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + compare_outputs(actual, expected) async def test_create_final_documents_with_attribute_columns(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) - - config["document_attribute_columns"] = ["title"] + config = create_graphrag_config() + config.input.document_attribute_columns = ["title"] - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + # we should have dropped "title" and added "attributes" # our test dataframe does not have attributes, so we'll assert without it # and separately confirm it is in the output diff --git a/tests/verbs/test_create_final_entities.py b/tests/verbs/test_create_final_entities.py index 491830205f..6d4430d398 100644 --- a/tests/verbs/test_create_final_entities.py +++ b/tests/verbs/test_create_final_entities.py @@ -1,24 +1,37 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.flows.create_final_entities import ( - create_final_entities, -) -from graphrag.index.workflows.v1.create_final_entities import ( +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_entities import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, + create_test_context, load_test_table, ) -def test_create_final_entities(): - input = load_test_table("base_entity_nodes") +async def test_create_final_entities(): expected = load_test_table(workflow_name) - actual = create_final_entities(input) + context = await create_test_context( + storage=["base_entity_nodes"], + ) + + config = create_graphrag_config() + + await run_workflow( + config, + context, + NoopVerbCallbacks(), + ) + + actual = await load_table_from_storage(workflow_name, context.storage) compare_outputs(actual, expected) assert len(actual.columns) == len(expected.columns) diff --git a/tests/verbs/test_create_final_nodes.py b/tests/verbs/test_create_final_nodes.py index db3b6ec57f..f37cb20cec 100644 --- a/tests/verbs/test_create_final_nodes.py +++ b/tests/verbs/test_create_final_nodes.py @@ -1,39 +1,42 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from datashaper import NoopVerbCallbacks - -from graphrag.config.models.embed_graph_config import EmbedGraphConfig -from graphrag.index.flows.create_final_nodes import ( - create_final_nodes, -) -from graphrag.index.workflows.v1.create_final_nodes import ( +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_nodes import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, + create_test_context, load_test_table, ) -def test_create_final_nodes(): - base_entity_nodes = load_test_table("base_entity_nodes") - base_relationship_edges = load_test_table("base_relationship_edges") - base_communities = load_test_table("base_communities") - +async def test_create_final_nodes(): expected = load_test_table(workflow_name) - embed_config = EmbedGraphConfig(enabled=False) - actual = create_final_nodes( - base_entity_nodes=base_entity_nodes, - base_relationship_edges=base_relationship_edges, - base_communities=base_communities, - callbacks=NoopVerbCallbacks(), - embed_config=embed_config, - layout_enabled=False, + context = await create_test_context( + storage=[ + "base_entity_nodes", + "base_relationship_edges", + "base_communities", + ], ) + config = create_graphrag_config() + + await run_workflow( + config, + context, + NoopVerbCallbacks(), + ) + + actual = await load_table_from_storage(workflow_name, context.storage) + assert "id" in expected.columns columns = list(expected.columns.values) columns.remove("id") diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 9f01e08304..223ca20ea4 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -1,24 +1,38 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.flows.create_final_relationships import ( - create_final_relationships, -) -from graphrag.index.workflows.v1.create_final_relationships import ( + +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_relationships import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, + create_test_context, load_test_table, ) -def test_create_final_relationships(): - edges = load_test_table("base_relationship_edges") +async def test_create_final_relationships(): expected = load_test_table(workflow_name) - actual = create_final_relationships(edges) + context = await create_test_context( + storage=["base_relationship_edges"], + ) + + config = create_graphrag_config() + + await run_workflow( + config, + context, + NoopVerbCallbacks(), + ) + + actual = await load_table_from_storage(workflow_name, context.storage) assert "id" in expected.columns columns = list(expected.columns.values) diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index b87e61f55d..19fb11c6f0 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -1,80 +1,70 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.create_final_text_units import ( - build_steps, +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.index.workflows.create_final_text_units import ( + run_workflow, workflow_name, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( compare_outputs, - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) async def test_create_final_text_units(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - "workflow:create_final_entities", - "workflow:create_final_relationships", - "workflow:create_final_covariates", - ]) expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=[ + "create_base_text_units", + "create_final_entities", + "create_final_relationships", + "create_final_covariates", + ], ) - config = get_config_for_workflow(workflow_name) + config = create_graphrag_config() + config.claim_extraction.enabled = True - config["covariates_enabled"] = True - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + compare_outputs(actual, expected) async def test_create_final_text_units_no_covariates(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - "workflow:create_final_entities", - "workflow:create_final_relationships", - "workflow:create_final_covariates", - ]) expected = load_test_table(workflow_name) - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=[ + "create_base_text_units", + "create_final_entities", + "create_final_relationships", + "create_final_covariates", + ], ) - config = get_config_for_workflow(workflow_name) + config = create_graphrag_config() + config.claim_extraction.enabled = False - config["covariates_enabled"] = False - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) + actual = await load_table_from_storage(workflow_name, context.storage) + # we're short a covariate_ids column columns = list(expected.columns.values) columns.remove("covariate_ids") diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 3ccc1d22b6..68c9bb231b 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -3,17 +3,16 @@ import pytest +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.enums import LLMType -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.extract_graph import ( - build_steps, - workflow_name, +from graphrag.index.workflows.extract_graph import ( + run_workflow, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, load_test_table, ) @@ -49,35 +48,34 @@ async def test_extract_graph(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - ]) - nodes_expected = load_test_table("base_entity_nodes") edges_expected = load_test_table("base_relationship_edges") - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) - config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG - config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG - - steps = build_steps(config) - - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + config = create_graphrag_config() + config.entity_extraction.strategy = { + "type": "graph_intelligence", + "llm": MOCK_LLM_ENTITY_CONFIG, + } + config.summarize_descriptions.strategy = { + "type": "graph_intelligence", + "llm": MOCK_LLM_SUMMARIZATION_CONFIG, + } + + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) # graph construction creates transient tables for nodes, edges, and communities - nodes_actual = await context.runtime_storage.get("base_entity_nodes") - edges_actual = await context.runtime_storage.get("base_relationship_edges") + nodes_actual = await load_table_from_storage("base_entity_nodes", context.storage) + edges_actual = await load_table_from_storage( + "base_relationship_edges", context.storage + ) assert len(nodes_actual.columns) == len(nodes_expected.columns), ( "Nodes dataframe columns differ" @@ -91,69 +89,26 @@ async def test_extract_graph(): # this is because the mock responses always result in a single description, which is returned verbatim rather than summarized # we need to update the mocking to provide somewhat unique graphs so a true merge happens # the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below - assert nodes_actual["description"].to_numpy()[0] == "Company_A is a test company" - assert len(context.storage.keys()) == 0, "Storage should be empty" - - -async def test_extract_graph_with_snapshots(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - ]) - - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] - ) - - config = get_config_for_workflow(workflow_name) - - config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG - config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG - config["snapshot_graphml"] = True - config["snapshot_transient"] = True - config["embed_graph_enabled"] = True # need this on in order to see the snapshot - - steps = build_steps(config) - - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, - ) - - assert context.storage.keys() == [ - "graph.graphml", - "base_entity_nodes.parquet", - "base_relationship_edges.parquet", - ], "Graph snapshot keys differ" - async def test_extract_graph_missing_llm_throws(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - ]) - - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] + context = await create_test_context( + storage=["create_base_text_units"], ) - config = get_config_for_workflow(workflow_name) - - config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG - del config["summarize_descriptions"]["strategy"]["llm"] - - steps = build_steps(config) + config = create_graphrag_config() + config.entity_extraction.strategy = { + "type": "graph_intelligence", + "llm": MOCK_LLM_ENTITY_CONFIG, + } + config.summarize_descriptions.strategy = { + "type": "graph_intelligence", + } with pytest.raises(ValueError): # noqa PT011 - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, + await run_workflow( + config, + context, + NoopVerbCallbacks(), ) diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index c0919501d8..640284c7ca 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -1,53 +1,44 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from io import BytesIO - -import pandas as pd - +from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks +from graphrag.config.create_graphrag_config import create_graphrag_config +from graphrag.config.enums import TextEmbeddingTarget from graphrag.index.config.embeddings import ( all_embeddings, ) -from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.generate_text_embeddings import ( - build_steps, - workflow_name, +from graphrag.index.workflows.generate_text_embeddings import ( + run_workflow, ) +from graphrag.utils.storage import load_table_from_storage from .util import ( - get_config_for_workflow, - get_workflow_output, - load_input_tables, + create_test_context, ) async def test_generate_text_embeddings(): - input_tables = load_input_tables( - inputs=[ - "workflow:create_final_documents", - "workflow:create_final_relationships", - "workflow:create_final_text_units", - "workflow:create_final_entities", - "workflow:create_final_community_reports", + context = await create_test_context( + storage=[ + "create_final_documents", + "create_final_relationships", + "create_final_text_units", + "create_final_entities", + "create_final_community_reports", ] ) - context = create_run_context(None, None, None) - - config = get_config_for_workflow(workflow_name) - - config["text_embed"]["strategy"]["type"] = "mock" - config["snapshot_embeddings"] = True - config["embedded_fields"] = all_embeddings + config = create_graphrag_config() + config.embeddings.strategy = { + "type": "mock", + } + config.embeddings.target = TextEmbeddingTarget.all + config.snapshots.embeddings = True - steps = build_steps(config) - - await get_workflow_output( - input_tables, - { - "steps": steps, - }, + await run_workflow( + config, context, + NoopVerbCallbacks(), ) parquet_files = context.storage.keys() @@ -56,23 +47,19 @@ async def test_generate_text_embeddings(): assert f"embeddings.{field}.parquet" in parquet_files # entity description should always be here, let's assert its format - entity_description_embeddings_buffer = BytesIO( - await context.storage.get( - "embeddings.entity.description.parquet", as_bytes=True - ) - ) - entity_description_embeddings = pd.read_parquet( - entity_description_embeddings_buffer + entity_description_embeddings = await load_table_from_storage( + "embeddings.entity.description", context.storage ) + assert len(entity_description_embeddings.columns) == 2 assert "id" in entity_description_embeddings.columns assert "embedding" in entity_description_embeddings.columns # every other embedding is optional but we've turned them all on, so check a random one - document_text_embeddings_buffer = BytesIO( - await context.storage.get("embeddings.document.text.parquet", as_bytes=True) + document_text_embeddings = await load_table_from_storage( + "embeddings.document.text", context.storage ) - document_text_embeddings = pd.read_parquet(document_text_embeddings_buffer) + assert len(document_text_embeddings.columns) == 2 assert "id" in document_text_embeddings.columns assert "embedding" in document_text_embeddings.columns diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 8c9cc990ef..91a9625893 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -1,35 +1,31 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from typing import cast - import pandas as pd -from datashaper import Workflow from pandas.testing import assert_series_equal -from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.index.config.workflow import PipelineWorkflowConfig from graphrag.index.context import PipelineRunContext -from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.index.run.utils import create_run_context +from graphrag.utils.storage import write_table_to_storage pd.set_option("display.max_columns", None) -def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: - """Harvest all the referenced input IDs from the workflow being tested and pass them here.""" - # stick all the inputs in a map - Workflow looks them up by name - input_tables: dict[str, pd.DataFrame] = {} +async def create_test_context(storage: list[str] | None = None) -> PipelineRunContext: + """Create a test context with tables loaded into storage storage.""" + context = create_run_context(None, None, None) - source = pd.read_parquet("tests/verbs/data/source_documents.parquet") - input_tables["source"] = source + # always set the input docs + input = load_test_table("source_documents") + await write_table_to_storage(input, "input", context.storage) - for input in inputs: - # remove the workflow: prefix if it exists, because that is not part of the actual table filename - name = input.replace("workflow:", "") - input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet") + if storage: + for name in storage: + table = load_test_table(name) + # normal storage interface insists on bytes + await write_table_to_storage(table, name, context.storage) - return input_tables + return context def load_test_table(output: str) -> pd.DataFrame: @@ -37,41 +33,6 @@ def load_test_table(output: str) -> pd.DataFrame: return pd.read_parquet(f"tests/verbs/data/{output}.parquet") -def get_config_for_workflow(name: str) -> PipelineWorkflowConfig: - """Instantiates the bare minimum config to get a default workflow config for testing.""" - config = create_graphrag_config() - - # this flag needs to be set before creating the pipeline config, or the entire covariate workflow will be excluded - config.claim_extraction.enabled = True - - pipeline_config = create_pipeline_config(config) - - result = next(conf for conf in pipeline_config.workflows if conf.name == name) - - return cast("PipelineWorkflowConfig", result.config) - - -async def get_workflow_output( - input_tables: dict[str, pd.DataFrame], - schema: dict, - context: PipelineRunContext | None = None, -) -> pd.DataFrame: - """Pass in the input tables, the schema, and the output name""" - - # the bare minimum workflow is the pipeline schema and table context - workflow = Workflow( - schema=schema, - input_tables=input_tables, - ) - - run_context = context or create_run_context(None, None, None) - - await workflow.run(context=run_context) - - # if there's only one output, it is the default here, no name required - return cast("pd.DataFrame", workflow.output()) - - def compare_outputs( actual: pd.DataFrame, expected: pd.DataFrame, columns: list[str] | None = None ) -> None: