Skip to content

Commit 1d68af3

Browse files
Community workflow (#1495)
* Create separate communities workflow * Add test for new workflow * Rename workflows * Collapse subflows into parents * Rename flows, reuse variables * Semver * Fix integration test * Fix smoke tests * Fix megapipeline format * Rename missed files --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent de12521 commit 1d68af3

36 files changed

+783
-735
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Create separate community workflow, collapse subflows."
4+
}

graphrag/index/create_pipeline_config.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
PipelineWorkflowReference,
5353
)
5454
from graphrag.index.workflows.default_workflows import (
55-
create_base_entity_graph,
55+
compute_communities,
5656
create_base_text_units,
5757
create_final_communities,
5858
create_final_community_reports,
@@ -62,6 +62,7 @@
6262
create_final_nodes,
6363
create_final_relationships,
6464
create_final_text_units,
65+
extract_graph,
6566
generate_text_embeddings,
6667
)
6768

@@ -216,7 +217,7 @@ def _get_embedding_settings(
216217
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
217218
return [
218219
PipelineWorkflowReference(
219-
name=create_base_entity_graph,
220+
name=extract_graph,
220221
config={
221222
"snapshot_graphml": settings.snapshots.graphml,
222223
"snapshot_transient": settings.snapshots.transient,
@@ -235,9 +236,15 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
235236
settings.root_dir,
236237
),
237238
},
239+
},
240+
),
241+
PipelineWorkflowReference(
242+
name=compute_communities,
243+
config={
238244
"cluster_graph": {
239245
"strategy": settings.cluster_graph.resolved_strategy()
240246
},
247+
"snapshot_transient": settings.snapshots.transient,
241248
},
242249
),
243250
PipelineWorkflowReference(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""All the steps to create the base entity graph."""
5+
6+
from typing import Any
7+
8+
import pandas as pd
9+
10+
from graphrag.index.operations.cluster_graph import cluster_graph
11+
from graphrag.index.operations.create_graph import create_graph
12+
from graphrag.index.operations.snapshot import snapshot
13+
from graphrag.storage.pipeline_storage import PipelineStorage
14+
15+
16+
async def compute_communities(
17+
base_relationship_edges: pd.DataFrame,
18+
storage: PipelineStorage,
19+
clustering_strategy: dict[str, Any],
20+
snapshot_transient_enabled: bool = False,
21+
) -> pd.DataFrame:
22+
"""All the steps to create the base entity graph."""
23+
graph = create_graph(base_relationship_edges)
24+
25+
communities = cluster_graph(
26+
graph,
27+
strategy=clustering_strategy,
28+
)
29+
30+
base_communities = pd.DataFrame(
31+
communities, columns=pd.Index(["level", "community", "parent", "title"])
32+
).explode("title")
33+
base_communities["community"] = base_communities["community"].astype(int)
34+
35+
if snapshot_transient_enabled:
36+
await snapshot(
37+
base_communities,
38+
name="base_communities",
39+
storage=storage,
40+
formats=["parquet"],
41+
)
42+
43+
return base_communities

graphrag/index/flows/create_base_entity_graph.py renamed to graphrag/index/flows/extract_graph.py

+4-29
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515

1616
from graphrag.cache.pipeline_cache import PipelineCache
17-
from graphrag.index.operations.cluster_graph import cluster_graph
1817
from graphrag.index.operations.create_graph import create_graph
1918
from graphrag.index.operations.extract_entities import extract_entities
2019
from graphrag.index.operations.snapshot import snapshot
@@ -25,13 +24,11 @@
2524
from graphrag.storage.pipeline_storage import PipelineStorage
2625

2726

28-
async def create_base_entity_graph(
27+
async def extract_graph(
2928
text_units: pd.DataFrame,
3029
callbacks: VerbCallbacks,
3130
cache: PipelineCache,
3231
storage: PipelineStorage,
33-
runtime_storage: PipelineStorage,
34-
clustering_strategy: dict[str, Any],
3532
extraction_strategy: dict[str, Any] | None = None,
3633
extraction_num_threads: int = 4,
3734
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
@@ -40,7 +37,7 @@ async def create_base_entity_graph(
4037
summarization_num_threads: int = 4,
4138
snapshot_graphml_enabled: bool = False,
4239
snapshot_transient_enabled: bool = False,
43-
) -> None:
40+
) -> tuple[pd.DataFrame, pd.DataFrame]:
4441
"""All the steps to create the base entity graph."""
4542
# this returns a graph for each text unit, to be merged later
4643
entity_dfs, relationship_dfs = await extract_entities(
@@ -73,17 +70,6 @@ async def create_base_entity_graph(
7370

7471
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
7572

76-
communities = cluster_graph(
77-
graph,
78-
strategy=clustering_strategy,
79-
)
80-
81-
base_communities = _prep_communities(communities)
82-
83-
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
84-
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
85-
await runtime_storage.set("base_communities", base_communities)
86-
8773
if snapshot_graphml_enabled:
8874
# todo: extract graphs at each level, and add in meta like descriptions
8975
await snapshot_graphml(
@@ -105,12 +91,8 @@ async def create_base_entity_graph(
10591
storage=storage,
10692
formats=["parquet"],
10793
)
108-
await snapshot(
109-
base_communities,
110-
name="base_communities",
111-
storage=storage,
112-
formats=["parquet"],
113-
)
94+
95+
return (base_entity_nodes, base_relationship_edges)
11496

11597

11698
def _merge_entities(entity_dfs) -> pd.DataFrame:
@@ -158,13 +140,6 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
158140
return edges
159141

160142

161-
def _prep_communities(communities) -> pd.DataFrame:
162-
# Convert the input into a DataFrame and explode the title column
163-
return pd.DataFrame(
164-
communities, columns=pd.Index(["level", "community", "parent", "title"])
165-
).explode("title")
166-
167-
168143
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
169144
return pd.DataFrame([
170145
{"name": node, "degree": int(degree)}

graphrag/index/update/entities.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def _run_entity_summarization(
112112
The updated entities dataframe with summarized descriptions.
113113
"""
114114
summarize_config = _find_workflow_config(
115-
config, "create_base_entity_graph", "summarize_descriptions"
115+
config, "extract_graph", "summarize_descriptions"
116116
)
117117
strategy = summarize_config.get("strategy", {})
118118

graphrag/index/workflows/default_workflows.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33

44
"""A package containing default workflows definitions."""
55

6-
# load and register all subflows
7-
from graphrag.index.workflows.v1.subflows import * # noqa
8-
96
from graphrag.index.workflows.typing import WorkflowDefinitions
10-
from graphrag.index.workflows.v1.create_base_entity_graph import (
11-
build_steps as build_create_base_entity_graph_steps,
7+
from graphrag.index.workflows.v1.compute_communities import (
8+
build_steps as build_compute_communities_steps,
129
)
13-
from graphrag.index.workflows.v1.create_base_entity_graph import (
14-
workflow_name as create_base_entity_graph,
10+
from graphrag.index.workflows.v1.compute_communities import (
11+
workflow_name as compute_communities,
1512
)
1613
from graphrag.index.workflows.v1.create_base_text_units import (
1714
build_steps as build_create_base_text_units_steps,
@@ -67,16 +64,22 @@
6764
from graphrag.index.workflows.v1.create_final_text_units import (
6865
workflow_name as create_final_text_units,
6966
)
67+
from graphrag.index.workflows.v1.extract_graph import (
68+
build_steps as build_extract_graph_steps,
69+
)
70+
from graphrag.index.workflows.v1.extract_graph import (
71+
workflow_name as extract_graph,
72+
)
7073
from graphrag.index.workflows.v1.generate_text_embeddings import (
7174
build_steps as build_generate_text_embeddings_steps,
7275
)
73-
7476
from graphrag.index.workflows.v1.generate_text_embeddings import (
7577
workflow_name as generate_text_embeddings,
7678
)
7779

7880
default_workflows: WorkflowDefinitions = {
79-
create_base_entity_graph: build_create_base_entity_graph_steps,
81+
extract_graph: build_extract_graph_steps,
82+
compute_communities: build_compute_communities_steps,
8083
create_base_text_units: build_create_base_text_units_steps,
8184
create_final_text_units: build_create_final_text_units,
8285
create_final_community_reports: build_create_final_community_reports_steps,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing build_steps method definition."""
5+
6+
from typing import Any, cast
7+
8+
import pandas as pd
9+
from datashaper import (
10+
Table,
11+
verb,
12+
)
13+
from datashaper.table_store.types import VerbResult, create_verb_result
14+
15+
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
16+
from graphrag.index.flows.compute_communities import compute_communities
17+
from graphrag.storage.pipeline_storage import PipelineStorage
18+
19+
workflow_name = "compute_communities"
20+
21+
22+
def build_steps(
23+
config: PipelineWorkflowConfig,
24+
) -> list[PipelineWorkflowStep]:
25+
"""
26+
Create the base communities from the graph edges.
27+
28+
## Dependencies
29+
* `workflow:extract_graph`
30+
"""
31+
clustering_config = config.get(
32+
"cluster_graph",
33+
{"strategy": {"type": "leiden"}},
34+
)
35+
clustering_strategy = clustering_config.get("strategy")
36+
37+
snapshot_transient = config.get("snapshot_transient", False) or False
38+
39+
return [
40+
{
41+
"verb": workflow_name,
42+
"args": {
43+
"clustering_strategy": clustering_strategy,
44+
"snapshot_transient_enabled": snapshot_transient,
45+
},
46+
"input": ({"source": "workflow:extract_graph"}),
47+
},
48+
]
49+
50+
51+
@verb(
52+
name=workflow_name,
53+
treats_input_tables_as_immutable=True,
54+
)
55+
async def workflow(
56+
storage: PipelineStorage,
57+
runtime_storage: PipelineStorage,
58+
clustering_strategy: dict[str, Any],
59+
snapshot_transient_enabled: bool = False,
60+
**_kwargs: dict,
61+
) -> VerbResult:
62+
"""All the steps to create the base entity graph."""
63+
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
64+
65+
base_communities = await compute_communities(
66+
base_relationship_edges,
67+
storage,
68+
clustering_strategy=clustering_strategy,
69+
snapshot_transient_enabled=snapshot_transient_enabled,
70+
)
71+
72+
await runtime_storage.set("base_communities", base_communities)
73+
74+
return create_verb_result(cast("Table", pd.DataFrame()))

graphrag/index/workflows/v1/create_base_entity_graph.py

-59
This file was deleted.

0 commit comments

Comments
 (0)