Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify flow config #1554

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
13 changes: 10 additions & 3 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput
from graphrag.config.input_models.llm_config_input import LLMConfigInput
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
Expand Down Expand Up @@ -412,12 +412,15 @@ def hydrate_parallelization_params(
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)

strategy = reader.str("strategy")
chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=encoding_model,
strategy=ChunkStrategyType(strategy)
if strategy
else ChunkStrategyType.tokens,
)
with (
reader.envvar_prefix(Section.snapshot),
Expand Down Expand Up @@ -522,8 +525,12 @@ def hydrate_parallelization_params(
)

with reader.use(values.get("cluster_graph")):
use_lcc = reader.bool("use_lcc")
cluster_graph_model = ClusterGraphConfig(
max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE
max_cluster_size=reader.int("max_cluster_size")
or defs.MAX_CLUSTER_SIZE,
natoverse marked this conversation as resolved.
Show resolved Hide resolved
use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC,
seed=reader.int("seed"),
)

with (
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
CLAIM_MAX_GLEANINGS = 1
CLAIM_EXTRACTION_ENABLED = False
MAX_CLUSTER_SIZE = 10
USE_LCC = True
COMMUNITY_REPORT_MAX_LENGTH = 2000
COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000
ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"]
Expand Down
34 changes: 17 additions & 17 deletions graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,24 @@

"""Parameterization settings for the default configuration."""

from enum import Enum

from pydantic import BaseModel, Field

import graphrag.config.defaults as defs


class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""

tokens = "tokens"
sentence = "sentence"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""

Expand All @@ -19,22 +32,9 @@ class ChunkingConfig(BaseModel):
description="The chunk by columns to use.",
default=defs.CHUNK_GROUP_BY_COLUMNS,
)
strategy: dict | None = Field(
description="The chunk strategy to use, overriding the default tokenization strategy",
default=None,
strategy: ChunkStrategyType = Field(
description="The chunking strategy to use.", default=ChunkStrategyType.tokens
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL
)

def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType

return self.strategy or {
"type": ChunkStrategyType.tokens,
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": encoding_model or self.encoding_model,
}
17 changes: 6 additions & 11 deletions graphrag/config/models/cluster_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@ class ClusterGraphConfig(BaseModel):
max_cluster_size: int = Field(
description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE
)
strategy: dict | None = Field(
description="The cluster strategy to use.", default=None
use_lcc: bool = Field(
description="Whether to use the largest connected component.",
default=defs.USE_LCC,
)
seed: int | None = Field(
description="The seed to use for the clustering.", default=None
natoverse marked this conversation as resolved.
Show resolved Hide resolved
)

def resolved_strategy(self) -> dict:
"""Get the resolved cluster strategy."""
from graphrag.index.operations.cluster_graph import GraphCommunityStrategyType

return self.strategy or {
"type": GraphCommunityStrategyType.leiden,
"max_cluster_size": self.max_cluster_size,
}
11 changes: 2 additions & 9 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,8 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_base_text_units,
config={
"chunks": settings.chunks,
"snapshot_transient": settings.snapshots.transient,
"chunk_by": settings.chunks.group_by_columns,
"text_chunk": {
"strategy": settings.chunks.resolved_strategy(
settings.encoding_model
)
},
},
),
PipelineWorkflowReference(
Expand Down Expand Up @@ -243,9 +238,7 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
PipelineWorkflowReference(
name=compute_communities,
config={
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()
},
"cluster_graph": settings.cluster_graph,
"snapshot_transient": settings.snapshots.transient,
},
),
Expand Down
10 changes: 6 additions & 4 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""All the steps to create the base entity graph."""

from typing import Any

import pandas as pd

from graphrag.index.operations.cluster_graph import cluster_graph
Expand All @@ -13,14 +11,18 @@

def compute_communities(
base_relationship_edges: pd.DataFrame,
clustering_strategy: dict[str, Any],
max_cluster_size: int,
use_lcc: bool,
seed: int | None,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)

communities = cluster_graph(
graph,
strategy=clustering_strategy,
max_cluster_size,
use_lcc,
seed,
)

base_communities = pd.DataFrame(
Expand Down
40 changes: 25 additions & 15 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
aggregate_operation_mapping,
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.utils.hashing import gen_sha512_hash


def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
group_by_columns: list[str],
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand All @@ -35,7 +39,7 @@ def create_base_text_units(

aggregated = _aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
groupby=[*group_by_columns] if len(group_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
Expand All @@ -47,30 +51,36 @@ def create_base_text_units(

callbacks.progress(Progress(percent=1))

chunked = chunk_text(
aggregated["chunks"] = chunk_text(
aggregated,
column="texts",
to="chunks",
size=size,
overlap=overlap,
encoding_model=encoding_model,
strategy=strategy,
callbacks=callbacks,
strategy=chunk_strategy,
)

chunked = cast("pd.DataFrame", chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
aggregated = aggregated.explode("chunks")
aggregated.rename(
columns={
"chunks": "chunk",
},
inplace=True,
)
chunked["id"] = chunked.apply(lambda row: gen_sha512_hash(row, ["chunk"]), axis=1)
chunked[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
chunked["chunk"].tolist(), index=chunked.index
aggregated["id"] = aggregated.apply(
lambda row: gen_sha512_hash(row, ["chunk"]), axis=1
)
aggregated[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame(
aggregated["chunk"].tolist(), index=aggregated.index
)
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)
aggregated.rename(columns={"chunk": "text"}, inplace=True)

return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
return cast(
"pd.DataFrame", aggregated[aggregated["text"].notna()].reset_index(drop=True)
)


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
8 changes: 0 additions & 8 deletions graphrag/index/operations/chunk_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,3 @@
# Licensed under the MIT License

"""The Indexing Engine text chunk package root."""

from graphrag.index.operations.chunk_text.chunk_text import (
ChunkStrategy,
ChunkStrategyType,
chunk_text,
)

__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk_text"]
46 changes: 23 additions & 23 deletions graphrag/index/operations/chunk_text/chunk_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
progress_ticker,
)

from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
natoverse marked this conversation as resolved.
Show resolved Hide resolved
from graphrag.index.operations.chunk_text.typing import (
ChunkInput,
ChunkStrategy,
ChunkStrategyType,
)


def chunk_text(
input: pd.DataFrame,
column: str,
to: str,
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
) -> pd.Series:
"""
Chunk a piece of text into smaller pieces.

Expand Down Expand Up @@ -60,35 +62,33 @@ def chunk_text(
type: sentence
```
"""
output = input
if strategy is None:
strategy = {}
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
strategy_config = {**strategy}
strategy_exec = load_strategy(strategy_name)

num_total = _get_num_total(output, column)
tick = progress_ticker(callbacks.progress, num_total)
strategy_exec = load_strategy(strategy)

output[to] = output.apply(
cast(
"Any",
lambda x: run_strategy(strategy_exec, x[column], strategy_config, tick),
num_total = _get_num_total(input, column)
tick = progress_ticker(callbacks.progress, num_total)
# collapse the config back to a single object to support "polymorphic" function call
config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model)
return cast(
"pd.Series",
input.apply(
cast(
"Any",
lambda x: run_strategy(strategy_exec, x[column], config, tick),
),
axis=1,
),
axis=1,
)
return output


def run_strategy(
strategy: ChunkStrategy,
strategy_exec: ChunkStrategy,
input: ChunkInput,
strategy_args: dict[str, Any],
config: ChunkingConfig,
tick: ProgressTicker,
) -> list[str | tuple[list[str] | None, str, int]]:
"""Run strategy method definition."""
if isinstance(input, str):
return [item.text_chunk for item in strategy([input], {**strategy_args}, tick)]
return [item.text_chunk for item in strategy_exec([input], config, tick)]

# We can work with both just a list of text content
# or a list of tuples of (document_id, text content)
Expand All @@ -100,7 +100,7 @@ def run_strategy(
else:
texts.append(item[1])

strategy_results = strategy(texts, {**strategy_args}, tick)
strategy_results = strategy_exec(texts, config, tick)

results = []
for strategy_result in strategy_results:
Expand Down
13 changes: 6 additions & 7 deletions graphrag/index/operations/chunk_text/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
"""A module containing chunk strategies."""

from collections.abc import Iterable
from typing import Any

import nltk
import tiktoken
from datashaper import ProgressTicker

import graphrag.config.defaults as defs
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import Tokenizer


def run_tokens(
input: list[str], args: dict[str, Any], tick: ProgressTicker
input: list[str], config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into chunks based on encoding tokens."""
tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE)
chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP)
encoding_name = args.get("encoding_name", defs.ENCODING_MODEL)
tokens_per_chunk = config.size
chunk_overlap = config.overlap
encoding_name = config.encoding_model
enc = tiktoken.get_encoding(encoding_name)

def encode(text: str) -> list[int]:
Expand Down Expand Up @@ -83,7 +82,7 @@ def _split_text_on_tokens(


def run_sentences(
input: list[str], _args: dict[str, Any], tick: ProgressTicker
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts by sentence."""
for doc_idx, text in enumerate(input):
Expand Down
Loading
Loading